safe_tool.py
import asyncio
import functools
import traceback
import sys
from functools import wraps
from fastapi.responses import FileResponse, JSONResponse
from tools.logs import logger
# 全局开关
ENABLE_TRY = True
def catch_exception_api(func):
"""装饰器:根据全局开关决定是否捕获异常"""
@wraps(func)
async def async_wrapper(*args, **kwargs):
global ENABLE_TRY
if ENABLE_TRY:
try:
return await func(*args, **kwargs)
except Exception as e:
# 获取异常信息
exc_type, exc_value, exc_tb = sys.exc_info()
if exc_tb:
tb = traceback.extract_tb(exc_tb)[-1] # 最后一级调用栈
filename = tb.filename
lineno = tb.lineno
else:
filename = "unknown"
lineno = 0
# 记录详细异常信息
logger.error(f"API Error in {func.__name__}: {str(e)}")
logger.error(f"Error location: {filename}, line {lineno}")
logger.error(f"Full traceback: {traceback.format_exc()}")
return JSONResponse(
status_code=500,
content={
"status": "error",
"msg": str(e),
"file": filename,
"line": lineno
}
)
else:
return await func(*args, **kwargs)
@wraps(func)
def sync_wrapper(*args, **kwargs):
global ENABLE_TRY
if ENABLE_TRY:
try:
return func(*args, **kwargs)
except Exception as e:
# 获取异常信息
exc_type, exc_value, exc_tb = sys.exc_info()
if exc_tb:
tb = traceback.extract_tb(exc_tb)[-1]
filename = tb.filename
lineno = tb.lineno
else:
filename = "unknown"
lineno = 0
logger.error(f"API Error in {func.__name__}: {str(e)}")
logger.error(f"Error location: {filename}, line {lineno}")
logger.error(f"Full traceback: {traceback.format_exc()}")
return JSONResponse(
status_code=500,
content={
"status": "error",
"message": str(e),
"file": filename,
"line": lineno
}
)
else:
return func(*args, **kwargs)
# 根据函数类型返回对应的包装器
if asyncio.iscoroutinefunction(func):
return async_wrapper
else:
return sync_wrapper
def catch_exceptions(return_value=None, log_level='error', include_traceback=True):
"""
改进的异常捕获装饰器
Args:
return_value: 异常时返回的值
log_level: 日志级别 ('error', 'warning', 'info')
include_traceback: 是否包含完整堆栈跟踪
"""
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
global ENABLE_TRY
if not ENABLE_TRY:
return func(*args, **kwargs)
try:
return func(*args, **kwargs)
except Exception as e:
# 记录日志
log_method = getattr(logger, log_level)
log_method(f"Function {func.__name__} failed: {str(e)}")
if include_traceback:
logger.error(f"Traceback for {func.__name__}:\n{traceback.format_exc()}")
return return_value
return wrapper
return decorator
def catch_exceptions_old(func):
"""装饰器:根据全局开关决定是否捕获异常"""
def wrapper(*args, **kwargs):
global ENABLE_TRY
if ENABLE_TRY:
try:
return func(*args, **kwargs)
except Exception as e:
# 获取异常信息
exc_type, exc_value, exc_tb = sys.exc_info()
tb = traceback.extract_tb(exc_tb)[-1] # 最后一级调用栈
filename = tb.filename
lineno = tb.lineno
logger.error(f"catch err: {func.__name__}: {e}")
logger.error(f"err loc: file:{filename}, lineno:{lineno}")
return None
else:
return func(*args, **kwargs)
return wrapper
# ================= 使用示例 =================
@catch_exceptions(return_value=-1)
def div_val(a, b):
return a / b
if __name__ == '__main__':
# print("\n>>> 开关关闭时(不捕获异常)")
ENABLE_TRY = 0
aaa =div_val(10, 0) # 会直接抛 ZeroDivisionError
print('res',aaa)