第4章 ReAct 的形式化与核心实现

=======================未经允许,不得转载,侵权必究=========================

目标:以数学严谨性 + 工程可落地性为准则,全面重构 ReAct 系统的核心组件——T-A-O 循环、Prompt Schema、解析器、工具契约、并发策略。面向高可靠性、低成本、可审计的生产部署。


4.1 T-A-O 循环的形式化定义

4.1.0 概览(专家陈述)


4.1.1 形式化模型:POMDP 视角下的 ReAct

定义 1:POMDP 表示法

设计权衡说明:
  • 尽管理论上可以建模为完整的 POMDP,实际中由于 LLM 输出不确定性大,我们将其视为非平稳策略驱动的模拟器而非严格贝叶斯推断系统。
  • 因此,重点在于设计状态管理器 UpdateState 和 解析器 Parser 来稳定策略执行路径。

4.1.2 回合终止条件与成本度量(精确定义)

定义 2:终止条件 Predicate Set

定义 3:成本向量与标量化函数

⚠️ 工程提示:在 prompt 中注入当前预算余量(如 {budget_left: {token: 2000, usd: 0.1}}),并要求 LLM 在决策中考虑该约束。


4.2 Prompt Schema 与指令工程(Prompt Engineering)

4.2.1 指令段(System Prompt)设计要素(合约化)

系统 prompt 应作为“策略接口契约”,包含以下内容:

类别内容示例
输出格式要求“请仅输出合法 JSON,符合指定 schema。”
工具列表声明"available_tools": [{"name": "search", "signature": "q:string"}]
安全与合规约束“不得访问隐私字段”、“遇到敏感词需请求用户确认”
效用与预算指引“剩余 token 数为 2000,请优先使用本地缓存工具。”
自检与修正指令“若解析失败,请在 100 tokens 内返回修正后的 JSON。”

✅ 工程实践:将 system prompt 版本化存储(如 system_v1.2.txt),并随每次部署进行回归测试。


4.2.2 Few-Shot 示例:选择原则与动态检索

原则:
  • 覆盖率优先:涵盖正常流程、边界情况、错误处理路径。
  • 最小上下文开销:避免长文本挤占上下文窗口。
  • 纠错教学:提供“原始错误 → 修正样例”对帮助模型学习。
  • 版本同步:示例与当前 schema/output contract 同步更新。
推荐做法:
  • 使用 embedding 检索引擎(如 FAISS)动态匹配最相关 few-shot 示例。
  • 缓存高频示例至内存中加速检索。
  • 示例质量由人工标注团队定期审核。

4.2.3 输出约束(严格 schema + 字段语义)

推荐最小字段集(JSON Schema):
{
  "type": "object",
  "properties": {
    "thought": { "type": "string", "maxLength": 500 },
    "action": {
      "oneOf": [
        {
          "type": "object",
          "properties": {
            "tool": { "enum": ["search", "calculator", "..."] },
            "input": { "type": "object" }
          },
          "required": ["tool"]
        },
        { "type": "null" }
      ]
    },
    "answer": { "type": ["string", "null"] },
    "confidence": { "type": "number", "minimum": 0, "maximum": 1 },
    "debug": { "type": ["object", "null"], "additionalProperties": true }
  },
  "required": ["thought"]
}

🔐 工程提醒:所有 action.tool 必须白名单校验;action.input 需再次 schema 校验防止注入攻击。


4.3 输出解析器(Parser)设计(工程化、可证明的鲁棒性)

4.3.1 层级解析策略(推荐顺序)

层级方法描述
1Strict JSON Parse直接校验纯 JSON(最优路径)
2Lenient Extractor提取首段 JSON 块,去除冗余内容
3Regex Fallback仅限固定格式,最小假设
4LLM-assisted Correction Loop最终 fallback,带 retry 控制

4.3.2 LLM-assist 修正:算法与回合限制

伪代码:
def parse_and_validate(raw_text, schema, llm, max_retries=2):
    try:
        parsed = json.loads(raw_text)
    except Exception:
        parsed = lenient_json_extract(raw_text)
    
    if validate_against_schema(parsed, schema):
        return parsed

    for i in range(max_retries):
        fix_prompt = build_fix_prompt(raw_text, last_errors=parsed.errors())
        corrected = llm.generate(fix_prompt, max_tokens=100)
        try:
            fixed_parsed = json.loads(corrected)
        except Exception:
            continue
        if validate_against_schema(fixed_parsed, schema):
            return fixed_parsed

    raise ParseFailureException(raw_text)
设计要点:
  • 错误消息必须具体(如 field missinginvalid enum value)。
  • 最多重试 1~2 次,避免费用膨胀。
  • 所有尝试记录到日志系统,用于后续训练样本回收。

4.3.3 JSON Schema 与类型安全化(实践建议)

  • 使用 Draft-7 或更高版本的 JSON Schema。
  • 实施字段 coercion(如 string → int, float, datetime)。
  • 所有外部输入必须经过反注入检查。
  • CI 中加入 schema regression test,确保新变更不影响已有行为。

4.4 工具抽象与接口规范(Tool API Contract)

4.4.1 工具声明模板(JSON/YAML)

id: search.v1
name: web_search
version: 1.0.2
description: Query web index and return top-k results.
inputs_schema:
  type: object
  properties:
    q: { type: string }
  required: [q]
outputs_schema:
  type: array
  items:
    type: object
    properties:
      title: string
      url: string
cost:
  usd_per_call: 0.002
  token_est: 80
idempotent: true
timeout_ms: 3000
retries:
  max: 2
  backoff: exponential
side_effects: none
auth:
  required: true
  scopes: [search:read]
工程要求:
  • 工具注册中心支持运行时查询与版本控制。
  • 工具变更需触发 schema 兼容性检查。
  • 工具元数据注入到 system prompt 中供 agent 查询。

4.4.2 幂等性、事务语义与失败模型

失败分类:
分类特征处理方式
Transient可恢复故障(超时、网络抖动)Exponential Backoff Retry
Permanent不可恢复错误(无效参数)终止计划并报告错误
Partial部分成功标记 partial 并合并结果
事务语义:
  • 对于副作用操作(如写入数据库、发起支付)采用两阶段提交或 Saga 模式。
  • 工具接口需显式标明是否支持幂等。

4.5 伪码:工程级 ReAct 主循环(单线程与并发)

4.5.1 单线程主循环(基础版)

def react_loop(state, llm, tools, schema, budget, max_steps=20):
    for step in range(max_steps):
        prompt = build_prompt(state)
        raw_output = llm.generate(prompt, max_tokens=2048)
        try:
            parsed = parse_and_validate(raw_output, schema, llm)
        except ParseFailureException:
            return failure("PARSE_FAILED")

        # Token cost estimation
        state.costs.token += estimate_token_count(raw_output)

        if parsed.answer:
            if verify_answer(parsed.answer, state):
                return success(parsed.answer, meta=state)
            else:
                state = update_state(state, thought=parsed.thought, action=None, observation={"error": "verification_failed"})
                continue

        if not parsed.action:
            if should_ask_clarification(state):
                state = update_state(state, thought=parsed.thought, action=None, observation=None)
                continue
            else:
                return failure("NO_ACTION_NO_ANSWER")

        tool = tools.get(parsed.action.tool)
        if not tool:
            state = update_state(state, thought=parsed.thought, action=parsed.action, observation={"error": "UNKNOWN_TOOL"})
            continue

        try:
            normalized_input = coerce_and_validate(parsed.action.input, tool.schema.input)
        except ValidationError as ve:
            state = update_state(state, thought=parsed.thought, action=parsed.action, observation={"error": "INPUT_INVALID", "details": str(ve)})
            continue

        try:
            obs = tool.call(normalized_input, timeout=tool.timeout_ms)
        except TransientError as te:
            obs = {"status": "TRANSIENT_ERROR", "details": str(te)}
        except PermanentError as pe:
            obs = {"status": "PERMANENT_ERROR", "details": str(pe)}

        state = update_state(state, thought=parsed.thought, action=parsed.action, observation=obs)
        state.costs.monetary += obs.get("monetary_cost", 0)
        state.costs.latency += obs.get("latency_ms", 0)

        if budget.exceeded(state.costs):
            return failure("BUDGET_EXCEEDED")

        if detect_no_progress(state):
            return failure("NO_PROGRESS_DETECTED")

    return failure("MAX_STEPS_REACHED")

4.5.2 并发 / 批量调用扩展(设计与伪码)

并发适用场景:
  • 动作之间无依赖关系;
  • 工具延迟较高;
  • 总体延迟成为瓶颈。
并发伪码:
async def concurrent_react_loop(state, llm, tools, schema, budget, max_steps=20):
    for step in range(max_steps):
        prompt = build_prompt(state)
        raw = await llm.generate_async(prompt)
        parsed = parse_and_validate(raw, schema, llm)

        actions = parsed.actions or [parsed.action]
        grouped_actions = partition_by_independence(actions, tools)

        for group in grouped_actions:
            semaphore = asyncio.Semaphore(group.max_concurrency)
            async def exec_one(action):
                async with semaphore:
                    tool = tools[action.tool]
                    inp = coerce_and_validate(action.input, tool.schema.input)
                    return await tool.call_async(inp, timeout=tool.timeout_ms)

            futures = [asyncio.create_task(exec_one(a)) for a in group.actions]
            done, pending = await asyncio.wait_for(asyncio.gather(*futures, return_exceptions=True), timeout=group.group_timeout_ms)

            observations = [res for res in done if not isinstance(res, Exception)]
            merged_obs = group.merge_fn(observations)
            state = update_state(state, thought=parsed.thought, actions=group.actions, observation=merged_obs)

            estimated_cost = sum([obs.get("estimated_cost", 0) for obs in observations])
            state.costs.add(**estimated_cost)

        if budget.exceeded(state.costs):
            return failure("BUDGET_EXCEEDED")

    return failure("MAX_STEPS_REACHED")

设计权衡与工程建议(专家清单)

组件关键权衡点推荐方案
Schema-first是否牺牲灵活性换取稳定性强烈推荐:所有接口均需 schema 校验
最小信任原则如何平衡 LLM 输出自由度与安全性所有输出必须经 schema 验证后再执行动作
解析修正策略如何降低 LLM 辅助修正带来的成本限定最多 1~2 次修正,记录修正日志用于训练
并发控制如何在性能提升与复杂度间取得平衡仅对幂等工具启用并发,使用 group timeout 控制
权限与敏感操作如何保障敏感操作的安全性加入 manual approval flag 或 human-in-loop
日志与监控如何追踪失败原因与改进方向完整保留 raw/parsed/observation 流水线日志
回归测试如何防止系统升级引发行为退化CI 中加入 few-shot/schemas 回归测试集

结论:构建工业级 ReAct Agent 的关键维度总结

维度技术要点
状态建模明确区分 visible state vs internal state
策略控制结合 LLM + Parser + Schema 实现可靠策略
成本控制成本向量 + 权重标量化 + Budget-driven prompting
工具治理工具注册中心 + Schema-first + 幂等性保证
解析鲁棒性多层级解析器 + LLM-assisted correction loop
并发调度group-wise concurrency + merge strategy + timeout control
安全合规权限校验 + 敏感操作拦截 + human-in-loop 支持
可审计性完整记录 raw → parsed → executed → observed 全链路日志
迭代优化解析失败样本回收 → 微调 LLM → 更新 few-shot 示例

源码实践

  • 文件 1:react_agent.py —— 主模块(无外部依赖,标准库实现,可插入真实 LLM 接口)

  • 文件 2:test_react_agent.py —— 单元测试(使用 unittest / pytest 均可运行)

  • 如何运行:把两个文件放在同一目录后,运行 python -m pytest test_react_agent.pypython test_react_agent.py(如果用 unittest)。

react_agent.py
"""
react_agent.py

Engineering-grade Python module draft for a ReAct-style agent:
- Parser with LLM-assisted correction loop
- Tool registry with metadata and sync/async call helpers
- CostCounter for tracking token/monetary/latency costs
- A minimal single-threaded react_loop demonstrating usage

This module is intentionally dependency-free (standard library only).
Replace DummyLLM with a real LLM wrapper for production use.
"""

from typing import Any, Dict, Optional, Callable, List
import json
import time
import asyncio
import re

# ---------------------------
# Utilities: JSON extraction
# ---------------------------

def extract_first_json_object(text: str) -> Optional[str]:
    """
    Extract the first balanced JSON object from text by scanning braces.
    Returns the JSON substring or None if not found.
    """
    start = text.find('{')
    if start == -1:
        return None
    depth = 0
    for i in range(start, len(text)):
        ch = text[i]
        if ch == '{':
            depth += 1
        elif ch == '}':
            depth -= 1
            if depth == 0:
                return text[start:i+1]
    return None

def strict_json_parse(text: str) -> Any:
    """
    Strictly parse JSON -- raises json.JSONDecodeError on failure.
    """
    return json.loads(text)

def lenient_extract_and_parse(text: str) -> Any:
    """
    Attempt to extract a JSON object and parse it. Raises json.JSONDecodeError
    if parsing the extracted JSON fails, or ValueError if none found.
    """
    js = extract_first_json_object(text)
    if js is None:
        raise ValueError("No JSON object found in text")
    return json.loads(js)

# ---------------------------
# Schema validation (simple)
# ---------------------------

def validate_agent_output_schema(obj: Dict[str, Any]) -> List[str]:
    """
    Very small validator for the expected output schema:
    {
      "thought": str,
      "action": null | {"tool": str, "input": dict},
      "answer": null | str,
      "confidence": float in [0,1]
    }
    Returns a list of validation error strings (empty if valid).
    """
    errs = []
    if not isinstance(obj, dict):
        errs.append("output must be a JSON object")
        return errs
    # thought
    if "thought" not in obj:
        errs.append("missing 'thought'")
    elif not isinstance(obj["thought"], str):
        errs.append("'thought' must be a string")
    # action
    if "action" not in obj:
        errs.append("missing 'action'")
    else:
        a = obj["action"]
        if a is not None:
            if not isinstance(a, dict):
                errs.append("'action' must be null or an object")
            else:
                if "tool" not in a:
                    errs.append("action missing 'tool'")
                elif not isinstance(a["tool"], str):
                    errs.append("action.tool must be a string")
                if "input" not in a:
                    errs.append("action missing 'input'")
                elif not isinstance(a["input"], dict):
                    errs.append("action.input must be an object")
    # answer
    if "answer" not in obj:
        errs.append("missing 'answer'")
    else:
        if obj["answer"] is not None and not isinstance(obj["answer"], str):
            errs.append("'answer' must be null or a string")
    # confidence
    if "confidence" not in obj:
        errs.append("missing 'confidence'")
    else:
        conf = obj["confidence"]
        if not (isinstance(conf, float) or isinstance(conf, int)):
            errs.append("'confidence' must be a number")
        else:
            if not (0.0 <= float(conf) <= 1.0):
                errs.append("'confidence' must be between 0 and 1")
    return errs

# ---------------------------
# LLM interface (pluggable)
# ---------------------------

class LLMInterface:
    """
    Minimal LLM interface wrapper.
    Implement generate(prompt, max_tokens) for synchronous LLM,
    and generate_async(...) for async usage.
    """
    def generate(self, prompt: str, max_tokens: int = 1024) -> str:
        raise NotImplementedError()

    async def generate_async(self, prompt: str, max_tokens: int = 1024) -> str:
        # default implementation wraps sync generate
        return self.generate(prompt, max_tokens=max_tokens)

class DummyLLM(LLMInterface):
    """
    Deterministic dummy LLM for unit tests and local debugging.

    Behavior:
      - If constructed with a mapping `responses` (ordered dict like),
        when a key is substring of the prompt, the first unused matching key
        will be returned and marked used. This allows sequential different
        responses for similar prompts across multiple calls.
      - If prompt contains the special marker <<FIX-RETURN>>...<<END-FIX>>, returns the enclosed text.
      - Otherwise returns a minimal valid no-op JSON object.
    """
    def __init__(self, responses: Optional[Dict[str, str]] = None):
        self.responses = responses or {}
        self._used_keys: List[str] = []

    def generate(self, prompt: str, max_tokens: int = 1024) -> str:
        # exact mapping or substring mapping with first unused match
        for k, v in self.responses.items():
            if k in prompt and k not in self._used_keys:
                self._used_keys.append(k)
                return v
        # if all matches used, allow reuse: return first matching key
        for k, v in self.responses.items():
            if k in prompt:
                return v
        # special marker support
        m = re.search(r'<<FIX-RETURN>>(.*?)<<END-FIX>>', prompt, flags=re.S)
        if m:
            return m.group(1).strip()
        # fallback: return a polite empty valid object (no action)
        return json.dumps({"thought":"no-op","action":None,"answer":None,"confidence":0.0})

# ---------------------------
# Parser with LLM-assisted correction loop
# ---------------------------

def build_fix_prompt(original: str, validation_errors: List[str], schema_note: str = "") -> str:
    """
    Build a short prompt to ask the LLM to only output corrected JSON
    that satisfies the agent output schema.
    """
    errs = "\\n".join(validation_errors)
    prompt = (
        "The system attempted to parse the following LLM output but failed schema validation.\\n"
        "Original output:\\n"
        "----BEGIN ORIGINAL----\\n"
        f"{original}\\n"
        "----END ORIGINAL----\\n"
        "Validation errors:\\n"
        f"{errs}\\n"
        "Please output **only** a single JSON object that matches the schema: "
        '{"thought": "string", "action": null | {"tool":"string","input":{}}, "answer": null | "string", "confidence": 0.0} '
        "Do not include any prose. If you intend no action, set action to null.\\n"
    )
    if schema_note:
        prompt += "\\n" + schema_note
    return prompt

def try_parse_json(raw: str) -> Any:
    """
    Try strict parse first; if fails, try lenient extraction. Raises json.JSONDecodeError or ValueError.
    """
    try:
        return strict_json_parse(raw)
    except Exception:
        # lenient attempt
        return lenient_extract_and_parse(raw)

def parse_and_validate(raw_text: str, llm: Optional[LLMInterface] = None, max_retries: int = 2) -> Dict[str, Any]:
    """
    Parse the raw LLM output into JSON and validate against the agent output schema.
    If validation fails and an LLM is provided, ask the LLM to correct the output.
    """
    last_raw = raw_text
    attempt = 0
    while True:
        attempt += 1
        parsed = None
        try:
            parsed = try_parse_json(last_raw)
            errs = validate_agent_output_schema(parsed)
            if not errs:
                return parsed  # success
            else:
                parse_err = "schema errors: " + "; ".join(errs)
        except Exception as e:
            parsed = None
            parse_err = f"parse error: {repr(e)}"
        # at this point parsed is None or has errors
        if llm is None or attempt > max_retries:
            # give up
            raise ValueError(f"Unparsable or invalid output after {attempt} attempts. Last error: {parse_err}. Raw: {last_raw[:1000]}")
        # build fix prompt
        fix_prompt = build_fix_prompt(last_raw, errs if parsed is not None else [parse_err])
        fixed = llm.generate(fix_prompt, max_tokens=512)
        # guard against identical response
        if fixed.strip() == last_raw.strip():
            raise ValueError("LLM returned identical output when asked to fix; aborting")
        last_raw = fixed
        # loop to attempt parse again

# ---------------------------
# Tool registry and Tool abstraction
# ---------------------------

class Tool:
    """
    Tool metadata and a callable interface. In production, call_fn should
    make network calls or perform the actual tool behavior.
    """
    def __init__(self, 
                 id: str,
                 name: str,
                 inputs_schema: Optional[Dict] = None,
                 outputs_schema: Optional[Dict] = None,
                 cost_estimate: Optional[Dict] = None,
                 idempotent: bool = True,
                 timeout_ms: int = 3000,
                 call_fn: Optional[Callable] = None,
                 call_async_fn: Optional[Callable] = None):
        self.id = id
        self.name = name
        self.inputs_schema = inputs_schema or {}
        self.outputs_schema = outputs_schema or {}
        self.cost_estimate = cost_estimate or {}
        self.idempotent = idempotent
        self.timeout_ms = timeout_ms
        self._call_fn = call_fn
        self._call_async_fn = call_async_fn

    def call(self, inp: Dict[str, Any]) -> Dict[str, Any]:
        """
        Synchronous call wrapper. Raises exceptions on error.
        """
        if self._call_fn is None:
            raise RuntimeError("No call_fn configured for tool " + self.name)
        start = time.time()
        out = self._call_fn(inp)
        out_meta = dict(out)
        out_meta.setdefault("_meta", {})
        out_meta["_meta"]["latency_ms"] = int((time.time() - start) * 1000)
        return out_meta

    async def call_async(self, inp: Dict[str, Any]) -> Dict[str, Any]:
        if self._call_async_fn is not None:
            start = time.time()
            out = await self._call_async_fn(inp)
            out_meta = dict(out)
            out_meta.setdefault("_meta", {})
            out_meta["_meta"]["latency_ms"] = int((time.time() - start) * 1000)
            return out_meta
        # wrap sync in threadpool to not block event loop
        loop = asyncio.get_event_loop()
        return await loop.run_in_executor(None, self.call, inp)

class ToolRegistry:
    def __init__(self):
        self._tools = {}

    def register(self, tool: Tool):
        if tool.name in self._tools:
            raise KeyError("Tool already registered: " + tool.name)
        self._tools[tool.name] = tool

    def get(self, name: str) -> Optional[Tool]:
        return self._tools.get(name)

# ---------------------------
# Cost counter
# ---------------------------

class CostCounter:
    def __init__(self):
        self.tokens = 0
        self.monetary = 0.0
        self.latency_ms = 0

    def add_tokens(self, n: int):
        self.tokens += int(n)

    def add_monetary(self, usd: float):
        self.monetary += float(usd)

    def add_latency(self, ms: int):
        self.latency_ms += int(ms)

    def add_from_tool_obs(self, obs: Dict[str, Any]):
        meta = obs.get("_meta", {})
        self.add_latency(meta.get("latency_ms", 0))
        # if tool provides monetary_cost in obs, add it
        if "monetary_cost" in obs:
            try:
                self.add_monetary(float(obs["monetary_cost"]))
            except Exception:
                pass

    def exceeded(self, budget: Dict[str, Any]) -> bool:
        """
        budget example: {"tokens":10000, "monetary": 1.0}
        """
        if not budget:
            return False
        if "tokens" in budget and self.tokens > budget["tokens"]:
            return True
        if "monetary" in budget and self.monetary > budget["monetary"]:
            return True
        return False

# ---------------------------
# Minimal single-threaded react loop
# ---------------------------

def build_prompt_from_state(state: Dict[str, Any]) -> str:
    """
    Simple prompt builder. For production, replace with system prompt + few-shot + state summary.
    """
    # Allow tests to inject explicit prompt override
    if "prompt_override" in state:
        return state["prompt_override"]
    return "State summary: " + json.dumps(state.get("summary", {})) + "\\nProduce JSON output."

def react_single_thread(initial_state: Dict[str, Any],
                        llm: LLMInterface,
                        tool_registry: ToolRegistry,
                        budget: Optional[Dict[str, Any]] = None,
                        max_steps: int = 10) -> Dict[str, Any]:
    S = dict(initial_state)  # shallow copy
    costs = CostCounter()
    for step in range(max_steps):
        prompt = build_prompt_from_state(S)
        raw = llm.generate(prompt, max_tokens=1024)
        costs.add_tokens(len(raw.split()))
        try:
            parsed = parse_and_validate(raw, llm=llm, max_retries=1)
        except Exception as e:
            return {"status":"error", "reason":"parse_failure", "detail": str(e), "state": S}
        # record thought
        S.setdefault("history", []).append({"thought": parsed.get("thought")})
        # if answer -> verify (here we accept as final)
        if parsed.get("answer") is not None:
            return {"status":"success", "answer": parsed["answer"], "steps": step+1, "costs": {"tokens": costs.tokens, "monetary": costs.monetary, "latency_ms": costs.latency_ms}}
        act = parsed.get("action")
        if act is None:
            # nothing to do - terminate as no-op
            return {"status":"no_action", "steps": step+1, "state": S}
        # validate tool exists
        tool = tool_registry.get(act.get("tool"))
        if tool is None:
            S["history"].append({"observation": {"error":"unknown_tool", "tool": act.get("tool")}})
            continue
        # call tool (sync)
        try:
            obs = tool.call(act.get("input", {}))
        except Exception as e:
            obs = {"error": "tool_exception", "detail": str(e)}
        S["history"].append({"action": act, "observation": obs})
        costs.add_from_tool_obs(obs)
        # check budget
        if costs.exceeded(budget or {}):
            return {"status":"budget_exceeded", "costs": {"tokens": costs.tokens, "monetary": costs.monetary}}
    return {"status":"max_steps_reached", "steps": max_steps, "costs": {"tokens": costs.tokens, "monetary": costs.monetary}}
test_react_agent.py
"""
Unit tests for react_agent.py
"""

import unittest
import json
import asyncio
import time
from react_agent import (
    extract_first_json_object, strict_json_parse, lenient_extract_and_parse,
    validate_agent_output_schema, DummyLLM, parse_and_validate,
    Tool, ToolRegistry, CostCounter, react_single_thread
)

class TestJSONExtraction(unittest.TestCase):
    def test_extract_balanced_json(self):
        text = "prefix {\"a\":1, \"b\":{\"c\":2}} suffix"
        js = extract_first_json_object(text)
        self.assertIsNotNone(js)
        self.assertEqual(json.loads(js), {"a":1,"b":{"c":2}})
    def test_no_json(self):
        text = "no json here"
        js = extract_first_json_object(text)
        self.assertIsNone(js)

class TestParsingAndValidation(unittest.TestCase):
    def test_strict_parse_valid(self):
        raw = json.dumps({"thought":"t","action":None,"answer":None,"confidence":0.5})
        parsed = parse_and_validate(raw, llm=DummyLLM(), max_retries=0)
        self.assertEqual(parsed["thought"], "t")
    def test_lenient_extract_and_parse(self):
        raw = "Some explanation... {\"thought\":\"lenient\",\"action\":null,\"answer\":null,\"confidence\":0.1} end"
        parsed = parse_and_validate(raw, llm=DummyLLM(), max_retries=0)
        self.assertEqual(parsed["thought"], "lenient")
    def test_llm_assist_fix(self):
        # raw is invalid; DummyLLM will return corrected JSON if requests contains special mapping
        raw = "I will not output JSON"
        fixed = json.dumps({"thought":"fixed","action":None,"answer":"done","confidence":0.9})
        llm = DummyLLM(responses={"Please output": fixed})
        # We craft parse_and_validate to call llm when max_retries>0 -- here single retry
        parsed = parse_and_validate(raw, llm=llm, max_retries=1)
        self.assertEqual(parsed["answer"], "done")

class TestToolRegistryAndCost(unittest.TestCase):
    def test_tool_call_and_cost_accumulation(self):
        def search_fn(inp):
            q = inp.get("q","")
            return {"results":[{"title":"r","url":"u","snippet":"s"}], "monetary_cost": 0.005}
        search_tool = Tool(id="search.v1", name="search", call_fn=search_fn)
        reg = ToolRegistry()
        reg.register(search_tool)
        tool = reg.get("search")
        obs = tool.call({"q":"test"})
        self.assertIn("results", obs)
        cost = CostCounter()
        cost.add_from_tool_obs(obs)
        self.assertAlmostEqual(cost.monetary, 0.005, places=6)

class TestReactLoop(unittest.TestCase):
    def test_react_single_thread_answer(self):
        # LLM will return a direct valid JSON with an answer
        fixed = json.dumps({"thought":"final","action":None,"answer":"42","confidence":1.0})
        llm = DummyLLM(responses={"Produce JSON output.": fixed})
        reg = ToolRegistry()
        state = {"summary":{"goal":"compute 6*7"}}
        res = react_single_thread(state, llm, reg, budget={"tokens":1000, "monetary":1.0}, max_steps=3)
        self.assertEqual(res["status"], "success")
        self.assertEqual(res["answer"], "42")
    def test_react_single_thread_tool_call(self):
        # LLM returns an action calling 'search', then second LLM call returns answer
        first = json.dumps({"thought":"do search","action":{"tool":"search","input":{"q":"x"}},"answer":None,"confidence":0.6})
        second = json.dumps({"thought":"got it","action":None,"answer":"found","confidence":0.9})
        llm = DummyLLM(responses={"Produce JSON output.": first, "State summary": second})
        def search_fn(inp):
            return {"results":[{"title":"r"}], "monetary_cost": 0.002}
        reg = ToolRegistry()
        reg.register(Tool(id="search.v1", name="search", call_fn=search_fn))
        state = {"summary":{"goal":"find x"}}
        res = react_single_thread(state, llm, reg, budget={"tokens":1000, "monetary":1.0}, max_steps=5)
        self.assertEqual(res["status"], "success")
        self.assertEqual(res["answer"], "found")

if __name__ == '__main__':
    unittest.main(verbosity=2)

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值