在上一篇文章中,我们讨论了 AI Agent 的记忆系统。今天,我想分享一下如何设计和实现一个灵活的工具调用体系。说实话,这个模块我重构了好几次,每次都有新的感悟。
从简单到复杂
最开始实现工具调用时,我用的是最简单的方式:
def run_tool(name: str, args: dict) -> str:
if name == "search":
return search_function(**args)
elif name == "calculate":
return calculate_function(**args)
# ... 更多工具
这种方式虽然能用,但很快就遇到了问题:
- 每加一个工具都要改代码
- 参数验证全靠人工检查
- 错误处理非常脆弱
- 完全不支持异步操作
重新设计
经过几轮迭代,我总结出了一个相对完善的工具调用体系。先看整体架构:
from typing import Protocol, Dict, Any, Optional
from pydantic import BaseModel, create_model
import inspect
import asyncio
from datetime import datetime
class ToolResult(BaseModel):
success: bool
result: Optional[Any]
error: Optional[str]
execution_time: float
class Tool(Protocol):
name: str
description: str
async def execute(self, **kwargs) -> ToolResult: ...
def get_schema(self) -> Dict[str, Any]: ...
class ToolRegistry:
def __init__(self):
self._tools: Dict[str, Tool] = {}
def register(self, tool: Tool):
self._tools[tool.name] = tool
def get_tool(self, name: str) -> Optional[Tool]:
return self._tools.get(name)
def list_tools(self) -> Dict[str, Dict[str, Any]]:
return {
name: tool.get_schema()
for name, tool in self._tools.items()
}
class BaseTool:
def __init__(self, name: str, description: str):
self.name = name
self.description = description
def get_schema(self) -> Dict[str, Any]:
# 通过检查函数签名自动生成 schema
sig = inspect.signature(self.execute)
params = {
name: (param.annotation, ...)
for name, param in sig.parameters.items()
if name != "self"
}
# 创建 Pydantic 模型
model = create_model(
f"{self.name}Args",
**params
)
return {
"name": self.name,
"description": self.description,
"parameters": model.schema()
}
async def execute(self, **kwargs) -> ToolResult:
start_time = datetime.now()
try:
result = await self._execute(**kwargs)
return ToolResult(
success=True,
result=result,
error=None,
execution_time=(
datetime.now() - start_time
).total_seconds()
)
except Exception as e:
return ToolResult(
success=False,
result=None,
error=str(e),
execution_time=(
datetime.now() - start_time
).total_seconds()
)
async def _execute(self, **kwargs) -> Any:
raise NotImplementedError
有了这个基础框架,我们就可以很方便地添加新工具了:
from pydantic import BaseModel
class SearchArgs(BaseModel):
query: str
limit: int = 10
class SearchTool(BaseTool):
def __init__(self):
super().__init__(
name="search",
description="搜索知识库中的相关内容"
)
async def _execute(
self,
query: str,
limit: int = 10
) -> List[Dict]:
results = await vector_store.search(
query=query,
limit=limit
)
return [
{
"content": r.content,
"score": r.score
}
for r in results
]
class CalculatorArgs(BaseModel):
expression: str
class CalculatorTool(BaseTool):
def __init__(self):
super().__init__(
name="calculator",
description="计算数学表达式"
)
async def _execute(
self,
expression: str
) -> float:
# 安全的表达式求值
return await self._safe_eval(expression)
使用示例:
# 初始化工具注册表
registry = ToolRegistry()
# 注册工具
registry.register(SearchTool())
registry.register(CalculatorTool())
# 获取工具列表
tools = registry.list_tools()
print(tools)
# 输出:
# {
# "search": {
# "name": "search",
# "description": "搜索知识库中的相关内容",
# "parameters": {
# "type": "object",
# "properties": {
# "query": {"type": "string"},
# "limit": {"type": "integer", "default": 10}
# },
# "required": ["query"]
# }
# },
# "calculator": {
# "name": "calculator",
# "description": "计算数学表达式",
# "parameters": {
# "type": "object",
# "properties": {
# "expression": {"type": "string"}
# },
# "required": ["expression"]
# }
# }
# }
# 执行工具
search_tool = registry.get_tool("search")
result = await search_tool.execute(
query="Python装饰器",
limit=5
)
print(result)
# 输出:
# {
# "success": true,
# "result": [...],
# "error": null,
# "execution_time": 0.123
# }
进阶特性
基础功能有了,接下来我们来添加一些进阶特性:
1. 权限控制
from enum import Enum
from typing import Set
class PermissionLevel(Enum):
READ = "read"
WRITE = "write"
ADMIN = "admin"
class SecureTool(BaseTool):
def __init__(
self,
name: str,
description: str,
required_permissions: Set[PermissionLevel]
):
super().__init__(name, description)
self.required_permissions = required_permissions
async def execute(
self,
user_permissions: Set[PermissionLevel],
**kwargs
) -> ToolResult:
# 检查权限
if not self.required_permissions.issubset(
user_permissions
):
return ToolResult(
success=False,
error="权限不足",
execution_time=0
)
return await super().execute(**kwargs)
class DeleteTool(SecureTool):
def __init__(self):
super().__init__(
name="delete",
description="删除文档",
required_permissions={
PermissionLevel.WRITE,
PermissionLevel.ADMIN
}
)
2. 速率限制
from datetime import datetime, timedelta
from collections import deque
class RateLimitedTool(BaseTool):
def __init__(
self,
name: str,
description: str,
max_calls: int,
time_window: timedelta
):
super().__init__(name, description)
self.max_calls = max_calls
self.time_window = time_window
self.call_history = deque()
async def execute(self, **kwargs) -> ToolResult:
# 清理过期的调用记录
now = datetime.now()
while (
self.call_history and
now - self.call_history[0] > self.time_window
):
self.call_history.popleft()
# 检查是否超过限制
if len(self.call_history) >= self.max_calls:
return ToolResult(
success=False,
error="调用频率超限",
execution_time=0
)
# 记录本次调用
self.call_history.append(now)
return await super().execute(**kwargs)
class APITool(RateLimitedTool):
def __init__(self):
super().__init__(
name="api_call",
description="调用外部API",
max_calls=100,
time_window=timedelta(hours=1)
)
3. 链式调用
class ToolChain:
def __init__(self, tools: List[Tool]):
self.tools = tools
async def execute(
self,
initial_input: Dict[str, Any]
) -> List[ToolResult]:
results = []
current_input = initial_input
for tool in self.tools:
result = await tool.execute(**current_input)
results.append(result)
if not result.success:
break
# 更新下一个工具的输入
current_input = self._process_result(
result.result
)
return results
def _process_result(
self,
result: Any
) -> Dict[str, Any]:
# 处理工具输出,转换为下一个工具的输入
# 具体实现依赖于业务逻辑
pass
# 使用示例
chain = ToolChain([
SearchTool(),
SummarizeTool(),
TranslateTool()
])
results = await chain.execute({
"query": "Python装饰器"
})
4. 错误重试
from typing import Type, Tuple
import asyncio
class RetryableTool(BaseTool):
def __init__(
self,
name: str,
description: str,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_exceptions: Tuple[Type[Exception], ...] = (
Exception,
)
):
super().__init__(name, description)
self.max_retries = max_retries
self.retry_delay = retry_delay
self.retry_exceptions = retry_exceptions
async def execute(self, **kwargs) -> ToolResult:
attempts = 0
last_error = None
while attempts < self.max_retries:
try:
return await super().execute(**kwargs)
except self.retry_exceptions as e:
attempts += 1
last_error = e
if attempts < self.max_retries:
await asyncio.sleep(
self.retry_delay * attempts
)
return ToolResult(
success=False,
error=f"重试{attempts}次后失败: {last_error}",
execution_time=0
)
class UnstableAPITool(RetryableTool):
def __init__(self):
super().__init__(
name="unstable_api",
description="调用不稳定的API",
max_retries=3,
retry_delay=2.0,
retry_exceptions=(
ConnectionError,
TimeoutError
)
)
实践心得
在实现这个工具调用体系的过程中,我学到了几点:
接口设计很重要
- 保持简单统一
- 预留扩展空间
- 考虑向后兼容
安全性第一
- 参数严格验证
- 权限精细控制
- 资源使用限制
可维护性是关键
- 完善的错误处理
- 详细的日志记录
- 方便的调试接口
写在最后
一个好的工具调用体系应该像乐高积木一样,每个工具都是一个标准化的模块,可以自由组合,又能保证安全和可靠。
在下一篇文章中,我会讲解如何实现 AI Agent 的规划系统。如果你对工具调用体系的设计有什么想法,欢迎在评论区交流。