=======================未经允许,不得转载,侵权必究=========================
目标:以数学严谨性 + 工程可落地性为准则,全面重构 ReAct 系统的核心组件——T-A-O 循环、Prompt Schema、解析器、工具契约、并发策略。面向高可靠性、低成本、可审计的生产部署。
4.1 T-A-O 循环的形式化定义
4.1.0 概览(专家陈述)

4.1.1 形式化模型:POMDP 视角下的 ReAct
定义 1:POMDP 表示法

设计权衡说明:
- 尽管理论上可以建模为完整的 POMDP,实际中由于 LLM 输出不确定性大,我们将其视为非平稳策略驱动的模拟器而非严格贝叶斯推断系统。
- 因此,重点在于设计状态管理器 UpdateState 和 解析器 Parser 来稳定策略执行路径。
4.1.2 回合终止条件与成本度量(精确定义)
定义 2:终止条件 Predicate Set

定义 3:成本向量与标量化函数

⚠️ 工程提示:在 prompt 中注入当前预算余量(如
{budget_left: {token: 2000, usd: 0.1}}),并要求 LLM 在决策中考虑该约束。
4.2 Prompt Schema 与指令工程(Prompt Engineering)
4.2.1 指令段(System Prompt)设计要素(合约化)
系统 prompt 应作为“策略接口契约”,包含以下内容:
| 类别 | 内容示例 |
|---|---|
| 输出格式要求 | “请仅输出合法 JSON,符合指定 schema。” |
| 工具列表声明 | "available_tools": [{"name": "search", "signature": "q:string"}] |
| 安全与合规约束 | “不得访问隐私字段”、“遇到敏感词需请求用户确认” |
| 效用与预算指引 | “剩余 token 数为 2000,请优先使用本地缓存工具。” |
| 自检与修正指令 | “若解析失败,请在 100 tokens 内返回修正后的 JSON。” |
✅ 工程实践:将 system prompt 版本化存储(如
system_v1.2.txt),并随每次部署进行回归测试。
4.2.2 Few-Shot 示例:选择原则与动态检索
原则:
- 覆盖率优先:涵盖正常流程、边界情况、错误处理路径。
- 最小上下文开销:避免长文本挤占上下文窗口。
- 纠错教学:提供“原始错误 → 修正样例”对帮助模型学习。
- 版本同步:示例与当前 schema/output contract 同步更新。
推荐做法:
- 使用 embedding 检索引擎(如 FAISS)动态匹配最相关 few-shot 示例。
- 缓存高频示例至内存中加速检索。
- 示例质量由人工标注团队定期审核。
4.2.3 输出约束(严格 schema + 字段语义)
推荐最小字段集(JSON Schema):
{
"type": "object",
"properties": {
"thought": { "type": "string", "maxLength": 500 },
"action": {
"oneOf": [
{
"type": "object",
"properties": {
"tool": { "enum": ["search", "calculator", "..."] },
"input": { "type": "object" }
},
"required": ["tool"]
},
{ "type": "null" }
]
},
"answer": { "type": ["string", "null"] },
"confidence": { "type": "number", "minimum": 0, "maximum": 1 },
"debug": { "type": ["object", "null"], "additionalProperties": true }
},
"required": ["thought"]
}
🔐 工程提醒:所有
action.tool必须白名单校验;action.input需再次 schema 校验防止注入攻击。
4.3 输出解析器(Parser)设计(工程化、可证明的鲁棒性)
4.3.1 层级解析策略(推荐顺序)
| 层级 | 方法 | 描述 |
|---|---|---|
| 1 | Strict JSON Parse | 直接校验纯 JSON(最优路径) |
| 2 | Lenient Extractor | 提取首段 JSON 块,去除冗余内容 |
| 3 | Regex Fallback | 仅限固定格式,最小假设 |
| 4 | LLM-assisted Correction Loop | 最终 fallback,带 retry 控制 |
4.3.2 LLM-assist 修正:算法与回合限制
伪代码:
def parse_and_validate(raw_text, schema, llm, max_retries=2):
try:
parsed = json.loads(raw_text)
except Exception:
parsed = lenient_json_extract(raw_text)
if validate_against_schema(parsed, schema):
return parsed
for i in range(max_retries):
fix_prompt = build_fix_prompt(raw_text, last_errors=parsed.errors())
corrected = llm.generate(fix_prompt, max_tokens=100)
try:
fixed_parsed = json.loads(corrected)
except Exception:
continue
if validate_against_schema(fixed_parsed, schema):
return fixed_parsed
raise ParseFailureException(raw_text)
设计要点:
- 错误消息必须具体(如
field missing,invalid enum value)。 - 最多重试 1~2 次,避免费用膨胀。
- 所有尝试记录到日志系统,用于后续训练样本回收。
4.3.3 JSON Schema 与类型安全化(实践建议)
- 使用 Draft-7 或更高版本的 JSON Schema。
- 实施字段 coercion(如 string → int, float, datetime)。
- 所有外部输入必须经过反注入检查。
- CI 中加入 schema regression test,确保新变更不影响已有行为。
4.4 工具抽象与接口规范(Tool API Contract)
4.4.1 工具声明模板(JSON/YAML)
id: search.v1
name: web_search
version: 1.0.2
description: Query web index and return top-k results.
inputs_schema:
type: object
properties:
q: { type: string }
required: [q]
outputs_schema:
type: array
items:
type: object
properties:
title: string
url: string
cost:
usd_per_call: 0.002
token_est: 80
idempotent: true
timeout_ms: 3000
retries:
max: 2
backoff: exponential
side_effects: none
auth:
required: true
scopes: [search:read]
工程要求:
- 工具注册中心支持运行时查询与版本控制。
- 工具变更需触发 schema 兼容性检查。
- 工具元数据注入到 system prompt 中供 agent 查询。
4.4.2 幂等性、事务语义与失败模型
失败分类:
| 分类 | 特征 | 处理方式 |
|---|---|---|
| Transient | 可恢复故障(超时、网络抖动) | Exponential Backoff Retry |
| Permanent | 不可恢复错误(无效参数) | 终止计划并报告错误 |
| Partial | 部分成功 | 标记 partial 并合并结果 |
事务语义:
- 对于副作用操作(如写入数据库、发起支付)采用两阶段提交或 Saga 模式。
- 工具接口需显式标明是否支持幂等。
4.5 伪码:工程级 ReAct 主循环(单线程与并发)
4.5.1 单线程主循环(基础版)
def react_loop(state, llm, tools, schema, budget, max_steps=20):
for step in range(max_steps):
prompt = build_prompt(state)
raw_output = llm.generate(prompt, max_tokens=2048)
try:
parsed = parse_and_validate(raw_output, schema, llm)
except ParseFailureException:
return failure("PARSE_FAILED")
# Token cost estimation
state.costs.token += estimate_token_count(raw_output)
if parsed.answer:
if verify_answer(parsed.answer, state):
return success(parsed.answer, meta=state)
else:
state = update_state(state, thought=parsed.thought, action=None, observation={"error": "verification_failed"})
continue
if not parsed.action:
if should_ask_clarification(state):
state = update_state(state, thought=parsed.thought, action=None, observation=None)
continue
else:
return failure("NO_ACTION_NO_ANSWER")
tool = tools.get(parsed.action.tool)
if not tool:
state = update_state(state, thought=parsed.thought, action=parsed.action, observation={"error": "UNKNOWN_TOOL"})
continue
try:
normalized_input = coerce_and_validate(parsed.action.input, tool.schema.input)
except ValidationError as ve:
state = update_state(state, thought=parsed.thought, action=parsed.action, observation={"error": "INPUT_INVALID", "details": str(ve)})
continue
try:
obs = tool.call(normalized_input, timeout=tool.timeout_ms)
except TransientError as te:
obs = {"status": "TRANSIENT_ERROR", "details": str(te)}
except PermanentError as pe:
obs = {"status": "PERMANENT_ERROR", "details": str(pe)}
state = update_state(state, thought=parsed.thought, action=parsed.action, observation=obs)
state.costs.monetary += obs.get("monetary_cost", 0)
state.costs.latency += obs.get("latency_ms", 0)
if budget.exceeded(state.costs):
return failure("BUDGET_EXCEEDED")
if detect_no_progress(state):
return failure("NO_PROGRESS_DETECTED")
return failure("MAX_STEPS_REACHED")
4.5.2 并发 / 批量调用扩展(设计与伪码)
并发适用场景:
- 动作之间无依赖关系;
- 工具延迟较高;
- 总体延迟成为瓶颈。
并发伪码:
async def concurrent_react_loop(state, llm, tools, schema, budget, max_steps=20):
for step in range(max_steps):
prompt = build_prompt(state)
raw = await llm.generate_async(prompt)
parsed = parse_and_validate(raw, schema, llm)
actions = parsed.actions or [parsed.action]
grouped_actions = partition_by_independence(actions, tools)
for group in grouped_actions:
semaphore = asyncio.Semaphore(group.max_concurrency)
async def exec_one(action):
async with semaphore:
tool = tools[action.tool]
inp = coerce_and_validate(action.input, tool.schema.input)
return await tool.call_async(inp, timeout=tool.timeout_ms)
futures = [asyncio.create_task(exec_one(a)) for a in group.actions]
done, pending = await asyncio.wait_for(asyncio.gather(*futures, return_exceptions=True), timeout=group.group_timeout_ms)
observations = [res for res in done if not isinstance(res, Exception)]
merged_obs = group.merge_fn(observations)
state = update_state(state, thought=parsed.thought, actions=group.actions, observation=merged_obs)
estimated_cost = sum([obs.get("estimated_cost", 0) for obs in observations])
state.costs.add(**estimated_cost)
if budget.exceeded(state.costs):
return failure("BUDGET_EXCEEDED")
return failure("MAX_STEPS_REACHED")
设计权衡与工程建议(专家清单)
| 组件 | 关键权衡点 | 推荐方案 |
|---|---|---|
| Schema-first | 是否牺牲灵活性换取稳定性 | 强烈推荐:所有接口均需 schema 校验 |
| 最小信任原则 | 如何平衡 LLM 输出自由度与安全性 | 所有输出必须经 schema 验证后再执行动作 |
| 解析修正策略 | 如何降低 LLM 辅助修正带来的成本 | 限定最多 1~2 次修正,记录修正日志用于训练 |
| 并发控制 | 如何在性能提升与复杂度间取得平衡 | 仅对幂等工具启用并发,使用 group timeout 控制 |
| 权限与敏感操作 | 如何保障敏感操作的安全性 | 加入 manual approval flag 或 human-in-loop |
| 日志与监控 | 如何追踪失败原因与改进方向 | 完整保留 raw/parsed/observation 流水线日志 |
| 回归测试 | 如何防止系统升级引发行为退化 | CI 中加入 few-shot/schemas 回归测试集 |
结论:构建工业级 ReAct Agent 的关键维度总结
| 维度 | 技术要点 |
|---|---|
| 状态建模 | 明确区分 visible state vs internal state |
| 策略控制 | 结合 LLM + Parser + Schema 实现可靠策略 |
| 成本控制 | 成本向量 + 权重标量化 + Budget-driven prompting |
| 工具治理 | 工具注册中心 + Schema-first + 幂等性保证 |
| 解析鲁棒性 | 多层级解析器 + LLM-assisted correction loop |
| 并发调度 | group-wise concurrency + merge strategy + timeout control |
| 安全合规 | 权限校验 + 敏感操作拦截 + human-in-loop 支持 |
| 可审计性 | 完整记录 raw → parsed → executed → observed 全链路日志 |
| 迭代优化 | 解析失败样本回收 → 微调 LLM → 更新 few-shot 示例 |
源码实践
-
文件 1:
react_agent.py—— 主模块(无外部依赖,标准库实现,可插入真实 LLM 接口) -
文件 2:
test_react_agent.py—— 单元测试(使用unittest/pytest均可运行) -
如何运行:把两个文件放在同一目录后,运行
python -m pytest test_react_agent.py或python test_react_agent.py(如果用 unittest)。
react_agent.py
"""
react_agent.py
Engineering-grade Python module draft for a ReAct-style agent:
- Parser with LLM-assisted correction loop
- Tool registry with metadata and sync/async call helpers
- CostCounter for tracking token/monetary/latency costs
- A minimal single-threaded react_loop demonstrating usage
This module is intentionally dependency-free (standard library only).
Replace DummyLLM with a real LLM wrapper for production use.
"""
from typing import Any, Dict, Optional, Callable, List
import json
import time
import asyncio
import re
# ---------------------------
# Utilities: JSON extraction
# ---------------------------
def extract_first_json_object(text: str) -> Optional[str]:
"""
Extract the first balanced JSON object from text by scanning braces.
Returns the JSON substring or None if not found.
"""
start = text.find('{')
if start == -1:
return None
depth = 0
for i in range(start, len(text)):
ch = text[i]
if ch == '{':
depth += 1
elif ch == '}':
depth -= 1
if depth == 0:
return text[start:i+1]
return None
def strict_json_parse(text: str) -> Any:
"""
Strictly parse JSON -- raises json.JSONDecodeError on failure.
"""
return json.loads(text)
def lenient_extract_and_parse(text: str) -> Any:
"""
Attempt to extract a JSON object and parse it. Raises json.JSONDecodeError
if parsing the extracted JSON fails, or ValueError if none found.
"""
js = extract_first_json_object(text)
if js is None:
raise ValueError("No JSON object found in text")
return json.loads(js)
# ---------------------------
# Schema validation (simple)
# ---------------------------
def validate_agent_output_schema(obj: Dict[str, Any]) -> List[str]:
"""
Very small validator for the expected output schema:
{
"thought": str,
"action": null | {"tool": str, "input": dict},
"answer": null | str,
"confidence": float in [0,1]
}
Returns a list of validation error strings (empty if valid).
"""
errs = []
if not isinstance(obj, dict):
errs.append("output must be a JSON object")
return errs
# thought
if "thought" not in obj:
errs.append("missing 'thought'")
elif not isinstance(obj["thought"], str):
errs.append("'thought' must be a string")
# action
if "action" not in obj:
errs.append("missing 'action'")
else:
a = obj["action"]
if a is not None:
if not isinstance(a, dict):
errs.append("'action' must be null or an object")
else:
if "tool" not in a:
errs.append("action missing 'tool'")
elif not isinstance(a["tool"], str):
errs.append("action.tool must be a string")
if "input" not in a:
errs.append("action missing 'input'")
elif not isinstance(a["input"], dict):
errs.append("action.input must be an object")
# answer
if "answer" not in obj:
errs.append("missing 'answer'")
else:
if obj["answer"] is not None and not isinstance(obj["answer"], str):
errs.append("'answer' must be null or a string")
# confidence
if "confidence" not in obj:
errs.append("missing 'confidence'")
else:
conf = obj["confidence"]
if not (isinstance(conf, float) or isinstance(conf, int)):
errs.append("'confidence' must be a number")
else:
if not (0.0 <= float(conf) <= 1.0):
errs.append("'confidence' must be between 0 and 1")
return errs
# ---------------------------
# LLM interface (pluggable)
# ---------------------------
class LLMInterface:
"""
Minimal LLM interface wrapper.
Implement generate(prompt, max_tokens) for synchronous LLM,
and generate_async(...) for async usage.
"""
def generate(self, prompt: str, max_tokens: int = 1024) -> str:
raise NotImplementedError()
async def generate_async(self, prompt: str, max_tokens: int = 1024) -> str:
# default implementation wraps sync generate
return self.generate(prompt, max_tokens=max_tokens)
class DummyLLM(LLMInterface):
"""
Deterministic dummy LLM for unit tests and local debugging.
Behavior:
- If constructed with a mapping `responses` (ordered dict like),
when a key is substring of the prompt, the first unused matching key
will be returned and marked used. This allows sequential different
responses for similar prompts across multiple calls.
- If prompt contains the special marker <<FIX-RETURN>>...<<END-FIX>>, returns the enclosed text.
- Otherwise returns a minimal valid no-op JSON object.
"""
def __init__(self, responses: Optional[Dict[str, str]] = None):
self.responses = responses or {}
self._used_keys: List[str] = []
def generate(self, prompt: str, max_tokens: int = 1024) -> str:
# exact mapping or substring mapping with first unused match
for k, v in self.responses.items():
if k in prompt and k not in self._used_keys:
self._used_keys.append(k)
return v
# if all matches used, allow reuse: return first matching key
for k, v in self.responses.items():
if k in prompt:
return v
# special marker support
m = re.search(r'<<FIX-RETURN>>(.*?)<<END-FIX>>', prompt, flags=re.S)
if m:
return m.group(1).strip()
# fallback: return a polite empty valid object (no action)
return json.dumps({"thought":"no-op","action":None,"answer":None,"confidence":0.0})
# ---------------------------
# Parser with LLM-assisted correction loop
# ---------------------------
def build_fix_prompt(original: str, validation_errors: List[str], schema_note: str = "") -> str:
"""
Build a short prompt to ask the LLM to only output corrected JSON
that satisfies the agent output schema.
"""
errs = "\\n".join(validation_errors)
prompt = (
"The system attempted to parse the following LLM output but failed schema validation.\\n"
"Original output:\\n"
"----BEGIN ORIGINAL----\\n"
f"{original}\\n"
"----END ORIGINAL----\\n"
"Validation errors:\\n"
f"{errs}\\n"
"Please output **only** a single JSON object that matches the schema: "
'{"thought": "string", "action": null | {"tool":"string","input":{}}, "answer": null | "string", "confidence": 0.0} '
"Do not include any prose. If you intend no action, set action to null.\\n"
)
if schema_note:
prompt += "\\n" + schema_note
return prompt
def try_parse_json(raw: str) -> Any:
"""
Try strict parse first; if fails, try lenient extraction. Raises json.JSONDecodeError or ValueError.
"""
try:
return strict_json_parse(raw)
except Exception:
# lenient attempt
return lenient_extract_and_parse(raw)
def parse_and_validate(raw_text: str, llm: Optional[LLMInterface] = None, max_retries: int = 2) -> Dict[str, Any]:
"""
Parse the raw LLM output into JSON and validate against the agent output schema.
If validation fails and an LLM is provided, ask the LLM to correct the output.
"""
last_raw = raw_text
attempt = 0
while True:
attempt += 1
parsed = None
try:
parsed = try_parse_json(last_raw)
errs = validate_agent_output_schema(parsed)
if not errs:
return parsed # success
else:
parse_err = "schema errors: " + "; ".join(errs)
except Exception as e:
parsed = None
parse_err = f"parse error: {repr(e)}"
# at this point parsed is None or has errors
if llm is None or attempt > max_retries:
# give up
raise ValueError(f"Unparsable or invalid output after {attempt} attempts. Last error: {parse_err}. Raw: {last_raw[:1000]}")
# build fix prompt
fix_prompt = build_fix_prompt(last_raw, errs if parsed is not None else [parse_err])
fixed = llm.generate(fix_prompt, max_tokens=512)
# guard against identical response
if fixed.strip() == last_raw.strip():
raise ValueError("LLM returned identical output when asked to fix; aborting")
last_raw = fixed
# loop to attempt parse again
# ---------------------------
# Tool registry and Tool abstraction
# ---------------------------
class Tool:
"""
Tool metadata and a callable interface. In production, call_fn should
make network calls or perform the actual tool behavior.
"""
def __init__(self,
id: str,
name: str,
inputs_schema: Optional[Dict] = None,
outputs_schema: Optional[Dict] = None,
cost_estimate: Optional[Dict] = None,
idempotent: bool = True,
timeout_ms: int = 3000,
call_fn: Optional[Callable] = None,
call_async_fn: Optional[Callable] = None):
self.id = id
self.name = name
self.inputs_schema = inputs_schema or {}
self.outputs_schema = outputs_schema or {}
self.cost_estimate = cost_estimate or {}
self.idempotent = idempotent
self.timeout_ms = timeout_ms
self._call_fn = call_fn
self._call_async_fn = call_async_fn
def call(self, inp: Dict[str, Any]) -> Dict[str, Any]:
"""
Synchronous call wrapper. Raises exceptions on error.
"""
if self._call_fn is None:
raise RuntimeError("No call_fn configured for tool " + self.name)
start = time.time()
out = self._call_fn(inp)
out_meta = dict(out)
out_meta.setdefault("_meta", {})
out_meta["_meta"]["latency_ms"] = int((time.time() - start) * 1000)
return out_meta
async def call_async(self, inp: Dict[str, Any]) -> Dict[str, Any]:
if self._call_async_fn is not None:
start = time.time()
out = await self._call_async_fn(inp)
out_meta = dict(out)
out_meta.setdefault("_meta", {})
out_meta["_meta"]["latency_ms"] = int((time.time() - start) * 1000)
return out_meta
# wrap sync in threadpool to not block event loop
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self.call, inp)
class ToolRegistry:
def __init__(self):
self._tools = {}
def register(self, tool: Tool):
if tool.name in self._tools:
raise KeyError("Tool already registered: " + tool.name)
self._tools[tool.name] = tool
def get(self, name: str) -> Optional[Tool]:
return self._tools.get(name)
# ---------------------------
# Cost counter
# ---------------------------
class CostCounter:
def __init__(self):
self.tokens = 0
self.monetary = 0.0
self.latency_ms = 0
def add_tokens(self, n: int):
self.tokens += int(n)
def add_monetary(self, usd: float):
self.monetary += float(usd)
def add_latency(self, ms: int):
self.latency_ms += int(ms)
def add_from_tool_obs(self, obs: Dict[str, Any]):
meta = obs.get("_meta", {})
self.add_latency(meta.get("latency_ms", 0))
# if tool provides monetary_cost in obs, add it
if "monetary_cost" in obs:
try:
self.add_monetary(float(obs["monetary_cost"]))
except Exception:
pass
def exceeded(self, budget: Dict[str, Any]) -> bool:
"""
budget example: {"tokens":10000, "monetary": 1.0}
"""
if not budget:
return False
if "tokens" in budget and self.tokens > budget["tokens"]:
return True
if "monetary" in budget and self.monetary > budget["monetary"]:
return True
return False
# ---------------------------
# Minimal single-threaded react loop
# ---------------------------
def build_prompt_from_state(state: Dict[str, Any]) -> str:
"""
Simple prompt builder. For production, replace with system prompt + few-shot + state summary.
"""
# Allow tests to inject explicit prompt override
if "prompt_override" in state:
return state["prompt_override"]
return "State summary: " + json.dumps(state.get("summary", {})) + "\\nProduce JSON output."
def react_single_thread(initial_state: Dict[str, Any],
llm: LLMInterface,
tool_registry: ToolRegistry,
budget: Optional[Dict[str, Any]] = None,
max_steps: int = 10) -> Dict[str, Any]:
S = dict(initial_state) # shallow copy
costs = CostCounter()
for step in range(max_steps):
prompt = build_prompt_from_state(S)
raw = llm.generate(prompt, max_tokens=1024)
costs.add_tokens(len(raw.split()))
try:
parsed = parse_and_validate(raw, llm=llm, max_retries=1)
except Exception as e:
return {"status":"error", "reason":"parse_failure", "detail": str(e), "state": S}
# record thought
S.setdefault("history", []).append({"thought": parsed.get("thought")})
# if answer -> verify (here we accept as final)
if parsed.get("answer") is not None:
return {"status":"success", "answer": parsed["answer"], "steps": step+1, "costs": {"tokens": costs.tokens, "monetary": costs.monetary, "latency_ms": costs.latency_ms}}
act = parsed.get("action")
if act is None:
# nothing to do - terminate as no-op
return {"status":"no_action", "steps": step+1, "state": S}
# validate tool exists
tool = tool_registry.get(act.get("tool"))
if tool is None:
S["history"].append({"observation": {"error":"unknown_tool", "tool": act.get("tool")}})
continue
# call tool (sync)
try:
obs = tool.call(act.get("input", {}))
except Exception as e:
obs = {"error": "tool_exception", "detail": str(e)}
S["history"].append({"action": act, "observation": obs})
costs.add_from_tool_obs(obs)
# check budget
if costs.exceeded(budget or {}):
return {"status":"budget_exceeded", "costs": {"tokens": costs.tokens, "monetary": costs.monetary}}
return {"status":"max_steps_reached", "steps": max_steps, "costs": {"tokens": costs.tokens, "monetary": costs.monetary}}
test_react_agent.py
"""
Unit tests for react_agent.py
"""
import unittest
import json
import asyncio
import time
from react_agent import (
extract_first_json_object, strict_json_parse, lenient_extract_and_parse,
validate_agent_output_schema, DummyLLM, parse_and_validate,
Tool, ToolRegistry, CostCounter, react_single_thread
)
class TestJSONExtraction(unittest.TestCase):
def test_extract_balanced_json(self):
text = "prefix {\"a\":1, \"b\":{\"c\":2}} suffix"
js = extract_first_json_object(text)
self.assertIsNotNone(js)
self.assertEqual(json.loads(js), {"a":1,"b":{"c":2}})
def test_no_json(self):
text = "no json here"
js = extract_first_json_object(text)
self.assertIsNone(js)
class TestParsingAndValidation(unittest.TestCase):
def test_strict_parse_valid(self):
raw = json.dumps({"thought":"t","action":None,"answer":None,"confidence":0.5})
parsed = parse_and_validate(raw, llm=DummyLLM(), max_retries=0)
self.assertEqual(parsed["thought"], "t")
def test_lenient_extract_and_parse(self):
raw = "Some explanation... {\"thought\":\"lenient\",\"action\":null,\"answer\":null,\"confidence\":0.1} end"
parsed = parse_and_validate(raw, llm=DummyLLM(), max_retries=0)
self.assertEqual(parsed["thought"], "lenient")
def test_llm_assist_fix(self):
# raw is invalid; DummyLLM will return corrected JSON if requests contains special mapping
raw = "I will not output JSON"
fixed = json.dumps({"thought":"fixed","action":None,"answer":"done","confidence":0.9})
llm = DummyLLM(responses={"Please output": fixed})
# We craft parse_and_validate to call llm when max_retries>0 -- here single retry
parsed = parse_and_validate(raw, llm=llm, max_retries=1)
self.assertEqual(parsed["answer"], "done")
class TestToolRegistryAndCost(unittest.TestCase):
def test_tool_call_and_cost_accumulation(self):
def search_fn(inp):
q = inp.get("q","")
return {"results":[{"title":"r","url":"u","snippet":"s"}], "monetary_cost": 0.005}
search_tool = Tool(id="search.v1", name="search", call_fn=search_fn)
reg = ToolRegistry()
reg.register(search_tool)
tool = reg.get("search")
obs = tool.call({"q":"test"})
self.assertIn("results", obs)
cost = CostCounter()
cost.add_from_tool_obs(obs)
self.assertAlmostEqual(cost.monetary, 0.005, places=6)
class TestReactLoop(unittest.TestCase):
def test_react_single_thread_answer(self):
# LLM will return a direct valid JSON with an answer
fixed = json.dumps({"thought":"final","action":None,"answer":"42","confidence":1.0})
llm = DummyLLM(responses={"Produce JSON output.": fixed})
reg = ToolRegistry()
state = {"summary":{"goal":"compute 6*7"}}
res = react_single_thread(state, llm, reg, budget={"tokens":1000, "monetary":1.0}, max_steps=3)
self.assertEqual(res["status"], "success")
self.assertEqual(res["answer"], "42")
def test_react_single_thread_tool_call(self):
# LLM returns an action calling 'search', then second LLM call returns answer
first = json.dumps({"thought":"do search","action":{"tool":"search","input":{"q":"x"}},"answer":None,"confidence":0.6})
second = json.dumps({"thought":"got it","action":None,"answer":"found","confidence":0.9})
llm = DummyLLM(responses={"Produce JSON output.": first, "State summary": second})
def search_fn(inp):
return {"results":[{"title":"r"}], "monetary_cost": 0.002}
reg = ToolRegistry()
reg.register(Tool(id="search.v1", name="search", call_fn=search_fn))
state = {"summary":{"goal":"find x"}}
res = react_single_thread(state, llm, reg, budget={"tokens":1000, "monetary":1.0}, max_steps=5)
self.assertEqual(res["status"], "success")
self.assertEqual(res["answer"], "found")
if __name__ == '__main__':
unittest.main(verbosity=2)
1565

被折叠的 条评论
为什么被折叠?



