<think>我们遇到了一个错误:`NameError: name 'threading' is not defined`。这是因为在代码中使用了`threading.Lock()`,但没有导入`threading`模块。我们需要在脚本开头导入`threading`模块。
此外,我们还可以进行一些其他优化:
1. 确保所有用到的模块都已导入。
2. 在并行处理中,我们使用了`concurrent.futures`,但同时也使用了`threading.Lock`,这是正确的,因为我们需要在多个线程中安全地更新共享数据结构(如`call_graph`和`class_methods`)。
3. 检查代码中是否还有其他未导入的模块。
修复步骤:
- 在文件顶部添加`import threading`。
同时,我们注意到在优化版本的代码中,我们使用了线程池和锁,但在原始代码中并没有导入`threading`模块,因此导致了上述错误。
另外,在优化版本中,我们使用了`concurrent.futures`,这已经导入了,但`threading`模块需要单独导入。
我们将在代码顶部添加`import threading`。
此外,我们还可以对代码进行一些微调:
- 将`MAX_WORKERS`的计算放在类外部,避免在类定义中执行可能引发异常的操作(虽然当前没有问题)。
- 确保在类初始化时正确初始化锁。
修改后的代码将如下所示(仅显示修改后的完整代码,并确保修复了`NameError`问题):
注意:由于原始代码较长,我们只显示修改后的完整脚本,并在关键修改处添加注释。</think>### 修复线程安全锁未导入问题及完整优化脚本
根据错误信息 `NameError: name 'threading' is not defined`,问题在于使用了 `threading.Lock()` 但未导入 `threading` 模块。下面是修复后的完整优化脚本,主要修改包括:
1. **添加 `threading` 模块导入**
2. **优化锁初始化位置**
3. **增强线程安全机制**
4. **修复其他潜在问题**
```python
import argparse
import os
import json
import time
import re
import traceback
import concurrent.futures
import logging
import threading # 修复:添加缺失的threading模块导入
import javalang
from javalang import tree
from collections import deque, defaultdict
from typing import Dict, List, Union, Tuple, Set
from colorama import Fore, init
from javalang.tree import ClassDeclaration
from JavaCodeExtract import extract_method_definition
from AutoVulReport import generate_markdown_report
# 初始化日志系统
logging.basicConfig(
level=logging.INFO,
format=f"{Fore.LIGHTBLUE_EX}[%(asctime)s] {Fore.LIGHTGREEN_EX}%(levelname)s: {Fore.RESET}%(message)s",
handlers=[
logging.StreamHandler(),
logging.FileHandler("java_sink_tracer.log", encoding="utf-8")
]
)
logger = logging.getLogger(__name__)
# 常量定义
MAPPING_ANNOTATIONS = {
"GetMapping", "PostMapping", "RequestMapping", "PutMapping", "DeleteMapping",
"Path", "GET", "POST", "PUT", "DELETE"
}
CONFIG_EXTENSIONS = {'.properties', '.yml', '.yaml', '.xml', '.conf', '.cfg'}
MAX_WORKERS = os.cpu_count() or 4 # 根据CPU核心数设置线程池大小
MIN_PASSWORD_LENGTH = 8 # 密码最小长度要求
MAX_DEPTH = 10 # 最大调用链深度
init(autoreset=True)
class JavaSinkTracer:
def __init__(self, project_path: str, rules_path: str):
self.project_path = os.path.normpath(project_path)
self.rules = self._load_rules(rules_path)
self.call_graph: Dict[str, List[str]] = {}
self.class_methods: Dict[str, Dict[str, Union[str, Dict[str, Dict[str, bool]]]]] = {}
self.ast_cache: Dict[str, javalang.parser.tree.CompilationUnit] = {}
self.weak_password_rules = [r for r in self.rules["sink_rules"] if r["sink_name"] == "WEAK_PASSWORD"]
self.vul_rules = [r for r in self.rules["sink_rules"] if r["sink_name"] != "WEAK_PASSWORD"]
self.processed_files: Set[str] = set()
self.lock = threading.Lock() # 修复:正确初始化线程锁
@staticmethod
def _load_rules(path: str) -> dict:
"""加载漏洞规则配置文件"""
try:
with open(path, "r", encoding="utf-8") as f:
rules = json.load(f)
logger.info(f"成功加载规则: {path},包含 {len(rules['sink_rules'])} 条规则")
return rules
except (FileNotFoundError, json.JSONDecodeError) as e:
logger.error(f"规则文件加载失败: {str(e)}")
raise
def _is_excluded(self, file_path: str) -> bool:
"""检查文件是否在排除路径中"""
rel_path = os.path.relpath(file_path, self.project_path)
return any(p in rel_path.split(os.sep) for p in self.rules["path_exclusions"])
def build_ast(self):
"""并行构建项目AST并建立调用关系"""
logger.info(f"开始构建项目AST: {self.project_path}")
java_files = []
for root, _, files in os.walk(self.project_path):
for file in files:
if file.endswith(".java") and not self._is_excluded(os.path.join(root, file)):
java_files.append(os.path.join(root, file))
logger.info(f"发现 {len(java_files)} 个Java文件,使用 {MAX_WORKERS} 个线程处理")
# 使用线程池并行处理文件
with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
futures = {executor.submit(self._process_file, file): file for file in java_files}
for future in concurrent.futures.as_completed(futures):
file_path = futures[future]
try:
future.result()
except Exception as e:
logger.error(f"处理文件 {file_path} 时出错: {str(e)}")
logger.info(f"AST构建完成,处理了 {len(self.processed_files)} 个文件")
def _process_file(self, file_path: str):
"""处理单个Java文件"""
normalized_path = os.path.normpath(file_path)
if normalized_path in self.processed_files:
return
if not os.path.exists(normalized_path):
logger.error(f"文件不存在: {normalized_path}")
return
try:
with open(normalized_path, "r", encoding="utf-8", errors="replace") as f:
content = f.read()
# 尝试从缓存获取AST
if normalized_path in self.ast_cache:
code_tree = self.ast_cache[normalized_path]
else:
code_tree = javalang.parse.parse(content)
self.ast_cache[normalized_path] = code_tree
# 检查是否包含类或枚举声明
if not any(isinstance(node, (javalang.tree.ClassDeclaration, javalang.tree.EnumDeclaration))
for _, node in code_tree):
logger.warning(f"跳过非类/枚举文件: {file_path}")
return
# 提取类信息和构建调用图
self._extract_class_info(code_tree, normalized_path)
self._build_call_graph(code_tree, normalized_path)
with self.lock: # 线程安全更新
self.processed_files.add(normalized_path)
except UnicodeDecodeError:
logger.error(f"文件编码错误: {normalized_path}")
except javalang.parser.JavaSyntaxError as e:
logger.error(f"语法错误: {normalized_path}: {str(e)}")
except Exception as e:
logger.error(f"处理文件 {normalized_path} 时发生异常: {str(e)}")
logger.debug(traceback.format_exc())
def _extract_class_info(self, code_tree, file_path: str):
"""提取Java类和方法信息"""
for path, node in code_tree.filter(ClassDeclaration):
class_name = node.name
methods_info = {}
for method_node in node.methods:
method_name = method_node.name
requires_params = bool(method_node.parameters)
has_mapping_annotation = False
# 检查映射注解
if method_node.annotations:
for annotation in method_node.annotations:
annotation_name = annotation.name.lstrip("@")
if annotation_name in MAPPING_ANNOTATIONS:
has_mapping_annotation = True
break
methods_info[method_name] = {
"requires_params": requires_params,
"has_mapping_annotation": has_mapping_annotation
}
# 使用全路径作为键避免类名冲突
full_class_name = f"{file_path}::{class_name}"
self.class_methods[full_class_name] = {
"file_path": file_path,
"methods": methods_info
}
def _build_call_graph(self, file_code_tree, file_path: str):
"""构建方法调用图"""
variable_symbols = self.get_variable_symbols(file_code_tree)
for path, node in file_code_tree.filter(javalang.tree.MethodInvocation):
try:
caller = self._get_current_method_from_path(path, file_path)
if caller == "unknown:unknown":
continue
callee = self._resolve_callee(node, path, variable_symbols, caller)
if callee.startswith('[!]'):
logger.warning(f"调用图解析失败: {caller} -> {callee}")
else:
logger.debug(f"调用图: {caller} -> {callee}")
# 使用线程安全的更新
with self.lock: # 线程安全操作
self.call_graph.setdefault(caller, []).append(callee)
except Exception as e:
logger.error(f"构建调用图时出错: {str(e)}")
logger.debug(traceback.format_exc())
def _resolve_callee(self, node, path, variable_symbols, caller) -> str:
"""解析被调用方法"""
if node.qualifier:
parts = node.qualifier.split('.')
base_type = parts[0] if parts and parts[0][0].isupper() else node.qualifier
base_type = variable_symbols.get(node.qualifier, base_type).split('<')[0]
return f"{base_type}:{node.member}"
elif node.qualifier is None:
if self.is_string_literal_caller(path):
return f"String:{node.member}"
parent_node = path[-2] if len(path) > 1 else None
if isinstance(parent_node, (javalang.tree.ClassCreator, javalang.tree.ClassReference)):
return f"{parent_node.type.name}:{node.member}"
# 尝试从调用图中获取最近的类型
if caller in self.call_graph and self.call_graph[caller]:
last_call = self.call_graph[caller][-1]
if ':' in last_call:
return f"{last_call.split(':')[0]}:{node.member}"
return f"[!]解析失败:{node.member}"
@staticmethod
def is_string_literal_caller(path) -> bool:
"""判断方法调用是否由字符串常量发起"""
return any(isinstance(p, javalang.tree.Literal) and isinstance(p.value, str)
for p in reversed(path))
@staticmethod
def get_variable_symbols(file_code_tree) -> Dict[str, str]:
"""提取变量声明及其类型"""
variable_symbols = {}
# 处理局部变量声明
for _, node in file_code_tree.filter(javalang.tree.LocalVariableDeclaration):
var_type = node.type.name if hasattr(node.type, 'name') else str(node.type)
for declarator in node.declarators:
variable_symbols[declarator.name] = var_type
# 处理字段声明
for _, node in file_code_tree.filter(javalang.tree.FieldDeclaration):
var_type = node.type.name if hasattr(node.type, 'name') else str(node.type)
for declarator in node.declarators:
variable_symbols[declarator.name] = var_type
# 处理方法参数
for _, node in file_code_tree.filter(javalang.tree.MethodDeclaration):
for param in node.parameters:
var_type = param.type.name if hasattr(param.type, 'name') else str(param.type)
variable_symbols[param.name] = var_type
return variable_symbols
def _get_current_method_from_path(self, path, file_path: str) -> str:
"""获取当前方法全名"""
for node in reversed(path):
if isinstance(node, javalang.tree.MethodDeclaration):
class_node = self.find_parent_class(path)
if class_node:
# 使用文件路径限定类名避免冲突
return f"{file_path}::{class_node.name}:{node.name}"
return "unknown:unknown"
def find_taint_paths(self) -> List[dict]:
"""查找所有污点传播路径"""
logger.info(f"开始审计项目: {self.project_path}")
results = []
# 弱口令检测
if self.weak_password_rules:
logger.info("开始弱口令配置检测...")
weak_password_results = self.scan_config_files()
if weak_password_results:
logger.warning(f"发现 {len(weak_password_results)} 个弱口令配置问题")
results.extend(weak_password_results)
# 其他漏洞检测
for rule in self.vul_rules:
if "sinks" not in rule:
logger.error(f"规则缺少sinks字段: {rule}")
continue
for sink in rule["sinks"]:
if sink.count(':') != 1:
logger.error(f"无效的sink格式: {sink} (应为 '类名:方法名')")
continue
class_name_full = sink.split(":")[0]
methods_part = sink.split(":")[1]
class_name = class_name_full.split('.')[-1]
for method in methods_part.split("|"):
sink_point = f"{class_name}:{method}"
logger.info(f"审计sink点: {sink_point}")
paths = self._trace_back(sink_point, MAX_DEPTH)
if paths:
results.append({
"vul_type": rule["sink_name"],
"sink_desc": rule["sink_desc"],
"severity": rule["severity_level"],
"sink": sink_point,
"call_chains": self.process_call_stacks(self.project_path, paths)
})
return results
def _trace_back(self, sink: str, max_depth: int) -> List[List[str]]:
"""回溯污点传播路径"""
paths = []
queue = deque([([sink], 0)])
visited = set()
while queue:
current_path, current_depth = queue.popleft()
current_sink = current_path[0]
# 检查深度限制和重复路径
if current_depth >= max_depth or tuple(current_path) in visited:
continue
visited.add(tuple(current_path))
# 查找调用者
caller_methods = [
caller for caller, callees in self.call_graph.items()
if current_sink in callees
]
if not caller_methods:
continue
logger.debug(f"需要追溯调用点: {caller_methods}")
for caller in caller_methods:
# 检查方法参数
if not self.is_has_parameters(caller):
logger.debug(f"忽略无参函数: {caller}")
continue
new_path = [caller] + current_path
logger.debug(f"追溯路径: [{' → '.join(new_path)}]")
if self.is_entry_point(caller):
paths.append(new_path)
logger.info(f"发现完整调用链: {new_path}")
else:
queue.append((new_path, current_depth + 1))
return paths
def is_has_parameters(self, method_full: str) -> bool:
"""检查方法是否有参数"""
try:
# 方法全名格式: file_path::ClassName:methodName
parts = method_full.split("::")
if len(parts) < 2:
return True
class_method = parts[-1]
if ':' not in class_method:
return True
class_name, method_name = class_method.split(':', 1)
full_class_name = f"{parts[0]}::{class_name}"
if full_class_name not in self.class_methods:
return True
methods = self.class_methods[full_class_name]["methods"]
return methods.get(method_name, {}).get("requires_params", True)
except Exception:
return True
def is_entry_point(self, method_full: str) -> bool:
"""检查是否是入口点"""
try:
# 方法全名格式: file_path::ClassName:methodName
parts = method_full.split("::")
if len(parts) < 2:
return False
class_method = parts[-1]
if ':' not in class_method:
return False
class_name, method_name = class_method.split(':', 1)
full_class_name = f"{parts[0]}::{class_name}"
if full_class_name not in self.class_methods:
return False
methods = self.class_methods[full_class_name]["methods"]
return methods.get(method_name, {}).get("has_mapping_annotation", False)
except Exception:
return False
@staticmethod
def find_parent_class(path) -> Union[javalang.tree.ClassDeclaration, javalang.tree.EnumDeclaration, None]:
"""查找最近的类或枚举声明"""
for node in reversed(path):
if isinstance(node, (javalang.tree.ClassDeclaration, javalang.tree.EnumDeclaration)):
return node
return None
def scan_config_files(self) -> List[dict]:
"""扫描配置文件检测弱口令"""
results = []
config_files = []
for root, _, files in os.walk(self.project_path):
for file in files:
file_path = os.path.join(root, file)
if any(file.endswith(ext) for ext in CONFIG_EXTENSIONS) and not self._is_excluded(file_path):
config_files.append(file_path)
logger.info(f"发现 {len(config_files)} 个配置文件")
# 并行处理配置文件
with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
futures = {executor.submit(self._scan_single_config, file): file for file in config_files}
for future in concurrent.futures.as_completed(futures):
file_path = futures[future]
try:
file_results = future.result()
if file_results:
results.extend(file_results)
except Exception as e:
logger.error(f"扫描配置文件 {file_path} 时出错: {str(e)}")
return results
def _scan_single_config(self, file_path: str) -> List[dict]:
"""扫描单个配置文件"""
results = []
config = self._parse_config_file(file_path)
if not config:
return []
file_relpath = os.path.relpath(file_path, self.project_path)
for rule in self.weak_password_rules