deepseek: 切分类和长函数到同名文件中

import re
import sys
import os
import ast
from tokenize import generate_tokens, COMMENT, STRING, NL, INDENT, DEDENT
import io

def extract_entities(filename):
    """提取类和函数到单独文件"""
    with open(filename, 'r', encoding='utf-8') as f:
        content = f.read()
    
    # 计算总行数
    total_lines = content.count('\n') + 1
    long_function_threshold = max(80, total_lines // 12)
    
    # 初始化结果
    entities = []
    current_pos = 0
    
    # 尝试使用AST解析获取更准确的位置信息
    try:
        tree = ast.parse(content)
        ast_entities = []
        
        for node in tree.body:
            if isinstance(node, ast.ClassDef):
                start_line = node.lineno
                end_line = node.end_lineno
                ast_entities.append(('class', node.name, start_line, end_line))
            
            elif isinstance(node, ast.FunctionDef):
                start_line = node.lineno
                end_line = node.end_lineno
                ast_entities.append(('function', node.name, start_line, end_line))
        
        # 使用AST信息辅助提取
        lines = content.split('\n')
        for e_type, name, start_line, end_line in ast_entities:
            start_index = sum(len(line) + 1 for line in lines[:start_line-1])
            end_index = sum(len(line) + 1 for line in lines[:end_line]) - 1
            
            # 对于函数,检查是否为长函数
            if e_type == 'function':
                func_lines = end_line - start_line + 1
                if func_lines > long_function_threshold:
                    func_content = content[start_index:end_index]
                    entities.append(('function', name, start_index, end_index, func_content))
            
            # 类总是提取
            elif e_type == 'class':
                class_content = content[start_index:end_index]
                entities.append(('class', name, start_index, end_index, class_content))
    
    except SyntaxError:
        # 如果AST解析失败,使用基于缩进的方法
        return extract_entities_with_indent(content, total_lines, long_function_threshold)
    
    return entities, content, total_lines

def extract_entities_with_indent(content, total_lines, long_function_threshold):
    """当AST解析失败时使用基于缩进的方法"""
    entities = []
    current_pos = 0
    
    # 类定义正则(支持装饰器和继承)
    class_pattern = re.compile(
        r'^([ \t]*)@?.*?\b(class|struct)\s+(\w+)\s*[\(:]?[^{:]*[:{]?', 
        re.MULTILINE
    )
    
    # 函数定义正则(支持装饰器和类型注解)
    func_pattern = re.compile(
        r'^([ \t]*)@?.*?\b(def)\s+(\w+)\s*\([^{:]*\)\s*[^{:]*[:{]?', 
        re.MULTILINE
    )
    
    # 提取类
    for match in class_pattern.finditer(content):
        indent = match.group(1)
        class_name = match.group(3)
        start_index = match.start()
        
        # 找到类体的结束位置
        end_index = find_block_end(content, start_index, indent)
        if end_index == -1:
            continue
            
        class_content = content[start_index:end_index]
        entities.append(('class', class_name, start_index, end_index, class_content))
        current_pos = end_index
    
    # 提取函数(只提取顶级函数,忽略类内方法)
    for match in func_pattern.finditer(content):
        func_name = match.group(3)
        start_index = match.start()
        indent = match.group(1)
        
        # 跳过类内的方法
        if any(start_index > c_start and start_index < c_end 
               for (t, _, c_start, c_end, _) in entities if t == 'class'):
            continue
            
        # 找到函数体的结束位置
        end_index = find_block_end(content, start_index, indent)
        if end_index == -1:
            continue
            
        func_content = content[start_index:end_index]
        func_lines = func_content.count('\n') + 1
        
        # 检查是否为长函数
        if func_lines > long_function_threshold:
            entities.append(('function', func_name, start_index, end_index, func_content))
            current_pos = end_index
    
    # 按起始位置排序
    entities.sort(key=lambda x: x[2])
    
    return entities, content, total_lines

def find_block_end(content, start_index, base_indent):
    """找到代码块的结束位置"""
    base_indent_level = len(base_indent) if base_indent else 0
    current_index = start_index
    stack = []
    in_string = False
    string_char = None
    
    while current_index < len(content):
        char = content[current_index]
        
        # 处理字符串字面量
        if not in_string and char in ('"', "'"):
            in_string = True
            string_char = char
        elif in_string and char == string_char:
            # 检查是否是转义的引号
            if content[current_index-1] == '\\':
                # 检查转义字符本身是否被转义
                backslash_count = 0
                i = current_index - 1
                while i >= 0 and content[i] == '\\':
                    backslash_count += 1
                    i -= 1
                if backslash_count % 2 == 0:  # 偶数个反斜杠,引号未被转义
                    in_string = False
                    string_char = None
            else:
                in_string = False
                string_char = None
        
        if in_string:
            current_index += 1
            continue
        
        # 处理括号
        if char in '([{':
            stack.append(char)
        elif char in ')]}':
            if not stack:
                return -1  # 不匹配的括号
            last_open = stack.pop()
            if (last_open == '(' and char != ')') or \
               (last_open == '[' and char != ']') or \
               (last_open == '{' and char != '}'):
                return -1  # 括号不匹配
        
        # 检查代码块结束
        if char == '\n':
            next_line_start = current_index + 1
            if next_line_start >= len(content):
                return next_line_start  # 文件结束
            
            # 检查下一行的缩进级别
            next_line_end = content.find('\n', next_line_start)
            if next_line_end == -1:
                next_line_end = len(content)
                
            next_line = content[next_line_start:next_line_end]
            indent_level = len(next_line) - len(next_line.lstrip())
            
            # 如果缩进小于基础缩进且没有未闭合的括号,则块结束
            if indent_level <= base_indent_level and not stack:
                # 确保不是空行或注释
                stripped_line = next_line.strip()
                if stripped_line and not stripped_line.startswith('#'):
                    return next_line_start
        
        current_index += 1
    
    return len(content)  # 到达文件末尾

def write_entities(entities, content, filename):
    """将实体写入文件并生成剩余内容"""
    # 创建输出目录
    base_name = os.path.splitext(os.path.basename(filename))[0]
    output_dir = f"{base_name}_split"
    os.makedirs(output_dir, exist_ok=True)
    
    # 提取覆盖范围
    covered_ranges = []
    for entity in entities:
        e_type, e_name, start, end, e_content = entity
        
        # 确保每个实体都写入单独的文件
        output_path = os.path.join(output_dir, f"{e_name}.py")
        with open(output_path, 'w', encoding='utf-8') as f:
            f.write(e_content)
        
        covered_ranges.append((start, end))
    
    # 生成剩余内容
    covered_ranges.sort(key=lambda x: x[0])
    remaining_parts = []
    last_pos = 0
    
    for start, end in covered_ranges:
        remaining_parts.append(content[last_pos:start])
        last_pos = end
    
    remaining_parts.append(content[last_pos:])
    remaining_content = ''.join(remaining_parts)
    
    # 写入剩余文件
    left_path = os.path.join(output_dir, "left.py")
    with open(left_path, 'w', encoding='utf-8') as f:
        f.write(remaining_content)
    
    return output_dir

def main():
    if len(sys.argv) != 2:
        print("Usage: python split_code.py <source_file.py>")
        sys.exit(1)
    
    filename = sys.argv[1]
    if not os.path.isfile(filename):
        print(f"Error: File not found - {filename}")
        sys.exit(1)
    
    try:
        entities, content, total_lines = extract_entities(filename)
        output_dir = write_entities(entities, content, filename)
        
        class_count = sum(1 for e in entities if e[0] == 'class')
        func_count = sum(1 for e in entities if e[0] == 'function')
        
        print(f"Processed {filename}:")
        print(f"- Total lines: {total_lines}")
        print(f"- Long function threshold: {max(80, total_lines//12)} lines")
        print(f"- Extracted: {class_count} classes, {func_count} long functions")
        print(f"- Output directory: {output_dir}")
        print(f"- Remaining code in: {os.path.join(output_dir, 'left.py')}")
        
    except Exception as e:
        print(f"Error processing file: {str(e)}")
        import traceback
        traceback.print_exc()
        sys.exit(1)

if __name__ == "__main__":
    main()

需求: 

写一个python脚本,阅读python源代码,i, 把源文件中各个class 切分到文件中,文件名是类名,ii,长函数(代码行超过本文件长度1/12, 或者超过80行,两个条件满足一个即可)切分到文件中,文件名是函数名。 iii,剩余部分放在left.py文件中,谢谢

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值