AI Agent的工具调用体系:打造可扩展的能力框架

在上一篇文章中,我们讨论了 AI Agent 的记忆系统。今天,我想分享一下如何设计和实现一个灵活的工具调用体系。说实话,这个模块我重构了好几次,每次都有新的感悟。

从简单到复杂

最开始实现工具调用时,我用的是最简单的方式:

def run_tool(name: str, args: dict) -> str:
    if name == "search":
        return search_function(**args)
    elif name == "calculate":
        return calculate_function(**args)
    # ... 更多工具

这种方式虽然能用,但很快就遇到了问题:

  • 每加一个工具都要改代码
  • 参数验证全靠人工检查
  • 错误处理非常脆弱
  • 完全不支持异步操作

重新设计

经过几轮迭代,我总结出了一个相对完善的工具调用体系。先看整体架构:

from typing import Protocol, Dict, Any, Optional
from pydantic import BaseModel, create_model
import inspect
import asyncio
from datetime import datetime

class ToolResult(BaseModel):
    success: bool
    result: Optional[Any]
    error: Optional[str]
    execution_time: float

class Tool(Protocol):
    name: str
    description: str
    async def execute(self, **kwargs) -> ToolResult: ...
    def get_schema(self) -> Dict[str, Any]: ...

class ToolRegistry:
    def __init__(self):
        self._tools: Dict[str, Tool] = {}

    def register(self, tool: Tool):
        self._tools[tool.name] = tool

    def get_tool(self, name: str) -> Optional[Tool]:
        return self._tools.get(name)

    def list_tools(self) -> Dict[str, Dict[str, Any]]:
        return {
            name: tool.get_schema()
            for name, tool in self._tools.items()
        }

class BaseTool:
    def __init__(self, name: str, description: str):
        self.name = name
        self.description = description

    def get_schema(self) -> Dict[str, Any]:
        # 通过检查函数签名自动生成 schema
        sig = inspect.signature(self.execute)
        params = {
            name: (param.annotation, ...)
            for name, param in sig.parameters.items()
            if name != "self"
        }

        # 创建 Pydantic 模型
        model = create_model(
            f"{self.name}Args",
            **params
        )

        return {
            "name": self.name,
            "description": self.description,
            "parameters": model.schema()
        }

    async def execute(self, **kwargs) -> ToolResult:
        start_time = datetime.now()
        try:
            result = await self._execute(**kwargs)
            return ToolResult(
                success=True,
                result=result,
                error=None,
                execution_time=(
                    datetime.now() - start_time
                ).total_seconds()
            )
        except Exception as e:
            return ToolResult(
                success=False,
                result=None,
                error=str(e),
                execution_time=(
                    datetime.now() - start_time
                ).total_seconds()
            )

    async def _execute(self, **kwargs) -> Any:
        raise NotImplementedError

有了这个基础框架,我们就可以很方便地添加新工具了:

from pydantic import BaseModel

class SearchArgs(BaseModel):
    query: str
    limit: int = 10

class SearchTool(BaseTool):
    def __init__(self):
        super().__init__(
            name="search",
            description="搜索知识库中的相关内容"
        )

    async def _execute(
        self,
        query: str,
        limit: int = 10
    ) -> List[Dict]:
        results = await vector_store.search(
            query=query,
            limit=limit
        )
        return [
            {
                "content": r.content,
                "score": r.score
            }
            for r in results
        ]

class CalculatorArgs(BaseModel):
    expression: str

class CalculatorTool(BaseTool):
    def __init__(self):
        super().__init__(
            name="calculator",
            description="计算数学表达式"
        )

    async def _execute(
        self,
        expression: str
    ) -> float:
        # 安全的表达式求值
        return await self._safe_eval(expression)

使用示例:

# 初始化工具注册表
registry = ToolRegistry()

# 注册工具
registry.register(SearchTool())
registry.register(CalculatorTool())

# 获取工具列表
tools = registry.list_tools()
print(tools)
# 输出:
# {
#     "search": {
#         "name": "search",
#         "description": "搜索知识库中的相关内容",
#         "parameters": {
#             "type": "object",
#             "properties": {
#                 "query": {"type": "string"},
#                 "limit": {"type": "integer", "default": 10}
#             },
#             "required": ["query"]
#         }
#     },
#     "calculator": {
#         "name": "calculator",
#         "description": "计算数学表达式",
#         "parameters": {
#             "type": "object",
#             "properties": {
#                 "expression": {"type": "string"}
#             },
#             "required": ["expression"]
#         }
#     }
# }

# 执行工具
search_tool = registry.get_tool("search")
result = await search_tool.execute(
    query="Python装饰器",
    limit=5
)
print(result)
# 输出:
# {
#     "success": true,
#     "result": [...],
#     "error": null,
#     "execution_time": 0.123
# }

进阶特性

基础功能有了,接下来我们来添加一些进阶特性:

1. 权限控制

from enum import Enum
from typing import Set

class PermissionLevel(Enum):
    READ = "read"
    WRITE = "write"
    ADMIN = "admin"

class SecureTool(BaseTool):
    def __init__(
        self,
        name: str,
        description: str,
        required_permissions: Set[PermissionLevel]
    ):
        super().__init__(name, description)
        self.required_permissions = required_permissions

    async def execute(
        self,
        user_permissions: Set[PermissionLevel],
        **kwargs
    ) -> ToolResult:
        # 检查权限
        if not self.required_permissions.issubset(
            user_permissions
        ):
            return ToolResult(
                success=False,
                error="权限不足",
                execution_time=0
            )

        return await super().execute(**kwargs)

class DeleteTool(SecureTool):
    def __init__(self):
        super().__init__(
            name="delete",
            description="删除文档",
            required_permissions={
                PermissionLevel.WRITE,
                PermissionLevel.ADMIN
            }
        )

2. 速率限制

from datetime import datetime, timedelta
from collections import deque

class RateLimitedTool(BaseTool):
    def __init__(
        self,
        name: str,
        description: str,
        max_calls: int,
        time_window: timedelta
    ):
        super().__init__(name, description)
        self.max_calls = max_calls
        self.time_window = time_window
        self.call_history = deque()

    async def execute(self, **kwargs) -> ToolResult:
        # 清理过期的调用记录
        now = datetime.now()
        while (
            self.call_history and
            now - self.call_history[0] > self.time_window
        ):
            self.call_history.popleft()

        # 检查是否超过限制
        if len(self.call_history) >= self.max_calls:
            return ToolResult(
                success=False,
                error="调用频率超限",
                execution_time=0
            )

        # 记录本次调用
        self.call_history.append(now)

        return await super().execute(**kwargs)

class APITool(RateLimitedTool):
    def __init__(self):
        super().__init__(
            name="api_call",
            description="调用外部API",
            max_calls=100,
            time_window=timedelta(hours=1)
        )

3. 链式调用

class ToolChain:
    def __init__(self, tools: List[Tool]):
        self.tools = tools

    async def execute(
        self,
        initial_input: Dict[str, Any]
    ) -> List[ToolResult]:
        results = []
        current_input = initial_input

        for tool in self.tools:
            result = await tool.execute(**current_input)
            results.append(result)

            if not result.success:
                break

            # 更新下一个工具的输入
            current_input = self._process_result(
                result.result
            )

        return results

    def _process_result(
        self,
        result: Any
    ) -> Dict[str, Any]:
        # 处理工具输出,转换为下一个工具的输入
        # 具体实现依赖于业务逻辑
        pass

# 使用示例
chain = ToolChain([
    SearchTool(),
    SummarizeTool(),
    TranslateTool()
])

results = await chain.execute({
    "query": "Python装饰器"
})

4. 错误重试

from typing import Type, Tuple
import asyncio

class RetryableTool(BaseTool):
    def __init__(
        self,
        name: str,
        description: str,
        max_retries: int = 3,
        retry_delay: float = 1.0,
        retry_exceptions: Tuple[Type[Exception], ...] = (
            Exception,
        )
    ):
        super().__init__(name, description)
        self.max_retries = max_retries
        self.retry_delay = retry_delay
        self.retry_exceptions = retry_exceptions

    async def execute(self, **kwargs) -> ToolResult:
        attempts = 0
        last_error = None

        while attempts < self.max_retries:
            try:
                return await super().execute(**kwargs)
            except self.retry_exceptions as e:
                attempts += 1
                last_error = e

                if attempts < self.max_retries:
                    await asyncio.sleep(
                        self.retry_delay * attempts
                    )

        return ToolResult(
            success=False,
            error=f"重试{attempts}次后失败: {last_error}",
            execution_time=0
        )

class UnstableAPITool(RetryableTool):
    def __init__(self):
        super().__init__(
            name="unstable_api",
            description="调用不稳定的API",
            max_retries=3,
            retry_delay=2.0,
            retry_exceptions=(
                ConnectionError,
                TimeoutError
            )
        )

实践心得

在实现这个工具调用体系的过程中,我学到了几点:

  1. 接口设计很重要

    • 保持简单统一
    • 预留扩展空间
    • 考虑向后兼容
  2. 安全性第一

    • 参数严格验证
    • 权限精细控制
    • 资源使用限制
  3. 可维护性是关键

    • 完善的错误处理
    • 详细的日志记录
    • 方便的调试接口

写在最后

一个好的工具调用体系应该像乐高积木一样,每个工具都是一个标准化的模块,可以自由组合,又能保证安全和可靠。

在下一篇文章中,我会讲解如何实现 AI Agent 的规划系统。如果你对工具调用体系的设计有什么想法,欢迎在评论区交流。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值