1. 概述:大脑树(Tree of Thoughts, ToT)算法
大脑树(ToT) 由 Yao 等人提出,是一种通用的大型语言模型(LLM)代理搜索算法。它结合了反思/评估和简单搜索(如广度优先搜索 BFS,当然也可以应用深度优先搜索 DFS 或其他算法)。ToT 主要包括三个步骤:
- 扩展(Expand):生成一个或多个问题的候选解决方案。
- 评分(Score):衡量这些解决方案的质量。
- 修剪(Prune):保留评分最高的前 K 个候选方案。
如果没有找到解决方案或解决方案质量不足,则返回“扩展”步骤继续迭代。
2. 前提条件(Prerequisites)
首先,我们需要安装教程所依赖的包,并设置所选择的 LLM 提供商的 API 密钥。
%%capture --no-stderr
%pip install -U langgraph langchain-openai
import getpass
import os
def _set_env(var: str):
if not os.environ.get(var):
os.environ[var] = getpass.getpass(f"{var}: ")
_set_env("OPENAI_API_KEY")
# 如果需要可视化算法
trace = True
if trace:
_set_env("LANGSMITH_API_KEY")
os.environ["LANGSMITH_PROJECT"] = "ToT Tutorial"
解释:
- 安装依赖包:使用
%pip install安装最新版本的langgraph和langchain-openai。 - 设置环境变量:通过
getpass.getpass安全地获取用户输入的 API 密钥,并将其设置为环境变量。这里需要设置OPENAI_API_KEY,如果启用了追踪(trace),还需要设置LANGSMITH_API_KEY和LANGSMITH_PROJECT。
3. 任务定义(Task Definition)
我们的代理将尝试玩“24 点游戏”。给定 4 个数字,生成一个数学方程式,使用每个数字恰好一次,结果为 24。
import operator
from typing import List, Literal, Union, NamedTuple, Optional
from pydantic import BaseModel, Field
OperatorType = Literal["+", "-", "*", "/"]
TokenType = Union[float, OperatorType]
解释:
- 导入必要的模块:包括
operator模块用于数学运算,typing模块用于类型提示,pydantic用于数据验证。 - 定义类型:
OperatorType:仅限于加、减、乘、除四种运算符。TokenType:可以是浮点数或运算符。
定义方程式的模型:
class Equation(BaseModel):
"""结合提供的数字以达到 24 的公式。"""
tokens: List[TokenType] = Field(
description="逆波兰表示法的令牌栈。例如:[3, 4, '+', -1, '*'] 将计算为 (3 + 4) * -1 = -7。",
)
def compute(self) -> float:
op_funcs = {
"+": operator.add,
"-": operator.sub,
"*": operator.mul,
"/": operator.truediv,
}
stack = []
for token in self.tokens:
if isinstance(token, float):
stack.append(token)
else:
b, a = stack.pop(), stack.pop()
stack.append(op_funcs[token](a, b))
return stack[0]
解释:
Equation类:tokens:使用逆波兰表示法(后缀表示法)存储方程的令牌,例如[3, 4, '+', -1, '*']代表(3 + 4) * -1。compute方法:计算方程的值,使用栈来解析逆波兰表达式。
class GuessEquations(BaseModel):
"""提交多个方程作为猜测。"""
reasoning: str = Field(
description="提交的猜测背后的推理。解释如何得出这些方程。",
)
equations: List[Equation] = Field(
description="要提交的方程列表。",
)
解释:
GuessEquations类:reasoning:描述提交的方程背后的思路。equations:提交的方程列表。
定义候选者模型:
class Candidate(NamedTuple):
candidate: Equation
score: Optional[float] = None
feedback: Optional[str] = None
def __str__(self):
try:
computed = self.candidate.compute()
except Exception as e:
computed = f"无效的方程: {self.candidate.tokens}; 错误: {repr(e)}"
return f"Equation({self.candidate.tokens}) = {computed} (奖励: {self.score})"
class ScoredCandidate(Candidate):
candidate: Equation
score: float
feedback: str
解释:
-
Candidate类:- 表示一个候选解决方案,包括
Equation对象、评分和反馈。 __str__方法用于打印候选者的信息。
- 表示一个候选解决方案,包括
-
ScoredCandidate类:- 继承自
Candidate,确保score和feedback是必填项。
- 继承自
4. 获取数据(Fetch data)
我们将使用“24 点游戏”数据集中的一个示例。
import requests
import csv
csv_data = requests.get(
"https://storage.googleapis.com/benchmarks-artifacts/game-of-24/24.csv"
).content.decode("utf-8")
# 获取“Puzzles”列(列索引为1)
puzzles = [row[1].strip() for row in csv.reader(csv_data.splitlines()[1:])]
print(f"示例谜题: {puzzles[:3]}")
解释:
- 获取 CSV 数据:从指定的 URL 下载 CSV 文件并解码为字符串。
- 解析 CSV:使用
csv.reader解析 CSV 数据,提取第二列(索引为1)的“Puzzles”数据。 - 打印示例谜题:展示前三个谜题,例如
['1 1 4 6', '1 1 11 11', '1 1 3 8']。
5. 扩展器(Expander)
ToT 算法是相对通用的。主要的两个任务特定组件是扩展器和评分器。扩展器(增强的 LLM)尝试生成一个或多个问题的解决方案。在后续尝试中,它会接收来自之前搜索的种子/候选值。
定义扩展器:
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"你正在玩24点游戏。使用提供的数字,创建一个等于24的方程式。\n"
"本轮请提交恰好{k}个猜测。",
),
("user", "为这些数字解决24点游戏: {problem}.{candidate}"),
],
).partial(candidate="")
llm = ChatOpenAI(model="gpt-4o-mini")
bound_llm = llm.with_structured_output(GuessEquations)
solver = prompt | bound_llm
# API参考: ChatPromptTemplate | ChatOpenAI
解释:
-
ChatPromptTemplate:- 定义了与 LLM 交互的提示模板。
- 系统消息:指示 LLM 参与“24 点游戏”,并提交
{k}个猜测。 - 用户消息:提供具体的问题
{problem}和之前的候选{candidate}。
-
ChatOpenAI:- 使用 OpenAI 的 GPT-4 模型(这里是
gpt-4o-mini,可能是一个简化版本)。
- 使用 OpenAI 的 GPT-4 模型(这里是
-
绑定结构化输出:
with_structured_output(GuessEquations):确保 LLM 的输出符合GuessEquations模型的结构。
-
组合提示和 LLM:
- 使用管道操作符
|将提示模板和 LLM 连接起来,形成一个完整的求解器solver。
- 使用管道操作符
6. 评分器(Scorer)
在这个游戏中,评分器相对简单。我们需要确认两点:
- LLM 生成了一个有效的方程,使用每个数字恰好一次。
- 方程的结果等于 24。
定义评分函数:
def compute_score(problem: str, candidate: Candidate) -> ScoredCandidate:
numbers = list(map(int, problem.split()))
# 检查候选方程是否使用了所有4个数字且每个数字恰好一次
used_numbers = [
token for token in candidate.candidate.tokens if isinstance(token, float)
]
if sorted(used_numbers) != sorted(numbers):
score = 0
feedback = "方程必须使用所有4个数字且每个数字恰好一次。"
return ScoredCandidate(
candidate=candidate.candidate, score=score, feedback=feedback
)
try:
result = candidate.candidate.compute()
score = 1 / (1 + abs(24 - result))
feedback = f"结果: {result}"
except Exception as e:
score = 0
feedback = f"无效的方程。错误: {repr(e)}"
return ScoredCandidate(
candidate=candidate.candidate, score=score, feedback=feedback
)
解释:
-
参数:
problem:当前的谜题(四个数字)。candidate:候选的方程。
-
评分逻辑:
-
验证数字使用情况:
- 提取候选方程中使用的数字,检查是否恰好使用了问题中的所有数字。
- 如果不匹配,评分为 0,反馈为“方程必须使用所有4个数字且每个数字恰好一次。”
-
计算结果:
- 如果数字使用正确,尝试计算方程的结果。
- 评分公式:
1 / (1 + abs(24 - result)),即结果越接近 24,评分越高。 - 如果计算过程中出现错误(如除以零),评分为 0,反馈包含错误信息。
-
7. 构建图(Graph)
接下来,我们需要创建 LangGraph 图,该图将包括扩展、评分和修剪三个节点,并定义它们之间的连接关系。
定义图的状态和配置:
import operator
from typing import Optional, Dict, Any
from typing_extensions import Annotated, TypedDict
from langgraph.graph import StateGraph
from langchain_core.runnables import RunnableConfig
from langgraph.constants import Send
from langgraph.checkpoint.memory import MemorySaver
def update_candidates(
existing: Optional[list] = None,
updates: Optional[Union[list, Literal["clear"]]] = None,
) -> List[str]:
if existing is None:
existing = []
if updates is None:
return existing
if updates == "clear":
return []
# 连接现有列表和更新列表
return existing + updates
class ToTState(TypedDict):
problem: str
candidates: Annotated[List[Candidate], update_candidates]
scored_candidates: Annotated[List[ScoredCandidate], update_candidates]
depth: Annotated[int, operator.add]
class Configuration(TypedDict, total=False):
max_depth: int
threshold: float
k: int
beam_size: int
解释:
-
ToTState:problem:当前的谜题。candidates:当前的候选方程列表,使用update_candidates函数进行更新。scored_candidates:评分后的候选方程列表,同样使用update_candidates。depth:当前的搜索深度,每次迭代增加。
-
Configuration:max_depth:搜索的最大深度,默认 10。threshold:评分阈值,默认 0.9。k:每轮生成的猜测数量,默认 5。beam_size:修剪时保留的候选数量,默认 3。
辅助函数:
def _ensure_configurable(config: RunnableConfig) -> Configuration:
"""获取配置搜索算法的参数。"""
configurable = config.get("configurable", {})
return {
**configurable,
"max_depth": configurable.get("max_depth", 10),
"threshold": config.get("threshold", 0.9),
"k": configurable.get("k", 5),
"beam_size": configurable.get("beam_size", 3),
}
解释:
_ensure_configurable:- 从
RunnableConfig中提取配置参数,提供默认值。
- 从
定义扩展函数:
class ExpansionState(ToTState):
seed: Optional[Candidate]
def expand(state: ExpansionState, *, config: RunnableConfig) -> Dict[str, List[str]]:
"""生成下一个状态。"""
configurable = _ensure_configurable(config)
if not state.get("seed"):
candidate_str = ""
else:
candidate_str = "\n\n" + str(state["seed"])
try:
equation_submission = solver.invoke(
{
"problem": state["problem"],
"candidate": candidate_str,
"k": configurable["k"],
},
config=config,
)
except Exception:
return {"candidates": []}
new_candidates = [
Candidate(candidate=equation) for equation in equation_submission.equations
]
return {"candidates": new_candidates}
解释:
expand函数:- 使用
solver(之前定义的扩展器)生成新的候选方程。 - 如果有
seed候选者,将其转换为字符串并附加到提示中,以指导 LLM 生成更好的猜测。 - 返回新生成的候选方程列表。
- 使用
定义评分函数:
def score(state: ToTState) -> Dict[str, List[float]]:
"""评估候选生成。"""
candidates = state["candidates"]
scored = []
for candidate in candidates:
scored.append(compute_score(state["problem"], candidate))
return {"scored_candidates": scored, "candidates": "clear"}
解释:
score函数:- 遍历当前候选方程,对每个候选进行评分。
- 将评分结果存储在
scored_candidates中,并清空candidates列表。
定义修剪函数:
def prune(
state: ToTState, *, config: RunnableConfig
) -> Dict[str, List[Dict[str, Any]]]:
scored_candidates = state["scored_candidates"]
beam_size = _ensure_configurable(config)["beam_size"]
organized = sorted(
scored_candidates, key=lambda candidate: candidate.score, reverse=True
)
pruned = organized[:beam_size]
return {
# 更新下一次迭代的起点
"candidates": pruned,
# 清除旧的评分结果
"scored_candidates": "clear",
# 深度增加 1
"depth": 1,
}
解释:
prune函数:- 根据评分对候选方程进行排序,保留评分最高的
beam_size个。 - 更新
candidates为修剪后的候选者,清空scored_candidates,并增加搜索深度。
- 根据评分对候选方程进行排序,保留评分最高的
定义终止条件:
def should_terminate(
state: ToTState, config: RunnableConfig
) -> Union[Literal["__end__"], Send]:
configurable = _ensure_configurable(config)
solved = state["candidates"][0].score >= configurable["threshold"]
if solved or state["depth"] >= configurable["max_depth"]:
return "__end__"
return [
Send("expand", {**state, "seed": candidate})
for candidate in state["candidates"]
]
解释:
should_terminate函数:- 检查是否满足终止条件:
- 是否找到评分超过阈值的解决方案。
- 是否达到最大搜索深度。
- 如果满足其中任一条件,返回
"__end__",终止搜索。 - 否则,继续进行扩展操作,对每个候选者发送
expand请求。
- 检查是否满足终止条件:
构建图:
# 创建图构建器
builder = StateGraph(state_schema=ToTState, config_schema=Configuration)
# 添加节点
builder.add_node(expand)
builder.add_node(score)
builder.add_node(prune)
# 添加边
builder.add_edge("expand", "score")
builder.add_edge("score", "prune")
builder.add_conditional_edges("prune", should_terminate, path_map=["expand", "__end__"])
# 设置入口点
builder.add_edge("__start__", "expand")
# 编译图
graph = builder.compile(checkpointer=MemorySaver())
# API参考: RunnableConfig | StateGraph | Send | MemorySaver
解释:
-
创建
StateGraph:- 定义图的状态模式
ToTState和配置模式Configuration。
- 定义图的状态模式
-
添加节点:
expand:扩展节点。score:评分节点。prune:修剪节点。
-
添加边:
- 从
expand到score。 - 从
score到prune。 - 从
prune根据should_terminate函数的结果决定下一步:- 如果终止,结束图执行。
- 否则,返回
expand继续迭代。
- 从
-
设置入口点:
- 从
__start__到expand。
- 从
-
编译图:
- 使用
MemorySaver作为检查点,编译完成后的图对象为graph。
- 使用
8. 运行(Run)
现在,我们可以尝试在其中一个谜题上运行这个图。
可视化图:
from IPython.display import Image, display
display(Image(graph.get_graph().draw_mermaid_png()))
解释:
- 可视化图:
- 使用
draw_mermaid_png方法绘制图的结构,并在 Jupyter Notebook 中显示。
- 使用
执行图:
config = {
"configurable": {
"thread_id": "test_1",
"depth": 10,
}
}
for step in graph.stream({"problem": puzzles[42]}, config):
print(step)
解释:
-
配置参数:
thread_id:标识当前线程,方便跟踪。depth:最大搜索深度,设置为 10。
-
运行图:
- 使用
graph.stream方法逐步执行图,并打印每一步的状态。
- 使用
示例输出解析:
{'expand': {'candidates': [Candidate(...), ...]}}
{'score': {'candidates': 'clear', 'scored_candidates': [ScoredCandidate(...), ...]}}
{'prune': {'candidates': [ScoredCandidate(...), ...], 'scored_candidates': 'clear', 'depth': 1}}
...
{'prune': {'candidates': [ScoredCandidate(...), ...], 'scored_candidates': 'clear', 'depth': 1}}
final_state = graph.get_state(config)
winning_solution = final_state.values["candidates"][0]
search_depth = final_state.values["depth"]
if winning_solution.score == 1:
print(f"在 {search_depth} 步内找到一个获胜的解决方案: {winning_solution}")
else:
print(
f"在 {search_depth} 步内未能找到获胜的解决方案。最佳猜测: {winning_solution}"
)
解释:
-
逐步执行:
- 每一步
expand、score和prune都会打印当前的状态,包括生成的候选方程、评分结果和修剪后的候选者。
- 每一步
-
获取最终状态:
- 使用
graph.get_state(config)获取最终的搜索状态。 - 提取
winning_solution和search_depth。
- 使用
-
输出结果:
- 如果找到评分为 1 的解决方案(即结果精确等于 24),则打印成功信息。
- 否则,打印失败信息及最佳猜测。
示例输出:
在 2 步内找到一个获胜的解决方案: [Equation(tokens=[1.0, 5.0, 7.0, '*', 12.0, '-', '+']), 1.0, '结果: 24.0']
解释:
- 成功找到解决方案:
- 在第 2 步搜索中,找到一个评分为 1 的解决方案,即
(1 + (5 * 7)) - 12 = 24。
- 在第 2 步搜索中,找到一个评分为 1 的解决方案,即
9. 总结
本案例展示了如何使用 LangGraph 和 ToT 算法构建一个智能代理,解决“24 点游戏”。通过以下几个关键步骤:
- 定义任务和数据模型:包括方程表示、候选者和评分模型。
- 数据获取:从外部数据源加载游戏谜题。
- 构建扩展器和评分器:使用 LLM 生成候选方程,并评估其质量。
- 构建搜索图:定义扩展、评分和修剪节点及其连接关系。
- 执行搜索:运行图,迭代生成和评估候选,直到找到满意的解决方案或达到最大搜索深度。
通过这种方式,您可以利用 LLM 的强大能力,结合结构化的搜索策略,有效地解决复杂的组合优化问题。
汇总
当然,以下是将您提供的 LangGraph 入门案例的相关代码汇总到一个完整的 tot_game_of_24.py 文件中。所有提示词和注释均已翻译为中文,以便您更好地理解和执行代码。
请确保在运行此脚本之前,已按照以下步骤操作:
-
安装必要的包:
在终端或命令提示符中运行以下命令,以安装所需的 Python 包:
pip install -U langgraph langchain-openai pydantic requests typing-extensions -
设置环境变量:
运行脚本时,系统会提示您输入
OPENAI_API_KEY。如果启用了追踪功能(trace = True),还需要输入LANGSMITH_API_KEY。
以下是完整的 tot_game_of_24.py 文件内容:
# tot_game_of_24.py
import getpass
import os
import operator
from typing import List, Literal, Union, NamedTuple, Optional, Dict, Any
from typing_extensions import Annotated, TypedDict
import requests
import csv
from pydantic import BaseModel, Field
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langgraph.graph import StateGraph
from langchain_core.runnables import RunnableConfig
from langgraph.constants import Send
from langgraph.checkpoint.memory import MemorySaver
# 设置环境变量的函数
def _set_env(var: str):
if not os.environ.get(var):
os.environ[var] = getpass.getpass(f"{var}:")
# 设置 OPENAI_API_KEY
_set_env("OPENAI_API_KEY")
# 如果需要可视化算法
trace = True
if trace:
_set_env("LANGSMITH_API_KEY")
os.environ["LANGSMITH_PROJECT"] = "ToT Tutorial"
# 定义运算符类型和令牌类型
OperatorType = Literal["+", "-", "*", "/"]
TokenType = Union[float, OperatorType]
# 定义方程式的模型
class Equation(BaseModel):
"""结合提供的数字以达到 24 的公式。"""
tokens: List[TokenType] = Field(
description="逆波兰表示法的令牌栈。例如:[3, 4, '+', -1, '*'] 将计算为 (3 + 4) * -1 = -7。",
)
def compute(self) -> float:
op_funcs = {
"+": operator.add,
"-": operator.sub,
"*": operator.mul,
"/": operator.truediv,
}
stack = []
for token in self.tokens:
if isinstance(token, float):
stack.append(token)
else:
if len(stack) < 2:
raise ValueError("栈中元素不足以进行运算。")
b, a = stack.pop(), stack.pop()
try:
result = op_funcs[token](a, b)
except ZeroDivisionError:
raise ValueError("除以零错误。")
stack.append(result)
if len(stack) != 1:
raise ValueError("最终栈中元素数量不正确。")
return stack[0]
# 定义猜测方程式的模型
class GuessEquations(BaseModel):
"""提交多个方程作为猜测。"""
reasoning: str = Field(
description="提交的猜测背后的推理。解释如何得出这些方程。",
)
equations: List[Equation] = Field(
description="要提交的方程列表。",
)
# 定义候选者模型
class Candidate(NamedTuple):
candidate: Equation
score: Optional[float] = None
feedback: Optional[str] = None
def __str__(self):
try:
computed = self.candidate.compute()
except Exception as e:
computed = f"无效的方程: {self.candidate.tokens}; 错误: {repr(e)}"
return f"Equation({self.candidate.tokens}) = {computed} (奖励: {self.score})"
# 定义评分后的候选者模型
class ScoredCandidate(Candidate):
candidate: Equation
score: float
feedback: str
# 获取数据函数
def fetch_puzzles() -> List[str]:
"""从指定的URL获取24点游戏的谜题数据。"""
csv_url = "https://storage.googleapis.com/benchmarks-artifacts/game-of-24/24.csv"
response = requests.get(csv_url)
if response.status_code != 200:
raise ValueError(f"无法获取数据,状态码:{response.status_code}")
csv_data = response.content.decode("utf-8")
# 获取“Puzzles”列(列索引为1)
puzzles = [row[1].strip() for row in csv.reader(csv_data.splitlines()[1:])]
return puzzles
# 调用获取数据函数并打印示例谜题
puzzles = fetch_puzzles()
print(f"示例谜题: {puzzles[:3]}")
# 定义扩展器(Expander)
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"你正在玩24点游戏。使用提供的数字,创建一个等于24的方程式。\n"
"本轮请提交恰好{k}个猜测。",
),
("user", "为这些数字解决24点游戏: {problem}.{candidate}"),
],
).partial(candidate="")
llm = ChatOpenAI(model="gpt-4o-mini")
bound_llm = llm.with_structured_output(GuessEquations)
solver = prompt | bound_llm
# API参考: ChatPromptTemplate | ChatOpenAI
# 定义评分器(Scorer)
def compute_score(problem: str, candidate: Candidate) -> ScoredCandidate:
numbers = list(map(int, problem.split()))
# 检查候选方程是否使用了所有4个数字且每个数字恰好一次
used_numbers = [
token for token in candidate.candidate.tokens if isinstance(token, float)
]
if sorted(used_numbers) != sorted(numbers):
score = 0
feedback = "方程必须使用所有4个数字且每个数字恰好一次。"
return ScoredCandidate(
candidate=candidate.candidate, score=score, feedback=feedback
)
try:
result = candidate.candidate.compute()
score = 1 / (1 + abs(24 - result))
feedback = f"结果: {result}"
except Exception as e:
score = 0
feedback = f"无效的方程。错误: {repr(e)}"
return ScoredCandidate(
candidate=candidate.candidate, score=score, feedback=feedback
)
# 定义图的状态和配置
class ToTState(TypedDict):
problem: str
candidates: Annotated[List[Candidate], "update_candidates"]
scored_candidates: Annotated[List[ScoredCandidate], "update_candidates"]
depth: Annotated[int, "operator.add"]
class Configuration(TypedDict, total=False):
max_depth: int
threshold: float
k: int
beam_size: int
# 辅助函数:更新候选者列表
def update_candidates(
existing: Optional[List[Union[Candidate, ScoredCandidate]]] = None,
updates: Optional[Union[List[Union[Candidate, ScoredCandidate]], Literal["clear"]]] = None,
) -> List[Union[Candidate, ScoredCandidate]]:
if existing is None:
existing = []
if updates is None:
return existing
if updates == "clear":
return []
# 连接现有列表和更新列表
return existing + updates
# 辅助函数:确保配置参数
def _ensure_configurable(config: RunnableConfig) -> Configuration:
"""获取配置搜索算法的参数。"""
configurable = config.get("configurable", {})
return {
**configurable,
"max_depth": configurable.get("max_depth", 10),
"threshold": configurable.get("threshold", 0.9),
"k": configurable.get("k", 5),
"beam_size": configurable.get("beam_size", 3),
}
# 定义扩展状态
class ExpansionState(ToTState):
seed: Optional[Candidate]
# 定义扩展函数
def expand(state: ExpansionState, *, config: RunnableConfig) -> Dict[str, List[Candidate]]:
"""生成下一个状态。"""
configurable = _ensure_configurable(config)
if not state.get("seed"):
candidate_str = ""
else:
candidate_str = "\n\n" + str(state["seed"])
try:
equation_submission = solver.invoke(
{
"problem": state["problem"],
"candidate": candidate_str,
"k": configurable["k"],
},
config=config,
)
except Exception as e:
print(f"扩展过程中出现错误: {e}")
return {"candidates": []}
new_candidates = [
Candidate(candidate=equation) for equation in equation_submission.equations
]
return {"candidates": new_candidates}
# 定义评分函数
def score(state: ToTState) -> Dict[str, List[ScoredCandidate]]:
"""评估候选生成。"""
candidates = state["candidates"]
scored = []
for candidate in candidates:
scored_candidate = compute_score(state["problem"], candidate)
scored.append(scored_candidate)
return {"scored_candidates": scored, "candidates": "clear"}
# 定义修剪函数
def prune(
state: ToTState, *, config: RunnableConfig
) -> Dict[str, List[Dict[str, Any]]]:
"""修剪候选者列表,保留评分最高的beam_size个。"""
scored_candidates = state["scored_candidates"]
beam_size = _ensure_configurable(config)["beam_size"]
organized = sorted(
scored_candidates, key=lambda candidate: candidate.score, reverse=True
)
pruned = organized[:beam_size]
return {
# 更新下一次迭代的起点
"candidates": pruned,
# 清除旧的评分结果
"scored_candidates": "clear",
# 深度增加 1
"depth": 1,
}
# 定义终止条件函数
def should_terminate(
state: ToTState, config: RunnableConfig
) -> Union[Literal["__end__"], Send]:
"""判断是否满足终止条件。"""
configurable = _ensure_configurable(config)
if not state["candidates"]:
return "__end__"
solved = state["candidates"][0].score >= configurable["threshold"]
if solved or state["depth"] >= configurable["max_depth"]:
return "__end__"
return [
Send("expand", {**state, "seed": candidate})
for candidate in state["candidates"]
]
# 创建图构建器
builder = StateGraph(state_schema=ToTState, config_schema=Configuration)
# 添加节点
builder.add_node(expand)
builder.add_node(score)
builder.add_node(prune)
# 添加边
builder.add_edge("expand", "score")
builder.add_edge("score", "prune")
builder.add_conditional_edges("prune", should_terminate, path_map=["expand", "__end__"])
# 设置入口点
builder.add_edge("__start__", "expand")
# 编译图
graph = builder.compile(checkpointer=MemorySaver())
# API参考: RunnableConfig | StateGraph | Send | MemorySaver
# 可视化图(需要在支持的环境中运行,如Jupyter Notebook)
# from IPython.display import Image, display
# display(Image(graph.get_graph().draw_mermaid_png()))
# 运行图
def run_tot(puzzle_index: int):
"""在指定的谜题上运行ToT算法。"""
if puzzle_index >= len(puzzles):
print(f"索引 {puzzle_index} 超出谜题列表范围。")
return
config = {
"configurable": {
"thread_id": "test_1",
"max_depth": 10,
}
}
initial_state = {"problem": puzzles[puzzle_index]}
print(f"正在解决的谜题: {initial_state['problem']}")
try:
for step in graph.stream(initial_state, config):
print(step)
except Exception as e:
print(f"运行过程中出现错误: {e}")
final_state = graph.get_state(config)
if not final_state:
print("未能获取最终状态。")
return
if "candidates" not in final_state.values:
print("最终状态中缺少候选者信息。")
return
if not final_state.values["candidates"]:
print("未找到任何候选解决方案。")
return
winning_solution = final_state.values["candidates"][0]
search_depth = final_state.values["depth"]
if winning_solution.score == 1:
print(f"在 {search_depth} 步内找到一个获胜的解决方案: {winning_solution}")
else:
print(
f"在 {search_depth} 步内未能找到获胜的解决方案。最佳猜测: {winning_solution}"
)
# 示例运行
if __name__ == "__main__":
# 选择要解决的谜题索引,例如第43个谜题(索引42)
puzzle_index = 42
run_tot(puzzle_index)
代码说明
-
环境设置:
- 安装包:确保已安装
langgraph,langchain-openai,pydantic,requests, 和typing-extensions。 - 环境变量:脚本会提示您输入
OPENAI_API_KEY,如果启用了追踪功能,还需要输入LANGSMITH_API_KEY。这些密钥将被安全地存储为环境变量。
- 安装包:确保已安装
-
数据模型:
Equation:表示一个数学方程,使用逆波兰表示法(后缀表示法)存储令牌,并提供计算方法。GuessEquations:用于提交多个方程作为猜测,包括推理说明。Candidate和ScoredCandidate:表示候选解决方案及其评分和反馈。
-
数据获取:
fetch_puzzles()函数从指定的 URL 下载 CSV 数据,并提取“Puzzles”列中的谜题。
-
扩展器(Expander):
- 使用
ChatPromptTemplate定义与 LLM 的交互模板,指示其生成等于24的方程式。 - 使用
ChatOpenAI的 GPT-4 模型(gpt-4o-mini)作为语言模型,并绑定结构化输出以符合GuessEquations模型。
- 使用
-
评分器(Scorer):
compute_score()函数验证候选方程是否使用了所有提供的数字,并计算结果与24的接近程度。评分越高,候选方案越接近正确答案。
-
图的构建(Graph Building):
- 定义图的状态 (
ToTState) 和配置 (Configuration)。 - 定义
expand,score, 和prune函数,分别用于生成候选方案、评分候选方案和修剪候选方案列表。 - 使用
StateGraph创建图,添加节点和边,并设置入口点。 - 编译图以生成可执行的
graph对象。
- 定义图的状态 (
-
运行图(Running the Graph):
run_tot()函数接受一个谜题索引,运行 ToT 算法解决该谜题,并打印每一步的状态。- 最终输出是否找到满足条件的解决方案。
执行脚本
-
保存脚本:
将上述代码保存为
tot_game_of_24.py。 -
运行脚本:
在终端或命令提示符中,导航到脚本所在的目录,并运行:
python tot_game_of_24.py脚本将提示您输入
OPENAI_API_KEY,并在启用了追踪功能的情况下,提示输入LANGSMITH_API_KEY。 -
查看结果:
脚本将打印解决过程中的每一步状态,并在最终输出中指示是否找到一个获胜的解决方案。
注意事项
- 模型名称:请确保您拥有访问
gpt-4o-mini模型的权限。如果没有,您可能需要更改为您拥有访问权限的其他模型名称,例如gpt-4。 - 环境支持:如果您希望可视化图结构,可以在支持的环境(如 Jupyter Notebook)中取消注释相关代码部分。
- 错误处理:脚本包含基本的错误处理,但在实际应用中,您可能需要根据具体情况添加更多的异常处理逻辑。
通过运行此脚本,您可以全面理解并体验 LangGraph 和 ToT 算法如何协同工作,以解决“24 点游戏”这一具体任务。如果有任何问题或需要进一步的帮助,请随时提问!
2573

被折叠的 条评论
为什么被折叠?



