get_async_keys.cpp

 
from fastapi import FastAPI, Request from fastapi.responses import HTMLResponse, JSONResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from jinja2 import Environment, FileSystemLoader import json import asyncio from dotenv import load_dotenv from pathlib import Path import uvicorn from transformers import AutoTokenizer, AutoModelForCausalLM import torch # 导入数据库操作 import database BASE_DIR = Path(__file__).resolve().parent.parent #主文件夹地址 FRONTEND_DIR = BASE_DIR / "frontend" app = FastAPI() load_dotenv(dotenv_path= BASE_DIR / r"backend\MyAI.env") templates = Jinja2Templates(directory=str(FRONTEND_DIR)) # 挂载静态资源 app.mount("/static", StaticFiles(directory=str(FRONTEND_DIR)), name="static") # Jinja2 模板引擎(用于渲染 HTML) template_env = Environment(loader=FileSystemLoader(str(FRONTEND_DIR))) # 全局变量(在 app 启动时加载一次) model = None tokenizer = None def load_model(): """同步函数:实际加载模型""" model_name = str(BASE_DIR / "model\DeepSeek-V3") print("Loading tokenizer...") tok = AutoTokenizer.from_pretrained(model_name) print("Loading model...") m = AutoModelForCausalLM.from_pretrained( model_name, dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto", low_cpu_mem_usage=True ).eval() return m, tok async def start_load(): global model,tokenizer """在程序启动时,非阻塞地加载模型""" loop = asyncio.get_event_loop() # 使用线程池执行同步加载,避免阻塞事件循环 model, tokenizer = await loop.run_in_executor(None, load_model) print("✅ Model loaded during startup!") def shutdown_event(): """在应用关闭时清理资源""" global model, tokenizer if model is not None: del model if tokenizer is not None: del tokenizer if torch.cuda.is_available(): torch.cuda.empty_cache() print("👋 Cleaned up model and CUDA cache on shutdown.") app.add_event_handler("startup",start_load) app.add_event_handler("shutdown",shutdown_event) @app.get("/", response_class=HTMLResponse) async def home(): #挂载网页,连接前端 template = template_env.get_template("myapp.html") characters = database.get_all_characters() characters_json = json.dumps([{ "id": c["id"], "name": c["name"], "trait": c["trait"] } for c in characters], ensure_ascii=False) content = template.render(characters=characters, characters_json=characters_json) return HTMLResponse(content=content) @app.post("/chat") async def deepseek_chat(request: Request): data = await request.json() character_id = data.get("character_id") user_message = data.get("message") # 查询角色信息 conn = database.get_connection() cursor = conn.cursor(dictionary=True) cursor.execute("SELECT name, trait FROM characters WHERE id = %s", (character_id,)) character = cursor.fetchone() conn.close() if not character: return JSONResponse({"error": "角色不存在"}, status_code=404) # 构建系统提示词 system_prompt = f""" 你正在扮演 {character['name']}。 人物设定:{character['trait']} 请始终以这个角色的身份、语气和思维方式回答问题。 不要暴露你是 AI 的事实。 """.strip() # 👇 关键:使用本地模型进行推理 try: # 确保模型已加载 global model, tokenizer if model is None or tokenizer is None: return JSONResponse({"error": "模型尚未加载,请先启动模型"}, status_code=500) # 构造对话历史(必须使用 chat template) messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_message} ] # 使用 tokenizer.apply_chat_template 构造输入文本 input_text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True # 让模型知道要开始生成 assistant 回复 ) # Tokenize inputs = tokenizer(input_text, return_tensors="pt").to(model.device) # 生成参数 with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=512, temperature=0.85, top_p=0.95, do_sample=True, repetition_penalty=1.1, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.eos_token_id # 避免 decoder-only 模型 padding 报错 ) # 解码输出(去掉输入部分) full_response = tokenizer.decode(outputs[0], skip_special_tokens=False) # 提取 assistant 的回复内容 # 注意:apply_chat_template 已经添加了 <|assistant|> 标记 if "<|assistant|>" in full_response: reply = full_response.split("<|assistant|>")[-1].strip() else: reply = full_response[len(input_text):].strip() # 清理结尾可能的无关 token eot_token = "<|EOT|>" if eot_token in reply: reply = reply.split(eot_token)[0].strip() # 保存对话记录 database.save_conversation(character_id, user_message, reply) return JSONResponse({"reply": reply}) except Exception as e: import traceback error_msg = traceback.format_exc() # 更详细的错误日志 return JSONResponse({"error": f"推理失败: {str(e)}", "detail": error_msg}, status_code=500) # 非llama.cpp实现 if __name__ == "__main__": uvicorn.run("myapp:app", host="127.0.0.1", port=8000, reload=True) 将本地模型加载换为远程api请求
最新发布
09-26
import os import sys import json import argparse import re import time import threading import traceback from collections import defaultdict from typing import Dict, List, Any, Optional, Tuple, Set, Union import requests from colorama import init, Fore, Style, Back # 静态分析相关库 import ast import importlib.util import networkx as nx from glob import glob # 初始化colorama以支持彩色输出 init() class CodeStructureAnalyzer: """代码结构分析器,使用静态分析构建函数调用图和依赖关系""" def __init__(self, repo_path: str, use_ctags: bool = False, verbose: bool = False): self.repo_path = repo_path self.call_graph = nx.DiGraph() # 使用有向图存储调用关系 self.function_info = {} # 函数/方法详细信息 self.class_info = {} # 类信息 self.module_info = {} # 模块信息 self.import_map = {} # 导入映射 self.function_calls = {} # 函数调用关系 {caller: [callee1, callee2, ...]} self.use_ctags = use_ctags # 是否使用CTags进行C/C++代码分析 self.verbose = verbose # 是否输出详细信息 self.supported_extensions = { '.py': self._analyze_python_file, '.js': self._analyze_js_file, '.ts': self._analyze_ts_file, '.tsx': self._analyze_ts_file, '.jsx': self._analyze_js_file, '.c': self._analyze_c_file, '.h': self._analyze_c_file, '.cpp': self._analyze_c_file, '.cc': self._analyze_c_file, '.hpp': self._analyze_c_file } def _analyze_c_file(self, file_path: str): """分析C/C++语言文件,提取函数、结构体和包含关系,可选择使用CTags进行更精确的解析""" # 确保文件存在且可读 if not os.path.isfile(file_path) or not os.access(file_path, os.R_OK): raise FileNotFoundError(f"找不到文件或无法访问:{file_path}") rel_path = os.path.relpath(file_path, self.repo_path) module_name = os.path.splitext(rel_path.replace(os.path.sep, '.'))[0] # 初始化模块信息 self.module_info[module_name] = { 'path': rel_path, 'functions': [], 'structures': [], 'includes': [] } # 判断文件类型是C还是C++ ext = os.path.splitext(file_path)[1].lower() is_cpp = ext in ['.cpp', '.cc'] # 如果启用了CTags并且可用,则使用CTags进行解析 if self.use_ctags and self._is_ctags_available(): try: return self._analyze_with_ctags(file_path, module_name, is_cpp) except Exception as e: if self.verbose: print(Fore.YELLOW + f"使用CTags分析失败,回退到正则表达式分析: {str(e)}" + Style.RESET_ALL) # 如果CTags分析失败,回退到正则表达式分析 # 使用正则表达式进行分析(原始方法) try: with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: code = f.read() # 提取#include语句 include_pattern = r'#include\s+[<"]([^>"]+)[>"]' for match in re.finditer(include_pattern, code): header = match.group(1) self.module_info[module_name]['includes'].append(header) # 记录导入映射 self.import_map[f"{module_name}.{os.path.splitext(header)[0]}"] = header # 提取函数定义 - 匹配基本C函数定义模式 # 1. 返回类型 + 函数名 + 参数列表 # 这是一个简化的正则表达式,真实C代码会有更复杂的情况 function_pattern = r'(\w+(?:\s+\w+)*)\s+(\w+)\s*\((.*?)\)\s*(?:;|{)' for match in re.finditer(function_pattern, code): return_type = match.group(1).strip() func_name = match.group(2).strip() params = match.group(3).strip() # 过滤掉可能误匹配的宏或关键字 if func_name in ['if', 'for', 'while', 'switch', 'return']: continue # 过滤掉可能误匹配的类型定义或结构体 if return_type in ['typedef', 'struct', 'enum', 'union']: continue # 检查函数是否有函数体,即真正的定义而非声明 is_definition = False pos = match.end() if pos < len(code) and code[pos - 1] == '{': is_definition = True # 也可以尝试寻找匹配的'}'来确定函数体范围 qualified_name = f"{module_name}.{func_name}" self.function_info[qualified_name] = { 'name': func_name, 'qualified_name': qualified_name, 'module': module_name, 'path': rel_path, 'return_type': return_type, 'parameters': params, 'is_definition': is_definition, 'is_method': False, 'calls': [] } self.module_info[module_name]['functions'].append(func_name) # 提取结构体定义 struct_pattern = r'struct\s+(\w+)\s*{([^}]*)}' for match in re.finditer(struct_pattern, code): struct_name = match.group(1) # struct_content = match.group(2) qualified_name = f"{module_name}.{struct_name}" self.class_info[qualified_name] = { 'name': struct_name, 'module': module_name, 'path': rel_path, 'methods': [], # C语言结构体没有方法 'bases': [], # C语言结构体没有继承 'type': 'struct' } self.module_info[module_name].setdefault('structures', []).append(struct_name) # 提取函数调用关系 - 尝试识别哪些函数调用了其他函数 for func_name, func_info in self.function_info.items(): if func_info['module'] == module_name and func_info.get('is_definition', False): # 简单匹配:函数名后跟着括号 for other_func in self.module_info[module_name]['functions']: if other_func != func_info['name']: # 匹配模式:函数名后跟着括号,但前面不是字母或数字 call_pattern = r'[^\w]' + re.escape(other_func) + r'\s*\(' if re.search(call_pattern, code): func_info['calls'].append(other_func) except Exception as e: raise Exception(f"C/C++文件分析错误: {str(e)}") return def _analyze_with_ctags(self, file_path: str, module_name: str, is_cpp: bool = False): """使用CTags分析C/C++文件结构,提供更精确的函数和结构体提取""" import subprocess import tempfile rel_path = os.path.relpath(file_path, self.repo_path) if self.verbose: file_type = "C++" if is_cpp else "C" print(f"使用CTags分析{file_type}文件: {rel_path}") # 创建临时文件存储CTags输出 with tempfile.NamedTemporaryFile(delete=False) as temp_file: temp_path = temp_file.name try: # 设置CTags命令和参数 language_opt = "--language-force=C++" if is_cpp else "--language-force=C" cmd = [ "ctags", "--fields=+n+s+l+t+K", # 包含行号、命名空间、语言、类型信息等额外字段 "--kinds-c++=+p+l+c+f+m+s+v+d" if is_cpp else "--kinds-c=+p+l+c+f+s+v+d", # 包含各种类型的标签 "--output-format=json", # 使用JSON输出格式,更容易解析 language_opt, "-o", temp_path, file_path ] if self.verbose: print(f"执行CTags命令: {' '.join(cmd)}") # 运行CTags命令 start_time = time.time() result = subprocess.run(cmd, capture_output=True, text=True) if self.verbose: ctags_time = time.time() - start_time print(f"CTags命令执行耗时: {ctags_time:.3f}秒") if result.returncode != 0: if self.verbose: print(Fore.RED + f"CTags执行失败: {result.stderr}" + Style.RESET_ALL) raise Exception(f"CTags执行失败: {result.stderr}") # 读取CTags输出 with open(temp_path, 'r') as f: tags_output = f.readlines() if self.verbose: print(f"CTags生成了 {len(tags_output)} 个标签条目") # 读取原始文件,用于分析包含关系和提取代码上下文 with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: code = f.read() # 提取#include语句 include_pattern = r'#include\s+[<"]([^>"]+)[>"]' includes = [] for match in re.finditer(include_pattern, code): header = match.group(1) includes.append(header) self.module_info[module_name]['includes'].append(header) self.import_map[f"{module_name}.{os.path.splitext(header)[0]}"] = header if self.verbose and includes: print(f"提取了 {len(includes)} 个包含语句: {', '.join(includes[:5])}" + ("..." if len(includes) > 5 else "")) # 解析CTags输出 functions_count = 0 structs_count = 0 classes_count = 0 methods_count = 0 for line in tags_output: try: if not line.strip(): continue tag = json.loads(line) tag_name = tag.get('name', '') tag_kind = tag.get('kind', '') tag_line = int(tag.get('line', 0)) # 处理函数和方法 if tag_kind in ['function', 'method']: # 构建完整的函数名 func_name = tag_name qualified_name = f"{module_name}.{func_name}" # 从文件中提取函数定义所在行及其上下文 line_content = "" if tag_line > 0: # 读取函数定义所在行及其前一行 lines = code.split('\n') if tag_line <= len(lines): line_content = lines[tag_line - 1].strip() # 提取参数和返回类型 params = tag.get('signature', '').strip('()') return_type = tag.get('typeref', '').replace('typename:', '') if not return_type and 'scope' in tag: return_type = tag['scope'] # 记录函数信息 self.function_info[qualified_name] = { 'name': func_name, 'qualified_name': qualified_name, 'module': module_name, 'path': rel_path, 'return_type': return_type, 'parameters': params, 'is_definition': True, # CTags通常只会为定义生成标签 'is_method': tag_kind == 'method', 'line': tag_line, 'calls': [] } # 更新计数器 if tag_kind == 'function': functions_count += 1 else: methods_count += 1 # 更新模块函数列表 if func_name not in self.module_info[module_name]['functions']: self.module_info[module_name]['functions'].append(func_name) # 处理结构体、类、枚举等 elif tag_kind in ['struct', 'class', 'enum', 'union']: struct_name = tag_name qualified_name = f"{module_name}.{struct_name}" # 记录结构体/类信息 self.class_info[qualified_name] = { 'name': struct_name, 'module': module_name, 'path': rel_path, 'methods': [], 'bases': [], 'line': tag_line, 'type': tag_kind } # 如果是C++类,可以尝试提取基类 if is_cpp and tag_kind == 'class' and 'inherits' in tag: self.class_info[qualified_name]['bases'] = [base.strip() for base in tag['inherits'].split(',')] if struct_name not in self.module_info[module_name]['structures']: self.module_info[module_name]['structures'].append(struct_name) # 更新计数器 if tag_kind == 'struct': structs_count += 1 elif tag_kind == 'class': classes_count += 1 # 如果是类成员方法,添加到对应类的methods列表中 if tag_kind == 'method' and 'class' in tag: class_name = tag['class'] class_qualified_name = f"{module_name}.{class_name}" if class_qualified_name in self.class_info: self.class_info[class_qualified_name]['methods'].append(tag_name) except json.JSONDecodeError: # 处理非JSON行,可能是旧版本CTags输出格式 if self.verbose: print(Fore.YELLOW + f"警告: 无法解析CTags输出行: {line[:50]}..." + Style.RESET_ALL) continue except Exception as e: if self.verbose: print(Fore.YELLOW + f"解析标签时出错: {str(e)}, 标签行: {line[:50]}..." + Style.RESET_ALL) if self.verbose: print(f"从CTags提取了 {functions_count} 个函数, {methods_count} 个方法, " + f"{structs_count} 个结构体, {classes_count} 个类") # 提取函数调用关系(暂时使用简单的正则表达式方法) calls_count = 0 for func_name, func_info in self.function_info.items(): if func_info['module'] == module_name: # 为每个函数查找可能的调用 func_calls = [] for other_func in self.module_info[module_name]['functions']: if other_func != func_info['name']: # 匹配模式:函数名后跟着括号,但前面不是字母或数字 call_pattern = r'[^\w]' + re.escape(other_func) + r'\s*\(' if re.search(call_pattern, code): func_info['calls'].append(other_func) func_calls.append(other_func) calls_count += 1 # 更新函数调用关系字典 other_func_qualified = f"{module_name}.{other_func}" if func_name not in self.function_calls: self.function_calls[func_name] = [] if other_func_qualified not in self.function_calls[func_name]: self.function_calls[func_name].append(other_func_qualified) if self.verbose and func_calls and len(func_calls) <= 5: print(f"函数 {func_info['name']} 调用了: {', '.join(func_calls)}") if self.verbose: print(f"共识别出 {calls_count} 个函数调用关系") except Exception as e: if self.verbose: print(Fore.RED + f"使用CTags分析失败: {str(e)}" + Style.RESET_ALL) raise finally: # 清理临时文件 if os.path.exists(temp_path): os.unlink(temp_path) if self.verbose: print("已清理CTags临时文件") def _is_ctags_available(self): """检查系统中是否安装了CTags""" try: import subprocess result = subprocess.run(["ctags", "--version"], capture_output=True, text=True) return result.returncode == 0 except Exception: return False def analyze_codebase(self, file_list: Optional[List[str]] = None, verbose: bool = False): """分析代码库,构建调用图和依赖关系""" analyzed_files = 0 analyzed_by_ctags = 0 if verbose: print(Fore.GREEN + "开始静态代码分析..." + Style.RESET_ALL) if self.use_ctags: # 检查CTags是否可用 ctags_available = self._is_ctags_available() if ctags_available: print(Fore.GREEN + "检测到CTags可用,将用于增强C/C++代码分析" + Style.RESET_ALL) else: print(Fore.YELLOW + "警告: CTags未安装或不可用,将使用基础正则表达式分析C/C++代码" + Style.RESET_ALL) # 如果没有提供文件列表,则自动发现支持的文件 if not file_list: if verbose: print("自动发现支持的文件类型...") file_list = [] for ext in self.supported_extensions: pattern = os.path.join(self.repo_path, f"**/*{ext}") found_files = glob(pattern, recursive=True) # 验证找到的文件是否真的存在并可访问 for file_path in found_files: if os.path.isfile(file_path) and os.access(file_path, os.R_OK): file_list.append(file_path) if verbose and found_files: print(f" 发现 {len(found_files)} 个 {ext} 文件") # 过滤掉node_modules, venv等目录的文件 ignored_dirs = ['node_modules', 'venv', 'env', '__pycache__', '.git', 'dist', 'build'] original_count = len(file_list) file_list = [f for f in file_list if not any(d in f for d in ignored_dirs)] if verbose and original_count != len(file_list): print(f"过滤忽略目录后,文件数从 {original_count} 减少到 {len(file_list)}") if verbose: print(f"找到 {len(file_list)} 个可分析的文件") # 统计不同类型的文件数量 ext_counts = defaultdict(int) for f in file_list: ext = os.path.splitext(f)[1].lower() ext_counts[ext] += 1 print("文件类型分布:") for ext, count in sorted(ext_counts.items(), key=lambda x: x[1], reverse=True): print(f" {ext}: {count} 文件") # 创建进度条 total_files = len(file_list) progress_step = max(1, total_files // 20) # 每处理5%的文件更新一次进度 start_time = time.time() # 先分析所有文件,收集基本信息 for i, file_path in enumerate(file_list): if verbose and i % progress_step == 0: elapsed = time.time() - start_time progress = i / total_files if progress > 0: eta = elapsed / progress * (1 - progress) print(f"已分析 {i}/{total_files} 文件... ({progress:.1%} 完成, ETA: {eta:.1f}秒)") # 再次检查文件是否存在且可访问 if not os.path.isfile(file_path) or not os.access(file_path, os.R_OK): if verbose: print(Fore.YELLOW + f"跳过不可访问的文件: {file_path}" + Style.RESET_ALL) continue ext = os.path.splitext(file_path)[1].lower() if ext in self.supported_extensions: try: # 记录分析开始时间 if verbose and ( ext == '.c' or ext == '.cpp' or ext == '.cc' or ext == '.h' or ext == '.hpp') and self.use_ctags: file_start_time = time.time() # 分析文件 self.supported_extensions[ext](file_path) analyzed_files += 1 # 如果是C/C++文件并使用了CTags,记录这一信息 if ( ext == '.c' or ext == '.cpp' or ext == '.cc' or ext == '.h' or ext == '.hpp') and self.use_ctags and self._is_ctags_available(): analyzed_by_ctags += 1 if verbose: file_time = time.time() - file_start_time rel_path = os.path.relpath(file_path, self.repo_path) print( Fore.BLUE + f"使用CTags分析文件 {rel_path}, 耗时: {file_time:.3f}秒" + Style.RESET_ALL) except Exception as e: if verbose: print(Fore.YELLOW + f"分析文件 {file_path} 时出错: {str(e)}" + Style.RESET_ALL) # 构建全局调用图 if verbose: print(Fore.BLUE + "正在构建全局调用图..." + Style.RESET_ALL) build_start_time = time.time() self._build_global_call_graph() if verbose: build_time = time.time() - build_start_time print(Fore.BLUE + f"全局调用图构建完成,耗时: {build_time:.2f}秒" + Style.RESET_ALL) total_time = time.time() - start_time if verbose: print( Fore.GREEN + f"静态分析完成,共分析了 {analyzed_files} 个文件,总耗时: {total_time:.2f}秒" + Style.RESET_ALL) if self.use_ctags: print(Fore.GREEN + f"其中 {analyzed_by_ctags} 个C/C++文件使用CTags进行了增强分析" + Style.RESET_ALL) print(f"发现 {len(self.function_info)} 个函数/方法, {len(self.class_info)} 个类") print(f"调用图包含 {self.call_graph.number_of_nodes()} 个节点和 {self.call_graph.number_of_edges()} 条边") def _analyze_python_file(self, file_path: str): """分析Python文件,提取函数、类和调用关系""" # 确保文件存在且可读 if not os.path.isfile(file_path) or not os.access(file_path, os.R_OK): raise FileNotFoundError(f"找不到文件或无法访问:{file_path}") rel_path = os.path.relpath(file_path, self.repo_path) module_name = os.path.splitext(rel_path.replace(os.path.sep, '.'))[0] # 读取并解析文件 try: with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: code = f.read() # 使用ast解析代码 tree = ast.parse(code, filename=file_path) # 为所有节点添加父节点引用 for node in ast.walk(tree): for child in ast.iter_child_nodes(node): setattr(child, 'parent', node) # 记录模块信息 self.module_info[module_name] = { 'path': rel_path, 'functions': [], 'classes': [], 'imports': [] } # 第一遍:收集所有类和函数定义 for node in ast.walk(tree): # 为节点添加父节点信息 for child in ast.iter_child_nodes(node): child.parent = node # 收集导入信息 if isinstance(node, ast.Import): for name in node.names: self.module_info[module_name]['imports'].append(name.name) elif isinstance(node, ast.ImportFrom) and node.module: for name in node.names: import_from = f"{node.module}.{name.name}" self.module_info[module_name]['imports'].append(import_from) # 记录导入映射,用于解析调用 if name.asname: self.import_map[f"{module_name}.{name.asname}"] = import_from else: self.import_map[f"{module_name}.{name.name}"] = import_from # 收集类定义 elif isinstance(node, ast.ClassDef): class_name = f"{module_name}.{node.name}" self.class_info[class_name] = { 'name': node.name, 'module': module_name, 'path': rel_path, 'methods': [], 'bases': [base.id for base in node.bases if isinstance(base, ast.Name)], 'lineno': node.lineno } self.module_info[module_name]['classes'].append(node.name) # 收集类的方法 for item in node.body: if isinstance(item, ast.FunctionDef): method_name = f"{class_name}.{item.name}" self.class_info[class_name]['methods'].append(item.name) # 记录方法信息 self.function_info[method_name] = { 'name': item.name, 'qualified_name': method_name, 'class': class_name, 'module': module_name, 'path': rel_path, 'lineno': item.lineno, 'is_method': True, 'calls': [] } # 收集顶层函数定义 elif isinstance(node, ast.FunctionDef): # 检查是否为顶层函数(父节点是模块) parent_is_module = False try: parent_is_module = isinstance(node.parent, ast.Module) except AttributeError: # 如果还没有添加parent属性,尝试判断是否为顶层 parent_is_module = isinstance(ast.iter_child_nodes(tree).__next__(), ast.FunctionDef) if parent_is_module: func_name = f"{module_name}.{node.name}" self.function_info[func_name] = { 'name': node.name, 'qualified_name': func_name, 'class': None, 'module': module_name, 'path': rel_path, 'lineno': node.lineno, 'is_method': False, 'calls': [] } self.module_info[module_name]['functions'].append(node.name) # 第二遍:收集调用关系 class CallVisitor(ast.NodeVisitor): def __init__(self, analyzer, current_scope): self.analyzer = analyzer self.current_scope = current_scope self.calls = [] def visit_Call(self, node): # 记录函数调用 if isinstance(node.func, ast.Name): # 直接调用:func() call_name = node.func.id self.calls.append(call_name) elif isinstance(node.func, ast.Attribute): # 属性调用:obj.method() if isinstance(node.func.value, ast.Name): # 简单的obj.method()形式 obj_name = node.func.value.id method_name = node.func.attr call_name = f"{obj_name}.{method_name}" self.calls.append(call_name) # 继续访问子节点 self.generic_visit(node) # 为每个函数/方法收集调用 for func_def in [n for n in ast.walk(tree) if isinstance(n, ast.FunctionDef)]: # 确定当前作用域 if hasattr(func_def, 'parent') and isinstance(func_def.parent, ast.ClassDef): current_scope = f"{module_name}.{func_def.parent.name}.{func_def.name}" else: current_scope = f"{module_name}.{func_def.name}" # 如果函数在我们的记录中 if current_scope in self.function_info: visitor = CallVisitor(self, current_scope) visitor.visit(func_def) self.function_info[current_scope]['calls'] = visitor.calls except Exception as e: raise Exception(f"Python文件分析错误: {str(e)}") def _analyze_js_file(self, file_path: str): """分析JavaScript文件,提取函数、类和调用关系 注意:这是一个简化版实现,使用正则表达式匹配而非完整的AST解析""" rel_path = os.path.relpath(file_path, self.repo_path) module_name = os.path.splitext(rel_path.replace(os.path.sep, '.'))[0] # 读取文件 try: with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: code = f.read() # 记录模块信息 self.module_info[module_name] = { 'path': rel_path, 'functions': [], 'classes': [], 'imports': [] } # 匹配导入语句 import_patterns = [ r'import\s+(\w+)\s+from\s+[\'"]([^\'"]+)[\'"]', # import X from 'Y' r'import\s+\{\s*([^}]+)\s*\}\s+from\s+[\'"]([^\'"]+)[\'"]', # import { X } from 'Y' r'const\s+(\w+)\s+=\s+require\([\'"]([^\'"]+)[\'"]\)' # const X = require('Y') ] for pattern in import_patterns: for match in re.finditer(pattern, code): if len(match.groups()) == 2: imported_name = match.group(1) module_source = match.group(2) self.module_info[module_name]['imports'].append(f"{imported_name} from {module_source}") # 匹配函数定义 function_patterns = [ r'function\s+(\w+)\s*\([^)]*\)\s*\{', # function x() { r'const\s+(\w+)\s*=\s*function\s*\([^)]*\)\s*\{', # const x = function() { r'const\s+(\w+)\s*=\s*\([^)]*\)\s*=>\s*\{', # const x = () => { r'let\s+(\w+)\s*=\s*function\s*\([^)]*\)\s*\{', # let x = function() { r'let\s+(\w+)\s*=\s*\([^)]*\)\s*=>\s*\{', # let x = () => { r'var\s+(\w+)\s*=\s*function\s*\([^)]*\)\s*\{' # var x = function() { ] # 提取函数定义 for pattern in function_patterns: for match in re.finditer(pattern, code): func_name = match.group(1) qualified_name = f"{module_name}.{func_name}" # 记录函数信息(简化版,没有行号) self.function_info[qualified_name] = { 'name': func_name, 'qualified_name': qualified_name, 'class': None, 'module': module_name, 'path': rel_path, 'lineno': None, # JavaScript解析简化版没有行号 'is_method': False, 'calls': [] } self.module_info[module_name]['functions'].append(func_name) # 匹配类定义和方法 class_pattern = r'class\s+(\w+)(?:\s+extends\s+(\w+))?\s*\{' for class_match in re.finditer(class_pattern, code): class_name = class_match.group(1) qualified_class_name = f"{module_name}.{class_name}" base_class = class_match.group(2) self.class_info[qualified_class_name] = { 'name': class_name, 'module': module_name, 'path': rel_path, 'methods': [], 'bases': [base_class] if base_class else [], 'lineno': None } self.module_info[module_name]['classes'].append(class_name) # 搜索类体中的方法 # 这是一个简化的方法,不处理嵌套类和其他复杂情况 class_pos = class_match.end() class_text = code[class_pos:] bracket_level = 0 for i, char in enumerate(class_text): if char == '{': bracket_level += 1 elif char == '}': bracket_level -= 1 if bracket_level < 0: # 到达类定义结束 class_text = class_text[:i] break # 在类体内搜索方法 method_pattern = r'(?:async\s+)?(\w+)\s*\([^)]*\)\s*\{' for method_match in re.finditer(method_pattern, class_text): method_name = method_match.group(1) if method_name != 'constructor': # 排除构造函数 qualified_method_name = f"{qualified_class_name}.{method_name}" self.class_info[qualified_class_name]['methods'].append(method_name) # 记录方法信息 self.function_info[qualified_method_name] = { 'name': method_name, 'qualified_name': qualified_method_name, 'class': qualified_class_name, 'module': module_name, 'path': rel_path, 'lineno': None, # 简化版实现没有行号 'is_method': True, 'calls': [] } # 简单提取函数调用关系(这是一个非常简化的方法) # 更复杂的代码需要完整的AST解析 for func_name, func_info in self.function_info.items(): if func_info['module'] == module_name: # 搜索函数体内的调用 # 这是一个简化实现,无法处理复杂情况 for other_func_name, other_func_info in self.function_info.items(): if other_func_info['module'] == module_name and other_func_info['name'] != func_info['name']: if re.search(r'\b' + re.escape(other_func_info['name']) + r'\s*\(', code): func_info['calls'].append(other_func_info['name']) except Exception as e: raise Exception(f"JavaScript文件分析错误: {str(e)}") def _analyze_ts_file(self, file_path: str): """分析TypeScript文件,类似于JavaScript但支持更多TypeScript特性 注意:这是一个简化版实现""" # TypeScript分析基本上与JavaScript相同,但有额外的类型信息 # 为简单起见,我们复用JavaScript分析器,但在实际应用中应该使用TypeScript专用解析器 return self._analyze_js_file(file_path) def _build_global_call_graph(self): """基于收集的信息构建全局调用图""" # 添加所有函数作为节点 for func_name, func_info in self.function_info.items(): self.call_graph.add_node(func_name, **func_info) # 添加调用关系作为边 for caller_name, caller_info in self.function_info.items(): module_name = caller_info['module'] for called_name in caller_info['calls']: # 尝试解析被调用函数的完全限定名 candidates = [] # 1. 同一模块中的函数 local_name = f"{module_name}.{called_name}" if local_name in self.function_info: candidates.append(local_name) # 2. 类方法调用 (Python, JS) if caller_info.get('class'): class_method = f"{caller_info['class']}.{called_name}" if class_method in self.function_info: candidates.append(class_method) # 3. 导入的函数 imported_name = self.import_map.get(f"{module_name}.{called_name}") if imported_name and imported_name in self.function_info: candidates.append(imported_name) # 4. 针对C文件: 检查其他C文件中的函数 if not candidates and (module_name.endswith('.c') or module_name.endswith('.h')): # 在所有其他C文件中查找该函数 for func, info in self.function_info.items(): if info['name'] == called_name and ( info['module'].endswith('.c') or info['module'].endswith('.h')): candidates.append(func) # 添加所有可能的调用关系 for candidate in candidates: self.call_graph.add_edge(caller_name, candidate) def get_function_info(self, function_name: str) -> Optional[Dict]: """获取函数的详细信息""" # 尝试直接查找 if function_name in self.function_info: return self.function_info[function_name] # 尝试模糊匹配 matches = [] for name, info in self.function_info.items(): if function_name in name or info['name'] == function_name: matches.append((name, info)) if len(matches) == 1: return matches[0][1] elif len(matches) > 1: # 返回所有模糊匹配的函数信息 return {name: info for name, info in matches} return None def get_class_info(self, class_name: str) -> Optional[Dict]: """获取类的详细信息""" # 尝试直接查找 if class_name in self.class_info: return self.class_info[class_name] # 尝试模糊匹配 matches = [] for name, info in self.class_info.items(): if class_name in name or info['name'] == class_name: matches.append((name, info)) if len(matches) == 1: return matches[0][1] elif len(matches) > 1: # 返回所有模糊匹配的类信息 return {name: info for name, info in matches} return None def get_callers(self, function_name: str) -> List[str]: """获取调用指定函数的所有函数""" # 找到准确的函数名 func_info = self.get_function_info(function_name) if not func_info: return [] # 如果是一个字典(多个匹配),选择第一个 if isinstance(func_info, dict): names = list(func_info.keys()) if not names: return [] actual_name = names[0] else: actual_name = func_info['qualified_name'] # 查找所有调用者 try: return list(self.call_graph.predecessors(actual_name)) except nx.NetworkXError: return [] def get_callees(self, function_name: str) -> List[str]: """获取指定函数调用的所有函数""" # 找到准确的函数名 func_info = self.get_function_info(function_name) if not func_info: return [] # 如果是一个字典(多个匹配),选择第一个 if isinstance(func_info, dict): names = list(func_info.keys()) if not names: return [] actual_name = names[0] else: actual_name = func_info['qualified_name'] # 查找所有被调用的函数 try: return list(self.call_graph.successors(actual_name)) except nx.NetworkXError: return [] def get_call_hierarchy(self, function_name: str, max_depth: int = 3) -> Dict: """获取函数的调用层次结构""" func_info = self.get_function_info(function_name) if not func_info: return {} # 如果是一个字典(多个匹配),选择第一个 if isinstance(func_info, dict): names = list(func_info.keys()) if not names: return {} function_name = names[0] func_info = func_info[function_name] else: function_name = func_info['qualified_name'] # 构建调用树 def build_tree(node, current_depth=0, visited=None): if visited is None: visited = set() if node in visited or current_depth >= max_depth: return {"name": node, "calls": []} visited.add(node) calls = [] try: for callee in self.call_graph.successors(node): calls.append(build_tree(callee, current_depth + 1, visited.copy())) except nx.NetworkXError: pass return { "name": node, "info": self.function_info.get(node, {}), "calls": calls } return build_tree(function_name) def find_shortest_path(self, source_func: str, target_func: str) -> List[str]: """找到从一个函数到另一个函数的最短调用路径""" # 解析函数名 source_info = self.get_function_info(source_func) target_info = self.get_function_info(target_func) if not source_info or not target_info: return [] # 处理模糊匹配情况 if isinstance(source_info, dict): source_func = list(source_info.keys())[0] else: source_func = source_info['qualified_name'] if isinstance(target_info, dict): target_func = list(target_info.keys())[0] else: target_func = target_info['qualified_name'] # 检查节点是否在图中 if source_func not in self.call_graph or target_func not in self.call_graph: return [] # 尝试找到最短路径 try: return nx.shortest_path(self.call_graph, source_func, target_func) except nx.NetworkXNoPath: return [] def get_common_callers(self, func1: str, func2: str) -> List[str]: """找到同时调用两个函数的函数""" callers1 = set(self.get_callers(func1)) callers2 = set(self.get_callers(func2)) return list(callers1.intersection(callers2)) def export_call_graph(self, output_file: str = None): """导出调用图为JSON格式""" data = { "nodes": [], "links": [] }
08-09
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值