AI驱动的自动化测试生成:提升代码质量的智能化解决方案

摘要

自动化测试是现代软件开发的重要组成部分,但编写全面的测试用例往往耗时且容易遗漏边界情况。本文将深入探讨如何利用AI技术自动生成高质量的测试代码,从单元测试到集成测试,从性能测试到安全测试,全面覆盖测试生成的各个方面。通过分析awesome-chatgpt-prompts项目中的测试相关提示词,我们将构建一套完整的AI驱动测试生成系统。

目录

  1. AI测试生成概述
  2. 单元测试自动生成
  3. 集成测试智能设计
  4. 性能测试场景构建
  5. 安全测试用例生成
  6. 测试数据智能生成
  7. 测试覆盖率分析
  8. 测试维护与优化
  9. 实战案例分析
  10. 总结与最佳实践

AI测试生成概述

测试生成架构体系

AI测试生成系统
代码分析模块
测试生成引擎
测试执行器
结果分析器
语法分析
依赖分析
复杂度分析
单元测试生成
集成测试生成
性能测试生成
安全测试生成
测试执行
结果收集
错误捕获
覆盖率分析
质量评估
优化建议

核心测试生成框架

import ast
import inspect
import unittest
from typing import List, Dict, Any, Optional, Callable
from dataclasses import dataclass
from abc import ABC, abstractmethod

@dataclass
class TestCase:
    """测试用例数据结构"""
    name: str
    description: str
    input_data: Dict[str, Any]
    expected_output: Any
    test_type: str
    priority: int
    tags: List[str]

class CodeAnalyzer:
    """代码分析器"""
    
    def __init__(self):
        self.functions = []
        self.classes = []
        self.dependencies = []
        self.complexity_metrics = {}
    
    def analyze_code(self, code: str) -> Dict[str, Any]:
        """分析代码结构"""
        try:
            tree = ast.parse(code)
            analysis = {
                'functions': self._extract_functions(tree),
                'classes': self._extract_classes(tree),
                'imports': self._extract_imports(tree),
                'complexity': self._calculate_complexity(tree),
                'dependencies': self._analyze_dependencies(tree)
            }
            return analysis
        except SyntaxError as e:
            return {'error': f'语法错误: {e}'}
    
    def _extract_functions(self, tree: ast.AST) -> List[Dict[str, Any]]:
        """提取函数信息"""
        functions = []
        
        for node in ast.walk(tree):
            if isinstance(node, ast.FunctionDef):
                func_info = {
                    'name': node.name,
                    'args': [arg.arg for arg in node.args.args],
                    'returns': self._get_return_type(node),
                    'docstring': ast.get_docstring(node),
                    'line_number': node.lineno,
                    'decorators': [d.id if isinstance(d, ast.Name) else str(d) for d in node.decorator_list],
                    'complexity': self._calculate_function_complexity(node)
                }
                functions.append(func_info)
        
        return functions
    
    def _extract_classes(self, tree: ast.AST) -> List[Dict[str, Any]]:
        """提取类信息"""
        classes = []
        
        for node in ast.walk(tree):
            if isinstance(node, ast.ClassDef):
                class_info = {
                    'name': node.name,
                    'methods': [],
                    'attributes': [],
                    'inheritance': [base.id if isinstance(base, ast.Name) else str(base) for base in node.bases],
                    'docstring': ast.get_docstring(node),
                    'line_number': node.lineno
                }
                
                # 提取方法
                for item in node.body:
                    if isinstance(item, ast.FunctionDef):
                        method_info = {
                            'name': item.name,
                            'args': [arg.arg for arg in item.args.args],
                            'is_private': item.name.startswith('_'),
                            'is_property': any(d.id == 'property' if isinstance(d, ast.Name) else False for d in item.decorator_list)
                        }
                        class_info['methods'].append(method_info)
                
                classes.append(class_info)
        
        return classes
    
    def _extract_imports(self, tree: ast.AST) -> List[Dict[str, Any]]:
        """提取导入信息"""
        imports = []
        
        for node in ast.walk(tree):
            if isinstance(node, ast.Import):
                for alias in node.names:
                    imports.append({
                        'type': 'import',
                        'module': alias.name,
                        'alias': alias.asname
                    })
            elif isinstance(node, ast.ImportFrom):
                for alias in node.names:
                    imports.append({
                        'type': 'from_import',
                        'module': node.module,
                        'name': alias.name,
                        'alias': alias.asname
                    })
        
        return imports
    
    def _calculate_complexity(self, tree: ast.AST) -> Dict[str, int]:
        """计算代码复杂度"""
        complexity = {
            'cyclomatic': 0,
            'cognitive': 0,
            'lines_of_code': 0,
            'functions_count': 0,
            'classes_count': 0
        }
        
        for node in ast.walk(tree):
            if isinstance(node, (ast.If, ast.While, ast.For, ast.Try)):
                complexity['cyclomatic'] += 1
            elif isinstance(node, ast.FunctionDef):
                complexity['functions_count'] += 1
            elif isinstance(node, ast.ClassDef):
                complexity['classes_count'] += 1
        
        return complexity
    
    def _calculate_function_complexity(self, func_node: ast.FunctionDef) -> int:
        """计算函数复杂度"""
        complexity = 1  # 基础复杂度
        
        for node in ast.walk(func_node):
            if isinstance(node, (ast.If, ast.While, ast.For, ast.Try, ast.With)):
                complexity += 1
            elif isinstance(node, (ast.And, ast.Or)):
                complexity += 1
        
        return complexity
    
    def _get_return_type(self, func_node: ast.FunctionDef) -> Optional[str]:
        """获取函数返回类型"""
        if func_node.returns:
            if isinstance(func_node.returns, ast.Name):
                return func_node.returns.id
            elif isinstance(func_node.returns, ast.Constant):
                return str(func_node.returns.value)
        return None
    
    def _analyze_dependencies(self, tree: ast.AST) -> List[str]:
        """分析代码依赖"""
        dependencies = set()
        
        for node in ast.walk(tree):
            if isinstance(node, ast.Call):
                if isinstance(node.func, ast.Name):
                    dependencies.add(node.func.id)
                elif isinstance(node.func, ast.Attribute):
                    if isinstance(node.func.value, ast.Name):
                        dependencies.add(f"{node.func.value.id}.{node.func.attr}")
        
        return list(dependencies)

class TestGenerator(ABC):
    """测试生成器基类"""
    
    @abstractmethod
    def generate_tests(self, code_analysis: Dict[str, Any]) -> List[TestCase]:
        """生成测试用例"""
        pass
    
    @abstractmethod
    def get_test_template(self) -> str:
        """获取测试模板"""
        pass

class UnitTestGenerator(TestGenerator):
    """单元测试生成器"""
    
    def __init__(self):
        self.test_patterns = {
            'normal_case': '正常情况测试',
            'edge_case': '边界情况测试',
            'error_case': '异常情况测试',
            'null_case': '空值测试'
        }
    
    def generate_tests(self, code_analysis: Dict[str, Any]) -> List[TestCase]:
        """生成单元测试用例"""
        test_cases = []
        
        for func in code_analysis.get('functions', []):
            test_cases.extend(self._generate_function_tests(func))
        
        for cls in code_analysis.get('classes', []):
            test_cases.extend(self._generate_class_tests(cls))
        
        return test_cases
    
    def _generate_function_tests(self, func_info: Dict[str, Any]) -> List[TestCase]:
        """为函数生成测试用例"""
        test_cases = []
        func_name = func_info['name']
        args = func_info['args']
        
        # 正常情况测试
        normal_test = TestCase(
            name=f"test_{func_name}_normal_case",
            description=f"测试{func_name}函数的正常情况",
            input_data=self._generate_normal_inputs(args),
            expected_output=None,  # 需要根据函数逻辑推断
            test_type="unit",
            priority=1,
            tags=["normal", "unit"]
        )
        test_cases.append(normal_test)
        
        # 边界情况测试
        edge_test = TestCase(
            name=f"test_{func_name}_edge_case",
            description=f"测试{func_name}函数的边界情况",
            input_data=self._generate_edge_inputs(args),
            expected_output=None,
            test_type="unit",
            priority=2,
            tags=["edge", "unit"]
        )
        test_cases.append(edge_test)
        
        # 异常情况测试
        if self._should_test_exceptions(func_info):
            error_test = TestCase(
                name=f"test_{func_name}_error_case",
                description=f"测试{func_name}函数的异常情况",
                input_data=self._generate_error_inputs(args),
                expected_output="Exception",
                test_type="unit",
                priority=3,
                tags=["error", "unit"]
            )
            test_cases.append(error_test)
        
        return test_cases
    
    def _generate_class_tests(self, class_info: Dict[str, Any]) -> List[TestCase]:
        """为类生成测试用例"""
        test_cases = []
        class_name = class_info['name']
        
        # 构造函数测试
        init_test = TestCase(
            name=f"test_{class_name}_initialization",
            description=f"测试{class_name}类的初始化",
            input_data={},
            expected_output=None,
            test_type="unit",
            priority=1,
            tags=["init", "unit"]
        )
        test_cases.append(init_test)
        
        # 方法测试
        for method in class_info['methods']:
            if not method['name'].startswith('_'):  # 跳过私有方法
                method_tests = self._generate_method_tests(class_name, method)
                test_cases.extend(method_tests)
        
        return test_cases
    
    def _generate_method_tests(self, class_name: str, method_info: Dict[str, Any]) -> List[TestCase]:
        """为方法生成测试用例"""
        test_cases = []
        method_name = method_info['name']
        
        test_case = TestCase(
            name=f"test_{class_name}_{method_name}",
            description=f"测试{class_name}.{method_name}方法",
            input_data=self._generate_normal_inputs(method_info['args'][1:]),  # 跳过self
            expected_output=None,
            test_type="unit",
            priority=1,
            tags=["method", "unit"]
        )
        test_cases.append(test_case)
        
        return test_cases
    
    def _generate_normal_inputs(self, args: List[str]) -> Dict[str, Any]:
        """生成正常输入数据"""
        inputs = {}
        
        for arg in args:
            if arg == 'self':
                continue
            
            # 根据参数名推断类型和值
            if 'id' in arg.lower():
                inputs[arg] = 1
            elif 'name' in arg.lower():
                inputs[arg] = "test_name"
            elif 'email' in arg.lower():
                inputs[arg] = "test@example.com"
            elif 'age' in arg.lower():
                inputs[arg] = 25
            elif 'count' in arg.lower() or 'num' in arg.lower():
                inputs[arg] = 10
            elif 'list' in arg.lower() or 'items' in arg.lower():
                inputs[arg] = [1, 2, 3]
            elif 'dict' in arg.lower() or 'data' in arg.lower():
                inputs[arg] = {"key": "value"}
            else:
                inputs[arg] = "test_value"
        
        return inputs
    
    def _generate_edge_inputs(self, args: List[str]) -> Dict[str, Any]:
        """生成边界输入数据"""
        inputs = {}
        
        for arg in args:
            if arg == 'self':
                continue
            
            # 生成边界值
            if 'id' in arg.lower():
                inputs[arg] = 0  # 最小ID
            elif 'age' in arg.lower():
                inputs[arg] = 0  # 最小年龄
            elif 'count' in arg.lower():
                inputs[arg] = 0  # 空计数
            elif 'list' in arg.lower():
                inputs[arg] = []  # 空列表
            elif 'dict' in arg.lower():
                inputs[arg] = {}  # 空字典
            elif 'name' in arg.lower():
                inputs[arg] = ""  # 空字符串
            else:
                inputs[arg] = None  # 空值
        
        return inputs
    
    def _generate_error_inputs(self, args: List[str]) -> Dict[str, Any]:
        """生成错误输入数据"""
        inputs = {}
        
        for arg in args:
            if arg == 'self':
                continue
            
            # 生成可能导致错误的值
            if 'id' in arg.lower():
                inputs[arg] = -1  # 负数ID
            elif 'age' in arg.lower():
                inputs[arg] = -5  # 负年龄
            elif 'email' in arg.lower():
                inputs[arg] = "invalid_email"  # 无效邮箱
            elif 'list' in arg.lower():
                inputs[arg] = "not_a_list"  # 错误类型
            elif 'dict' in arg.lower():
                inputs[arg] = "not_a_dict"  # 错误类型
            else:
                inputs[arg] = object()  # 不可序列化对象
        
        return inputs
    
    def _should_test_exceptions(self, func_info: Dict[str, Any]) -> bool:
        """判断是否需要测试异常"""
        # 如果函数复杂度较高或有特定关键词,则需要异常测试
        return (func_info.get('complexity', 0) > 3 or 
                any(keyword in func_info['name'].lower() 
                    for keyword in ['validate', 'check', 'parse', 'convert']))
    
    def get_test_template(self) -> str:
        """获取单元测试模板"""
        return '''
import unittest
from unittest.mock import Mock, patch
import sys
import os

# 添加项目路径到sys.path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from {module_name} import {class_or_function_name}

class Test{ClassName}(unittest.TestCase):
    """
    {class_or_function_name}的单元测试类
    """
    
    def setUp(self):
        """测试前的准备工作"""
        pass
    
    def tearDown(self):
        """测试后的清理工作"""
        pass
    
    {test_methods}

if __name__ == '__main__':
    unittest.main()
'''

class TestCodeGenerator:
    """测试代码生成器"""
    
    def __init__(self):
        self.generators = {
            'unit': UnitTestGenerator(),
            'integration': None,  # 后续实现
            'performance': None,  # 后续实现
            'security': None      # 后续实现
        }
    
    def generate_test_code(self, code: str, test_type: str = 'unit') -> str:
        """生成测试代码"""
        # 分析代码
        analyzer = CodeAnalyzer()
        analysis = analyzer.analyze_code(code)
        
        if 'error' in analysis:
            return f"# 代码分析失败: {analysis['error']}"
        
        # 生成测试用例
        generator = self.generators.get(test_type)
        if not generator:
            return f"# 不支持的测试类型: {test_type}"
        
        test_cases = generator.generate_tests(analysis)
        
        # 生成测试代码
        return self._build_test_code(analysis, test_cases, test_type)
    
    def _build_test_code(self, analysis: Dict[str, Any], test_cases: List[TestCase], test_type: str) -> str:
        """构建测试代码"""
        if test_type == 'unit':
            return self._build_unit_test_code(analysis, test_cases)
        else:
            return "# 其他测试类型的代码生成待实现"
    
    def _build_unit_test_code(self, analysis: Dict[str, Any], test_cases: List[TestCase]) -> str:
        """构建单元测试代码"""
        imports = self._generate_imports(analysis)
        test_class = self._generate_test_class(analysis, test_cases)
        
        return f"""
{imports}

{test_class}

if __name__ == '__main__':
    unittest.main()
"""
    
    def _generate_imports(self, analysis: Dict[str, Any]) -> str:
        """生成导入语句"""
        imports = [
            "import unittest",
            "from unittest.mock import Mock, patch, MagicMock",
            "import sys",
            "import os"
        ]
        
        # 添加被测试模块的导入
        functions = analysis.get('functions', [])
        classes = analysis.get('classes', [])
        
        if functions or classes:
            imports.append("")
            imports.append("# 导入被测试的模块")
            imports.append("# from your_module import YourClass, your_function")
        
        return "\n".join(imports)
    
    def _generate_test_class(self, analysis: Dict[str, Any], test_cases: List[TestCase]) -> str:
        """生成测试类"""
        class_name = "TestGeneratedCode"
        
        # 生成类头部
        class_code = [
            f"class {class_name}(unittest.TestCase):",
            '    """自动生成的测试类"""',
            "",
            "    def setUp(self):",
            '        """测试前的准备工作"""',
            "        pass",
            "",
            "    def tearDown(self):",
            '        """测试后的清理工作"""',
            "        pass",
            ""
        ]
        
        # 生成测试方法
        for test_case in test_cases:
            method_code = self._generate_test_method(test_case)
            class_code.extend(method_code)
            class_code.append("")
        
        return "\n".join(class_code)
    
    def _generate_test_method(self, test_case: TestCase) -> List[str]:
        """生成测试方法"""
        method_lines = [
            f"    def {test_case.name}(self):",
            f'        """{test_case.description}"""',
        ]
        
        # 生成测试数据准备
        if test_case.input_data:
            method_lines.append("        # 准备测试数据")
            for key, value in test_case.input_data.items():
                if isinstance(value, str):
                    method_lines.append(f'        {key} = "{value}"')
                else:
                    method_lines.append(f'        {key} = {repr(value)}')
            method_lines.append("")
        
        # 生成测试执行和断言
        if test_case.expected_output == "Exception":
            method_lines.extend([
                "        # 测试异常情况",
                "        with self.assertRaises(Exception):",
                "            # 调用被测试的函数/方法",
                "            pass  # 替换为实际的函数调用"
            ])
        else:
            method_lines.extend([
                "        # 执行测试",
                "        # result = your_function(test_args)",
                "        ",
                "        # 验证结果",
                "        # self.assertEqual(result, expected_value)",
                "        pass  # 替换为实际的断言"
            ])
        
        return method_lines

# 使用示例
def demo_test_generation():
    """演示测试生成功能"""
    
    # 示例代码
    sample_code = '''
class Calculator:
    """简单计算器类"""
    
    def __init__(self):
        self.history = []
    
    def add(self, a, b):
        """加法运算"""
        if not isinstance(a, (int, float)) or not isinstance(b, (int, float)):
            raise TypeError("参数必须是数字")
        
        result = a + b
        self.history.append(f"{a} + {b} = {result}")
        return result
    
    def divide(self, a, b):
        """除法运算"""
        if not isinstance(a, (int, float)) or not isinstance(b, (int, float)):
            raise TypeError("参数必须是数字")
        
        if b == 0:
            raise ValueError("除数不能为零")
        
        result = a / b
        self.history.append(f"{a} / {b} = {result}")
        return result
    
    def get_history(self):
        """获取计算历史"""
        return self.history.copy()
    
    def clear_history(self):
        """清空计算历史"""
        self.history.clear()

def validate_email(email):
    """验证邮箱格式"""
    if not isinstance(email, str):
        raise TypeError("邮箱必须是字符串")
    
    if "@" not in email:
        raise ValueError("邮箱格式不正确")
    
    parts = email.split("@")
    if len(parts) != 2:
        raise ValueError("邮箱格式不正确")
    
    username, domain = parts
    if not username or not domain:
        raise ValueError("邮箱格式不正确")
    
    return True
'''
    
    # 生成测试代码
    generator = TestCodeGenerator()
    test_code = generator.generate_test_code(sample_code, 'unit')
    
    print("生成的测试代码:")
    print("=" * 50)
    print(test_code)
    
    # 分析代码结构
    analyzer = CodeAnalyzer()
    analysis = analyzer.analyze_code(sample_code)
    
    print("\n代码分析结果:")
    print("=" * 50)
    print(f"函数数量: {len(analysis.get('functions', []))}")
    print(f"类数量: {len(analysis.get('classes', []))}")
    print(f"复杂度指标: {analysis.get('complexity', {})}")
    
    print("\n发现的函数:")
    for func in analysis.get('functions', []):
        print(f"- {func['name']} (复杂度: {func['complexity']})")
    
    print("\n发现的类:")
    for cls in analysis.get('classes', []):
        print(f"- {cls['name']} (方法数: {len(cls['methods'])})")

if __name__ == "__main__":
    demo_test_generation()

单元测试自动生成

智能测试用例设计

class IntelligentTestCaseDesigner:
    """智能测试用例设计器"""
    
    def __init__(self):
        self.test_strategies = {
            'equivalence_partitioning': self._equivalence_partitioning,
            'boundary_value_analysis': self._boundary_value_analysis,
            'decision_table': self._decision_table_testing,
            'state_transition': self._state_transition_testing,
            'error_guessing': self._error_guessing
        }
        self.data_generators = {
            'string': self._generate_string_data,
            'number': self._generate_number_data,
            'list': self._generate_list_data,
            'dict': self._generate_dict_data,
            'boolean': self._generate_boolean_data
        }
    
    def design_test_cases(self, function_info: Dict[str, Any]) -> List[TestCase]:
        """设计测试用例"""
        test_cases = []
        
        # 分析函数参数类型
        param_types = self._analyze_parameter_types(function_info)
        
        # 应用不同的测试策略
        for strategy_name, strategy_func in self.test_strategies.items():
            strategy_cases = strategy_func(function_info, param_types)
            test_cases.extend(strategy_cases)
        
        # 去重和优化
        optimized_cases = self._optimize_test_cases(test_cases)
        
        return optimized_cases
    
    def _analyze_parameter_types(self, function_info: Dict[str, Any]) -> Dict[str, str]:
        """分析参数类型"""
        param_types = {}
        
        for param in function_info.get('args', []):
            if param == 'self':
                continue
            
            # 基于参数名推断类型
            param_lower = param.lower()
            if any(keyword in param_lower for keyword in ['id', 'count', 'num', 'age', 'size']):
                param_types[param] = 'number'
            elif any(keyword in param_lower for keyword in ['name', 'email', 'text', 'str']):
                param_types[param] = 'string'
            elif any(keyword in param_lower for keyword in ['list', 'items', 'array']):
                param_types[param] = 'list'
            elif any(keyword in param_lower for keyword in ['dict', 'data', 'config']):
                param_types[param] = 'dict'
            elif any(keyword in param_lower for keyword in ['flag', 'enabled', 'active']):
                param_types[param] = 'boolean'
            else:
                param_types[param] = 'string'  # 默认类型
        
        return param_types
    
    def _equivalence_partitioning(self, function_info: Dict[str, Any], param_types: Dict[str, str]) -> List[TestCase]:
        """等价类划分测试"""
        test_cases = []
        func_name = function_info['name']
        
        # 为每个参数类型生成等价类测试
        for param, param_type in param_types.items():
            # 有效等价类
            valid_case = TestCase(
                name=f"test_{func_name}_{param}_valid_equivalence",
                description=f"测试{param}参数的有效等价类",
                input_data={param: self._generate_valid_data(param_type)},
                expected_output=None,
                test_type="unit",
                priority=1,
                tags=["equivalence", "valid"]
            )
            test_cases.append(valid_case)
            
            # 无效等价类
            invalid_case = TestCase(
                name=f"test_{func_name}_{param}_invalid_equivalence",
                description=f"测试{param}参数的无效等价类",
                input_data={param: self._generate_invalid_data(param_type)},
                expected_output="Exception",
                test_type="unit",
                priority=2,
                tags=["equivalence", "invalid"]
            )
            test_cases.append(invalid_case)
        
        return test_cases
    
    def _boundary_value_analysis(self, function_info: Dict[str, Any], param_types: Dict[str, str]) -> List[TestCase]:
        """边界值分析测试"""
        test_cases = []
        func_name = function_info['name']
        
        for param, param_type in param_types.items():
            if param_type == 'number':
                # 数值边界测试
                boundary_values = [0, 1, -1, 100, -100, 999999, -999999]
                for value in boundary_values:
                    test_case = TestCase(
                        name=f"test_{func_name}_{param}_boundary_{value}",
                        description=f"测试{param}参数的边界值{value}",
                        input_data={param: value},
                        expected_output=None,
                        test_type="unit",
                        priority=2,
                        tags=["boundary", "number"]
                    )
                    test_cases.append(test_case)
            
            elif param_type == 'string':
                # 字符串边界测试
                boundary_strings = ["", "a", "a" * 255, "a" * 1000]
                for value in boundary_strings:
                    test_case = TestCase(
                        name=f"test_{func_name}_{param}_boundary_len_{len(value)}",
                        description=f"测试{param}参数的边界字符串(长度{len(value)})",
                        input_data={param: value},
                        expected_output=None,
                        test_type="unit",
                        priority=2,
                        tags=["boundary", "string"]
                    )
                    test_cases.append(test_case)
            
            elif param_type == 'list':
                # 列表边界测试
                boundary_lists = [[], [1], list(range(100)), list(range(1000))]
                for value in boundary_lists:
                    test_case = TestCase(
                        name=f"test_{func_name}_{param}_boundary_len_{len(value)}",
                        description=f"测试{param}参数的边界列表(长度{len(value)})",
                        input_data={param: value},
                        expected_output=None,
                        test_type="unit",
                        priority=2,
                        tags=["boundary", "list"]
                    )
                    test_cases.append(test_case)
        
        return test_cases
    
    def _decision_table_testing(self, function_info: Dict[str, Any], param_types: Dict[str, str]) -> List[TestCase]:
        """决策表测试"""
        test_cases = []
        func_name = function_info['name']
        
        # 如果函数有多个布尔参数,生成决策表测试
        boolean_params = [param for param, param_type in param_types.items() if param_type == 'boolean']
        
        if len(boolean_params) >= 2:
            # 生成所有布尔组合
            import itertools
            combinations = list(itertools.product([True, False], repeat=len(boolean_params)))
            
            for i, combination in enumerate(combinations):
                input_data = dict(zip(boolean_params, combination))
                test_case = TestCase(
                    name=f"test_{func_name}_decision_table_{i+1}",
                    description=f"决策表测试组合{i+1}: {input_data}",
                    input_data=input_data,
                    expected_output=None,
                    test_type="unit",
                    priority=2,
                    tags=["decision_table", "combination"]
                )
                test_cases.append(test_case)
        
        return test_cases
    
    def _state_transition_testing(self, function_info: Dict[str, Any], param_types: Dict[str, str]) -> List[TestCase]:
        """状态转换测试"""
        test_cases = []
        func_name = function_info['name']
        
        # 如果函数名暗示状态变化,生成状态转换测试
        state_keywords = ['start', 'stop', 'pause', 'resume', 'init', 'close', 'open', 'enable', 'disable']
        
        if any(keyword in func_name.lower() for keyword in state_keywords):
            test_case = TestCase(
                name=f"test_{func_name}_state_transition",
                description=f"测试{func_name}的状态转换",
                input_data={},
                expected_output=None,
                test_type="unit",
                priority=2,
                tags=["state_transition"]
            )
            test_cases.append(test_case)
        
        return test_cases
    
    def _error_guessing(self, function_info: Dict[str, Any], param_types: Dict[str, str]) -> List[TestCase]:
        """错误推测测试"""
        test_cases = []
        func_name = function_info['name']
        
        # 基于经验生成可能的错误情况
        error_scenarios = [
            {'name': 'null_input', 'data': None, 'description': '空值输入'},
            {'name': 'empty_input', 'data': '', 'description': '空字符串输入'},
            {'name': 'negative_input', 'data': -1, 'description': '负数输入'},
            {'name': 'zero_input', 'data': 0, 'description': '零值输入'},
            {'name': 'large_input', 'data': 999999999, 'description': '超大数值输入'}
        ]
        
        for scenario in error_scenarios:
            for param in param_types.keys():
                test_case = TestCase(
                    name=f"test_{func_name}_{param}_{scenario['name']}",
                    description=f"错误推测: {param}参数{scenario['description']}",
                    input_data={param: scenario['data']},
                    expected_output="Exception",
                    test_type="unit",
                    priority=3,
                    tags=["error_guessing", scenario['name']]
                )
                test_cases.append(test_case)
        
        return test_cases
    
    def _generate_valid_data(self, param_type: str) -> Any:
        """生成有效数据"""
        generators = {
            'string': lambda: "valid_string",
            'number': lambda: 42,
            'list': lambda: [1, 2, 3],
            'dict': lambda: {"key": "value"},
            'boolean': lambda: True
        }
        return generators.get(param_type, lambda: "default_value")()
    
    def _generate_invalid_data(self, param_type: str) -> Any:
        """生成无效数据"""
        # 为每种类型生成不匹配的数据
        invalid_generators = {
            'string': lambda: 123,  # 数字而不是字符串
            'number': lambda: "not_a_number",  # 字符串而不是数字
            'list': lambda: "not_a_list",  # 字符串而不是列表
            'dict': lambda: "not_a_dict",  # 字符串而不是字典
            'boolean': lambda: "not_a_boolean"  # 字符串而不是布尔值
        }
        return invalid_generators.get(param_type, lambda: object())()
    
    def _optimize_test_cases(self, test_cases: List[TestCase]) -> List[TestCase]:
        """优化测试用例"""
        # 去重
        unique_cases = []
        seen_names = set()
        
        for case in test_cases:
            if case.name not in seen_names:
                unique_cases.append(case)
                seen_names.add(case.name)
        
        # 按优先级排序
        unique_cases.sort(key=lambda x: x.priority)
        
        return unique_cases

class MockDataGenerator:
    """模拟数据生成器"""
    
    def __init__(self):
        self.patterns = {
            'email': ['test@example.com', 'user@domain.org', 'admin@company.net'],
            'name': ['张三', '李四', '王五', 'John Doe', 'Jane Smith'],
            'phone': ['13800138000', '15912345678', '+86-138-0013-8000'],
            'url': ['http://example.com', 'https://test.org', 'ftp://files.com'],
            'ip': ['192.168.1.1', '10.0.0.1', '127.0.0.1'],
            'date': ['2024-01-01', '2023-12-31', '2024-07-14'],
            'uuid': ['550e8400-e29b-41d4-a716-446655440000']
        }
    
    def generate_test_data(self, data_type: str, count: int = 1) -> List[Any]:
        """生成测试数据"""
        if data_type in self.patterns:
            import random
            return random.choices(self.patterns[data_type], k=count)
        
        # 基础类型数据生成
        generators = {
            'int': lambda: random.randint(1, 1000),
            'float': lambda: round(random.uniform(0.1, 999.9), 2),
            'str': lambda: f"test_string_{random.randint(1, 1000)}",
            'bool': lambda: random.choice([True, False]),
            'list': lambda: [random.randint(1, 100) for _ in range(random.randint(1, 10))],
            'dict': lambda: {f"key_{i}": f"value_{i}" for i in range(random.randint(1, 5))}
        }
        
        generator = generators.get(data_type, lambda: None)
        return [generator() for _ in range(count)]
    
    def generate_edge_cases(self, data_type: str) -> List[Any]:
        """生成边界情况数据"""
        edge_cases = {
            'int': [0, 1, -1, 2147483647, -2147483648],
            'float': [0.0, 0.1, -0.1, float('inf'), float('-inf')],
            'str': ['', 'a', 'A' * 1000, '中文测试', '🚀🎉'],
            'list': [[], [None], list(range(1000))],
            'dict': [{}, {'': ''}, {str(i): i for i in range(100)}]
        }
        
        return edge_cases.get(data_type, [None])
    
    def generate_invalid_data(self, expected_type: str) -> List[Any]:
        """生成无效数据"""
        # 为期望类型生成不匹配的数据
        invalid_data = {
            'int': ['not_int', [], {}, None, float('nan')],
            'float': ['not_float', [], {}, None],
            'str': [123, [], {}, None],
            'list': ['not_list', 123, {}, None],
            'dict': ['not_dict', 123, [], None],
            'bool': ['not_bool', 123, [], {}]
        }
        
        return invalid_data.get(expected_type, [object()])

# 使用示例
def demo_intelligent_test_design():
    """演示智能测试设计"""
    
    # 示例函数信息
    function_info = {
        'name': 'calculate_discount',
        'args': ['price', 'discount_rate', 'is_member'],
        'complexity': 3,
        'docstring': '计算折扣价格'
    }
    
    # 创建测试设计器
    designer = IntelligentTestCaseDesigner()
    
    # 设计测试用例
    test_cases = designer.design_test_cases(function_info)
    
    print("智能生成的测试用例:")
    print("=" * 50)
    
    for i, case in enumerate(test_cases[:10], 1):  # 只显示前10个
        print(f"{i}. {case.name}")
        print(f"   描述: {case.description}")
        print(f"   输入: {case.input_data}")
        print(f"   标签: {case.tags}")
        print(f"   优先级: {case.priority}")
        print()
    
    # 演示数据生成
    data_generator = MockDataGenerator()
    
    print("生成的测试数据示例:")
    print("=" * 50)
    
    print("邮箱数据:", data_generator.generate_test_data('email', 3))
    print("整数边界:", data_generator.generate_edge_cases('int'))
    print("字符串无效数据:", data_generator.generate_invalid_data('str'))

if __name__ == "__main__":
    demo_intelligent_test_design()

集成测试智能设计

服务间测试生成

class IntegrationTestGenerator:
    """集成测试生成器"""
    
    def __init__(self):
        self.service_dependencies = {}
        self.api_contracts = {}
        self.test_scenarios = {}
    
    def analyze_service_dependencies(self, service_map: Dict[str, Any]) -> Dict[str, List[str]]:
        """分析服务依赖关系"""
        dependencies = {}
        
        for service_name, service_info in service_map.items():
            deps = []
            
            # 分析API调用
            if 'api_calls' in service_info:
                for api_call in service_info['api_calls']:
                    target_service = api_call.get('target_service')
                    if target_service:
                        deps.append(target_service)
            
            # 分析数据库依赖
            if 'database_access' in service_info:
                deps.extend(service_info['database_access'])
            
            # 分析消息队列依赖
            if 'message_queues' in service_info:
                deps.extend(service_info['message_queues'])
            
            dependencies[service_name] = list(set(deps))
        
        return dependencies
    
    def generate_integration_tests(self, service_map: Dict[str, Any]) -> List[TestCase]:
        """生成集成测试用例"""
        test_cases = []
        
        # 分析依赖关系
        dependencies = self.analyze_service_dependencies(service_map)
        
        # 生成服务间通信测试
        test_cases.extend(self._generate_service_communication_tests(dependencies))
        
        # 生成数据流测试
        test_cases.extend(self._generate_data_flow_tests(service_map))
        
        # 生成端到端测试
        test_cases.extend(self._generate_end_to_end_tests(service_map))
        
        return test_cases
    
    def _generate_service_communication_tests(self, dependencies: Dict[str, List[str]]) -> List[TestCase]:
        """生成服务间通信测试"""
        test_cases = []
        
        for service, deps in dependencies.items():
            for dep in deps:
                # 正常通信测试
                normal_test = TestCase(
                    name=f"test_{service}_to_{dep}_communication",
                    description=f"测试{service}{dep}服务的正常通信",
                    input_data={
                        'source_service': service,
                        'target_service': dep,
                        'test_type': 'normal_communication'
                    },
                    expected_output="success",
                    test_type="integration",
                    priority=1,
                    tags=["communication", "integration"]
                )
                test_cases.append(normal_test)
                
                # 超时测试
                timeout_test = TestCase(
                    name=f"test_{service}_to_{dep}_timeout",
                    description=f"测试{service}{dep}服务的超时处理",
                    input_data={
                        'source_service': service,
                        'target_service': dep,
                        'test_type': 'timeout',
                        'timeout_duration': 5
                    },
                    expected_output="timeout_handled",
                    test_type="integration",
                    priority=2,
                    tags=["timeout", "integration"]
                )
                test_cases.append(timeout_test)
                
                # 错误处理测试
                error_test = TestCase(
                    name=f"test_{service}_to_{dep}_error_handling",
                    description=f"测试{service}{dep}服务的错误处理",
                    input_data={
                        'source_service': service,
                        'target_service': dep,
                        'test_type': 'error_response',
                        'error_code': 500
                    },
                    expected_output="error_handled",
                    test_type="integration",
                    priority=2,
                    tags=["error_handling", "integration"]
                )
                test_cases.append(error_test)
        
        return test_cases
    
    def _generate_data_flow_tests(self, service_map: Dict[str, Any]) -> List[TestCase]:
        """生成数据流测试"""
        test_cases = []
        
        # 识别数据流路径
        data_flows = self._identify_data_flows(service_map)
        
        for flow in data_flows:
            test_case = TestCase(
                name=f"test_data_flow_{flow['name']}",
                description=f"测试数据流: {' -> '.join(flow['path'])}",
                input_data={
                    'flow_path': flow['path'],
                    'test_data': flow['test_data'],
                    'expected_transformations': flow['transformations']
                },
                expected_output="data_flow_success",
                test_type="integration",
                priority=1,
                tags=["data_flow", "integration"]
            )
            test_cases.append(test_case)
        
        return test_cases
    
    def _generate_end_to_end_tests(self, service_map: Dict[str, Any]) -> List[TestCase]:
        """生成端到端测试"""
        test_cases = []
        
        # 识别业务流程
        business_flows = self._identify_business_flows(service_map)
        
        for flow in business_flows:
            test_case = TestCase(
                name=f"test_e2e_{flow['name']}",
                description=f"端到端测试: {flow['description']}",
                input_data={
                    'business_flow': flow['steps'],
                    'test_scenario': flow['scenario'],
                    'expected_outcome': flow['expected_outcome']
                },
                expected_output="business_flow_success",
                test_type="integration",
                priority=1,
                tags=["e2e", "business_flow"]
            )
            test_cases.append(test_case)
        
        return test_cases
    
    def _identify_data_flows(self, service_map: Dict[str, Any]) -> List[Dict[str, Any]]:
        """识别数据流"""
        flows = []
        
        # 简化的数据流识别
        for service_name, service_info in service_map.items():
            if 'data_inputs' in service_info and 'data_outputs' in service_info:
                flow = {
                    'name': f"{service_name}_data_flow",
                    'path': [service_name],
                    'test_data': service_info.get('sample_data', {}),
                    'transformations': service_info.get('data_transformations', [])
                }
                flows.append(flow)
        
        return flows
    
    def _identify_business_flows(self, service_map: Dict[str, Any]) -> List[Dict[str, Any]]:
        """识别业务流程"""
        flows = []
        
        # 常见的业务流程模式
        common_flows = [
            {
                'name': 'user_registration',
                'description': '用户注册流程',
                'steps': ['validate_input', 'create_user', 'send_confirmation'],
                'scenario': 'new_user_signup',
                'expected_outcome': 'user_created_and_notified'
            },
            {
                'name': 'order_processing',
                'description': '订单处理流程',
                'steps': ['validate_order', 'check_inventory', 'process_payment', 'fulfill_order'],
                'scenario': 'customer_places_order',
                'expected_outcome': 'order_completed'
            }
        ]
        
        # 根据服务映射匹配业务流程
        for flow in common_flows:
            if self._flow_matches_services(flow, service_map):
                flows.append(flow)
        
        return flows
    
    def _flow_matches_services(self, flow: Dict[str, Any], service_map: Dict[str, Any]) -> bool:
        """检查业务流程是否匹配现有服务"""
        # 简化的匹配逻辑
        flow_keywords = set(flow['name'].split('_'))
        service_keywords = set()
        
        for service_name in service_map.keys():
            service_keywords.update(service_name.lower().split('_'))
        
        # 如果有关键词匹配,认为流程适用
        return bool(flow_keywords & service_keywords)

class APIContractTester:
    """API契约测试器"""
    
    def __init__(self):
        self.contracts = {}
        self.test_generators = {}
    
    def register_api_contract(self, service_name: str, contract: Dict[str, Any]):
        """注册API契约"""
        self.contracts[service_name] = contract
    
    def generate_contract_tests(self, service_name: str) -> List[TestCase]:
        """生成契约测试"""
        if service_name not in self.contracts:
            return []
        
        contract = self.contracts[service_name]
        test_cases = []
        
        # 为每个API端点生成测试
        for endpoint in contract.get('endpoints', []):
            test_cases.extend(self._generate_endpoint_tests(service_name, endpoint))
        
        return test_cases
    
    def _generate_endpoint_tests(self, service_name: str, endpoint: Dict[str, Any]) -> List[TestCase]:
        """为端点生成测试"""
        test_cases = []
        endpoint_path = endpoint['path']
        method = endpoint['method']
        
        # 正常请求测试
        normal_test = TestCase(
            name=f"test_{service_name}_{method}_{endpoint_path.replace('/', '_')}_normal",
            description=f"测试{service_name}{method} {endpoint_path}正常请求",
            input_data={
                'method': method,
                'path': endpoint_path,
                'headers': endpoint.get('headers', {}),
                'body': endpoint.get('sample_request', {}),
                'expected_status': endpoint.get('success_status', 200)
            },
            expected_output="api_success",
            test_type="integration",
            priority=1,
            tags=["api", "contract", "normal"]
        )
        test_cases.append(normal_test)
        
        # 参数验证测试
        if 'parameters' in endpoint:
            for param in endpoint['parameters']:
                if param.get('required', False):
                    missing_param_test = TestCase(
                        name=f"test_{service_name}_{method}_{endpoint_path.replace('/', '_')}_missing_{param['name']}",
                        description=f"测试缺少必需参数{param['name']}的情况",
                        input_data={
                            'method': method,
                            'path': endpoint_path,
                            'missing_param': param['name'],
                            'expected_status': 400
                        },
                        expected_output="validation_error",
                        test_type="integration",
                        priority=2,
                        tags=["api", "validation", "error"]
                    )
                    test_cases.append(missing_param_test)
        
        # 错误状态测试
        error_statuses = endpoint.get('error_statuses', [400, 401, 403, 404, 500])
        for status in error_statuses:
            error_test = TestCase(
                name=f"test_{service_name}_{method}_{endpoint_path.replace('/', '_')}_error_{status}",
                description=f"测试{status}错误状态",
                input_data={
                    'method': method,
                    'path': endpoint_path,
                    'force_error': status,
                    'expected_status': status
                },
                expected_output="api_error",
                test_type="integration",
                priority=3,
                tags=["api", "error", f"status_{status}"]
            )
            test_cases.append(error_test)
        
        return test_cases

# 使用示例
def demo_integration_testing():
    """演示集成测试生成"""
    
    # 示例服务映射
    service_map = {
        'user_service': {
            'api_calls': [
                {'target_service': 'auth_service', 'endpoint': '/validate'},
                {'target_service': 'notification_service', 'endpoint': '/send'}
            ],
            'database_access': ['user_db'],
            'data_inputs': ['user_data'],
            'data_outputs': ['user_profile'],
            'sample_data': {'username': 'test_user', 'email': 'test@example.com'}
        },
        'auth_service': {
            'api_calls': [
                {'target_service': 'user_service', 'endpoint': '/user/info'}
            ],
            'database_access': ['auth_db'],
            'data_inputs': ['credentials'],
            'data_outputs': ['auth_token']
        },
        'notification_service': {
            'message_queues': ['email_queue', 'sms_queue'],
            'data_inputs': ['notification_request'],
            'data_outputs': ['delivery_status']
        }
    }
    
    # 生成集成测试
    generator = IntegrationTestGenerator()
    test_cases = generator.generate_integration_tests(service_map)
    
    print("生成的集成测试用例:")
    print("=" * 50)
    
    for i, case in enumerate(test_cases[:8], 1):  # 显示前8个
        print(f"{i}. {case.name}")
        print(f"   描述: {case.description}")
        print(f"   类型: {case.test_type}")
        print(f"   标签: {case.tags}")
        print()
    
    # 演示API契约测试
    contract_tester = APIContractTester()
    
    # 注册API契约
    user_api_contract = {
        'endpoints': [
            {
                'path': '/users',
                'method': 'POST',
                'parameters': [
                    {'name': 'username', 'required': True, 'type': 'string'},
                    {'name': 'email', 'required': True, 'type': 'string'}
                ],
                'success_status': 201,
                'error_statuses': [400, 409],
                'sample_request': {'username': 'testuser', 'email': 'test@example.com'}
            },
            {
                'path': '/users/{id}',
                'method': 'GET',
                'parameters': [
                    {'name': 'id', 'required': True, 'type': 'integer'}
                ],
                'success_status': 200,
                'error_statuses': [404]
            }
        ]
    }
    
    contract_tester.register_api_contract('user_service', user_api_contract)
    contract_tests = contract_tester.generate_contract_tests('user_service')
    
    print("生成的API契约测试:")
    print("=" * 50)
    
    for i, case in enumerate(contract_tests, 1):
        print(f"{i}. {case.name}")
        print(f"   描述: {case.description}")
        print(f"   标签: {case.tags}")
        print()

if __name__ == "__main__":
    demo_integration_testing()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

CarlowZJ

我的文章对你有用的话,可以支持

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值