unittest 框架全面使用教程
本教程将展示如何使用 Python 内置的 unittest
框架实现与 pytest 相同的测试功能。unittest
是 Python 标准库的一部分,提供了完整的测试解决方案,包括测试发现、组织和执行功能。
目录
基本测试结构
import unittest
import logging
from datetime import datetime
# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 示例业务逻辑函数
def add(a, b):
return a + b
def divide(a, b):
if b == 0:
raise ValueError("除数不能为零")
return a / b
def create_user(username, role="user"):
if not username:
raise ValueError("用户名不能为空")
return {"username": username, "role": role, "created_at": datetime.now()}
class TestBasicOperations(unittest.TestCase):
"""基础测试示例"""
def test_addition(self):
"""测试加法"""
result = add(2, 3)
self.assertEqual(result, 5, "2+3应该等于5")
def test_division(self):
"""测试除法"""
result = divide(10, 2)
self.assertAlmostEqual(result, 5.0, places=2)
def test_user_creation(self):
"""测试用户创建"""
user = create_user("test_user")
self.assertEqual(user["role"], "user")
self.assertIsInstance(user["created_at"], datetime)
def test_negative_addition(self):
"""测试负数加法"""
result = add(-5, -3)
self.assertEqual(result, -8)
if __name__ == "__main__":
unittest.main()
参数化测试
unittest 本身不支持参数化,但可以通过 subTest
实现类似功能:
class TestParametrizedOperations(unittest.TestCase):
"""参数化测试示例"""
def test_parametrized_addition(self):
"""参数化加法测试"""
test_cases = [
(1, 2, 3), # 正整数
(0, 0, 0), # 零
(-1, 1, 0), # 负数与正数
(2.5, 3.5, 6.0) # 浮点数
]
for a, b, expected in test_cases:
with self.subTest(a=a, b=b, expected=expected):
result = add(a, b)
self.assertEqual(result, expected, f"{a}+{b}应该等于{expected}")
def test_parametrized_division(self):
"""参数化除法测试"""
test_cases = [
(10, 2, 5.0),
(15, 3, 5.0),
(1, 4, 0.25)
]
for a, b, expected in test_cases:
with self.subTest(a=a, b=b, expected=expected):
result = divide(a, b)
self.assertAlmostEqual(result, expected, places=4)
测试夹具
unittest 使用 setUp()
和 tearDown()
方法管理测试资源:
class TestDatabaseOperations(unittest.TestCase):
"""测试夹具示例"""
@classmethod
def setUpClass(cls):
"""类级别设置 - 整个测试类执行一次"""
cls.logger = logging.getLogger(f"{__name__}.{cls.__name__}")
cls.logger.info("\n=== 建立数据库连接 ===")
# 模拟数据库连接
cls.db = {"connected": True, "users": []}
@classmethod
def tearDownClass(cls):
"""类级别清理 - 整个测试类执行一次"""
cls.logger.info("\n=== 关闭数据库连接 ===")
cls.db["connected"] = False
def setUp(self):
"""测试方法级别设置 - 每个测试方法执行前调用"""
self.logger.info(f"\n开始测试: {self.id()}")
# 重置用户列表
self.db["users"] = []
def tearDown(self):
"""测试方法级别清理 - 每个测试方法执行后调用"""
self.logger.info(f"测试完成: {self.id()}")
def test_admin_user_creation(self):
"""测试管理员用户创建"""
user = create_user("admin_user", "admin")
self.db["users"].append(user)
self.assertEqual(user["role"], "admin")
self.assertEqual(len(self.db["users"]), 1)
def test_regular_user_creation(self):
"""测试普通用户创建"""
user = create_user("regular_user")
self.db["users"].append(user)
self.assertEqual(user["role"], "user")
self.assertEqual(len(self.db["users"]), 1)
def test_database_connection(self):
"""测试数据库连接状态"""
self.assertTrue(self.db["connected"])
跳过测试
unittest 提供了多种跳过测试的方式:
class TestSkipOperations(unittest.TestCase):
"""跳过测试示例"""
@unittest.skip("功能尚未实现")
def test_unimplemented_feature(self):
"""跳过未实现的功能测试"""
self.fail("这个测试不应该执行")
@unittest.skipIf(sys.version_info < (3, 8), "需要Python 3.8+")
def test_python38_feature(self):
"""条件跳过测试"""
# Python 3.8+ 的特性
self.assertEqual((x := 5), 5) # 海象运算符
@unittest.expectedFailure
def test_known_bug(self):
"""预期失败的测试"""
result = add(0.1, 0.2)
self.assertEqual(result, 0.3) # 浮点数精度问题
def test_skip_dynamically(self):
"""动态跳过测试"""
if os.environ.get("SKIP_TEST") == "true":
self.skipTest("环境变量要求跳过此测试")
# 正常测试逻辑
self.assertTrue(True)
异常测试
unittest 提供 assertRaises
方法来测试异常:
class TestExceptionHandling(unittest.TestCase):
"""异常测试示例"""
def test_divide_by_zero(self):
"""测试除零异常"""
with self.assertRaises(ValueError) as context:
divide(10, 0)
self.assertEqual(str(context.exception), "除数不能为零")
def test_invalid_username(self):
"""测试无效用户名"""
invalid_usernames = ["", None, " "]
for username in invalid_usernames:
with self.subTest(username=username):
with self.assertRaises(ValueError) as context:
create_user(username)
self.assertIn("用户名不能为空", str(context.exception))
临时文件处理
使用 tempfile
模块处理临时文件和目录:
import tempfile
import shutil
class TestFileOperations(unittest.TestCase):
"""临时文件处理示例"""
def setUp(self):
"""创建临时目录"""
self.test_dir = tempfile.mkdtemp()
self.logger.info(f"创建临时目录: {self.test_dir}")
def tearDown(self):
"""清理临时目录"""
shutil.rmtree(self.test_dir)
self.logger.info(f"删除临时目录: {self.test_dir}")
def test_file_operations(self):
"""测试文件操作"""
# 创建文件
test_file = os.path.join(self.test_dir, "test.txt")
with open(test_file, "w") as f:
f.write("Hello unittest!")
# 验证文件
self.assertTrue(os.path.exists(test_file))
# 读取内容
with open(test_file, "r") as f:
content = f.read()
self.assertEqual(content, "Hello unittest!")
def test_subdirectory_creation(self):
"""测试子目录创建"""
sub_dir = os.path.join(self.test_dir, "subdir")
os.makedirs(sub_dir)
self.assertTrue(os.path.isdir(sub_dir))
输出捕获
unittest 提供 assertLogs
和重定向方法捕获输出:
class TestOutputCapture(unittest.TestCase):
"""输出捕获示例"""
def test_stdout_capture(self):
"""测试标准输出捕获"""
with self.assertLogs(logger, level="INFO") as log_context:
logger.info("这是一条信息日志")
logger.warning("这是一条警告日志")
# 验证日志输出
self.assertIn("这是一条信息日志", log_context.output[0])
self.assertIn("这是一条警告日志", log_context.output[1])
def test_stderr_capture(self):
"""测试标准错误捕获"""
# 使用 StringIO 捕获 stderr
from io import StringIO
stderr_buffer = StringIO()
# 重定向 stderr
with unittest.mock.patch("sys.stderr", stderr_buffer):
print("错误消息", file=sys.stderr)
sys.stderr.write("另一条错误")
# 验证捕获内容
output = stderr_buffer.getvalue()
self.assertIn("错误消息", output)
self.assertIn("另一条错误", output)
并发测试
使用 concurrent.futures
实现并发测试:
import concurrent.futures
class TestConcurrentOperations(unittest.TestCase):
"""并发测试示例"""
def test_concurrent_execution(self):
"""并发执行测试"""
# 创建测试任务
tasks = [self._worker_function(i) for i in range(10)]
# 使用线程池执行
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
results = list(executor.map(lambda f: f(), tasks))
# 验证所有结果
self.assertTrue(all(results))
def _worker_function(self, index):
"""工作函数"""
def _task():
import time
time.sleep(0.1) # 模拟工作负载
return index < 10 # 总是成功
return _task
测试报告
生成文本报告
if __name__ == "__main__":
# 创建测试加载器
loader = unittest.TestLoader()
# 发现并加载所有测试
suite = loader.discover(".")
# 运行测试并生成文本报告
with open("test_report.txt", "w") as report_file:
runner = unittest.TextTestRunner(stream=report_file, verbosity=2)
result = runner.run(suite)
# 打印摘要
print(f"测试运行完成: {result.testsRun} 个测试")
print(f"失败: {len(result.failures)}")
print(f"错误: {len(result.errors)}")
print(f"跳过: {len(result.skipped)}")
生成 HTML 报告
安装扩展库:
pip install unittest-html-report
生成 HTML 报告:
from unittest_html_report import HTMLTestRunner
if __name__ == "__main__":
loader = unittest.TestLoader()
suite = loader.discover(".")
with open("test_report.html", "wb") as report_file:
runner = HTMLTestRunner(
stream=report_file,
title="单元测试报告",
description="系统功能测试结果",
verbosity=2
)
runner.run(suite)
高级功能
自定义测试加载器
class CustomTestLoader(unittest.TestLoader):
"""自定义测试加载器"""
def loadTestsFromModule(self, module):
"""从模块加载测试"""
tests = super().loadTestsFromModule(module)
# 过滤掉标记为慢速的测试
filtered = unittest.TestSuite()
for test in tests:
if not self._is_slow_test(test):
filtered.addTest(test)
return filtered
def _is_slow_test(self, test):
"""检查测试是否标记为慢速"""
# 实际实现中可以通过自定义属性或方法判断
return "slow" in str(test).lower()
if __name__ == "__main__":
loader = CustomTestLoader()
suite = loader.discover(".")
runner = unittest.TextTestRunner(verbosity=2)
runner.run(suite)
猴子补丁
使用 unittest.mock
实现猴子补丁:
class TestMonkeyPatch(unittest.TestCase):
"""猴子补丁示例"""
def test_monkeypatch_example(self):
"""使用猴子补丁修改环境"""
# 修改环境变量
with unittest.mock.patch.dict("os.environ", {"APP_ENV": "testing"}):
self.assertEqual(os.getenv("APP_ENV"), "testing")
# 修改函数行为
def mock_add(a, b):
return 42
with unittest.mock.patch("__main__.add", mock_add):
self.assertEqual(add(2, 3), 42)
# 验证原始行为恢复
self.assertEqual(add(2, 3), 5)
def test_mock_network_call(self):
"""模拟网络调用"""
# 创建模拟响应
mock_response = unittest.mock.MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"result": "success"}
# 替换网络请求函数
with unittest.mock.patch("requests.get", return_value=mock_response) as mock_get:
# 调用被测试函数
result = make_network_request()
# 验证结果
self.assertEqual(result, "success")
# 验证函数调用
mock_get.assert_called_once_with("https://api.example.com/data")
测试执行顺序控制
class TestExecutionOrder(unittest.TestCase):
"""测试执行顺序控制"""
@classmethod
def setUpClass(cls):
cls.execution_order = []
def test_first(self):
"""第一个测试"""
self.execution_order.append(1)
self.logger.info("首先执行")
def test_third(self):
"""第三个测试"""
self.execution_order.append(3)
self.logger.info("第三执行")
# 验证执行顺序
self.assertEqual(self.execution_order, [1, 2, 3])
def test_second(self):
"""第二个测试"""
self.execution_order.append(2)
self.logger.info("第二执行")
# 使用 TestSuite 控制顺序
@classmethod
def make_ordered_suite(cls):
"""创建有序的测试套件"""
suite = unittest.TestSuite()
suite.addTest(cls("test_first"))
suite.addTest(cls("test_second"))
suite.addTest(cls("test_third"))
return suite
if __name__ == "__main__":
# 运行有序测试套件
suite = TestExecutionOrder.make_ordered_suite()
runner = unittest.TextTestRunner(verbosity=2)
runner.run(suite)
unittest 最佳实践
-
组织结构:
project/ ├── src/ │ └── my_module.py └── tests/ ├── unit/ │ ├── test_module_a.py │ └── test_module_b.py ├── integration/ │ └── test_integration.py └── __init__.py
-
命名规范:
- 测试文件:
test_*.py
- 测试类:
Test*
- 测试方法:
test_*
- 测试文件:
-
测试发现:
# 运行所有测试 python -m unittest discover # 运行特定模块 python -m unittest tests.unit.test_module_a # 运行特定测试类 python -m unittest tests.unit.test_module_a.TestFeature # 运行特定测试方法 python -m unittest tests.unit.test_module_a.TestFeature.test_specific_case
-
测试覆盖率:
# 安装 coverage pip install coverage # 运行测试并收集覆盖率 coverage run -m unittest discover # 生成报告 coverage report coverage html # 生成HTML报告
-
与 CI/CD 集成:
# .gitlab-ci.yml 示例 test: image: python:3.9 script: - pip install coverage - coverage run -m unittest discover - coverage report - coverage xml # 用于集成分析
unittest vs pytest 对比
功能 | unittest | pytest |
---|---|---|
基本结构 | 需要继承 TestCase | 无需继承,普通函数即可 |
参数化 | 使用 subTest | 内置 @pytest.mark.parametrize |
测试夹具 | setUp/tearDown 方法 | @pytest.fixture 装饰器 |
跳过测试 | @unittest.skip | @pytest.mark.skip |
异常测试 | self.assertRaises | pytest.raises |
插件系统 | 有限支持 | 丰富的插件生态系统 |
测试发现 | 内置 | 更智能的发现机制 |
报告生成 | 需要扩展 | 内置多种报告格式 |
并发测试 | 需手动实现 | pytest-xdist 插件支持 |
猴子补丁 | unittest.mock | pytest monkeypatch 夹具 |
总结
unittest 是 Python 标准库的一部分,提供了强大的测试框架,无需额外依赖即可使用。虽然在某些高级功能上不如 pytest 便捷,但它提供了:
- 完整的测试结构和生命周期管理
- 丰富的断言方法
- 灵活的测试发现和执行机制
- 与 Python 生态系统的良好集成
- 适合大型项目和团队协作的稳定基础
通过本教程中的示例,您应该能够使用 unittest 实现各种测试场景,包括参数化测试、资源管理、异常处理和报告生成。对于需要更高级功能的项目,可以考虑结合使用 unittest 和 pytest,或者探索 unittest 的扩展库。