概述
Language Agent Tree Search (LATS) 是由Zhou等人提出的一种通用的基于大型语言模型(LLM)的代理搜索算法。LATS结合了反思/评估和搜索(特别是蒙特卡洛树搜索),以实现比类似技术(如ReACT、Reflexion或Tree of Thoughts)更好的整体任务性能。

LATS的四个主要步骤
- 选择(Select):基于步骤2的累计奖励选择最佳的下一步行动。如果找到了解决方案或达到最大搜索深度,则响应;否则继续搜索。
- 扩展和模拟(Expand and Simulate):选择“最佳”的5个潜在行动并并行执行。
- 反思与评估(Reflect + Evaluate):观察这些行动的结果,并基于反思(可能包括外部反馈)对决策进行评分。
- 回传(Backpropagate):根据结果更新根轨迹的分数。
环境搭建
首先,我们需要安装必要的库和工具,包括LangGraph框架、LangChain OpenAI接口,以及用于搜索引擎的Tavily Python。
安装依赖
%%capture --no-stderr
%pip install -U --quiet langchain langgraph langchain_openai
%pip install -U --quiet tavily-python
配置API密钥
我们需要设置OpenAI和Tavily的API密钥。以下代码会检查环境变量中是否已定义这些密钥,如果未定义,则提示输入。
import getpass
import os
def _set_if_undefined(var: str) -> None:
if os.environ.get(var):
return
os.environ[var] = getpass.getpass(var)
_set_if_undefined("OPENAI_API_KEY")
_set_if_undefined("TAVILY_API_KEY")
设置LangSmith
LangSmith用于LangGraph开发,可以帮助快速发现问题并提升LangGraph项目的性能。通过跟踪数据,可以调试、测试和监控使用LangGraph构建的LLM应用。
图状态(Graph State)
LATS基于一种(贪心的)蒙特卡洛树搜索。每个搜索步骤中,它选择具有最高“上置信界”(Upper Confidence Bound,UCB)的节点,该指标平衡了开发(最高平均奖励)和探索(最低访问次数)。从该节点开始,生成N(此处为5)个新的候选行动,并将其添加到树中。搜索在生成有效解决方案或达到最大回滚(搜索树深度)时停止。
树的组成
我们的LangGraph状态将由两个部分组成:
- 搜索树的根节点(root)
- 用户输入(input)
代码实现
定义反思模型
反思用于对代理的输出进行评分。
from pydantic import BaseModel, Field
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
class Reflection(BaseModel):
reflections: str = Field(
description="对回应的充分性、冗余性和整体质量的批评和反思"
)
score: int = Field(
description="候选回应质量的评分,范围从0到10。",
ge=0,
le=10,
)
found_solution: bool = Field(
description="回应是否完全解决了问题或任务。"
)
def as_message(self):
return HumanMessage(
content=f"推理: {self.reflections}\n评分: {self.score}"
)
@property
def normalized_score(self) -> float:
return self.score / 10.0
定义树节点
每个节点代表搜索树中的一个状态,包括消息、反思、价值、访问次数等信息。
import math
from collections import deque
from typing import Optional
class Node:
def __init__(
self,
messages: list[BaseMessage],
reflection: Reflection,
parent: Optional["Node"] = None,
):
self.messages = messages
self.parent = parent
self.children = []
self.value = 0
self.visits = 0
self.reflection = reflection
self.depth = parent.depth + 1 if parent is not None else 1
self._is_solved = reflection.found_solution if reflection else False
if self._is_solved:
self._mark_tree_as_solved()
self.backpropagate(reflection.normalized_score)
def __repr__(self) -> str:
return (
f"<Node value={self.value}, visits={self.visits},"
f" solution={self.messages} reflection={self.reflection}/>"
)
@property
def is_solved(self):
"""如果存在任何解决方案,我们可以结束搜索。"""
return self._is_solved
@property
def is_terminal(self):
return not self.children
@property
def best_child_score(self):
"""返回具有最高价值的子节点。"""
if not self.children:
return None
return max(self.children, key=lambda child: int(child.is_solved) * child.value)
@property
def height(self) -> int:
"""检查我们已经展开了多深的树。"""
if self.children:
return 1 + max([child.height for child in self.children])
return 1
def upper_confidence_bound(self, exploration_weight=1.0):
"""返回UCT评分。这有助于平衡分支的探索与开发。"""
if self.parent is None:
raise ValueError("无法从根节点获取UCT评分")
if self.visits == 0:
return self.value
# 鼓励对高价值轨迹的开发
average_reward = self.value / self.visits
# 鼓励对访问次数较少的轨迹进行探索
exploration_term = math.sqrt(math.log(self.parent.visits) / self.visits)
return average_reward + exploration_weight * exploration_term
def backpropagate(self, reward: float):
"""更新此节点及其父节点的分数。"""
node = self
while node:
node.visits += 1
node.value = (node.value * (node.visits - 1) + reward) / node.visits
node = node.parent
def get_messages(self, include_reflections: bool = True):
if include_reflections:
return self.messages + [self.reflection.as_message()]
return self.messages
def get_trajectory(self, include_reflections: bool = True) -> list[BaseMessage]:
"""获取表示此搜索分支的消息。"""
messages = []
node = self
while node:
messages.extend(
node.get_messages(include_reflections=include_reflections)[::-1]
)
node = node.parent
# 反转最终回溯的轨迹,以正确的顺序返回
return messages[::-1] # 根解决方案,反思,子节点1,...
def _get_all_children(self):
all_nodes = []
nodes = deque()
nodes.append(self)
while nodes:
node = nodes.popleft()
all_nodes.extend(node.children)
for n in node.children:
nodes.append(n)
return all_nodes
def get_best_solution(self):
"""从当前子树中返回最佳解决方案。"""
all_nodes = [self] + self._get_all_children()
best_node = max(
all_nodes,
# 过滤掉所有非终端、非解决方案的轨迹
key=lambda node: int(node.is_terminal and node.is_solved) * node.value,
)
return best_node
def _mark_tree_as_solved(self):
parent = self.parent
while parent:
parent._is_solved = True
parent = parent.parent
定义树状态类型
from typing_extensions import TypedDict
class TreeState(TypedDict):
# 完整的树
root: Node
# 原始输入
input: str
定义语言代理(Language Agent)
我们的代理将包含三个主要的LLM驱动流程:
- 反思(Reflect):基于工具的响应对行动进行评分。
- 初始响应(Initial Response):创建根节点并启动搜索。
- 扩展(Expand):从当前树的最佳位置生成5个候选“下一步”。
初始化LLM
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(model="gpt-4o")
工具(Tools)
在这个示例中,我们将为语言代理提供一个搜索引擎工具,使用Tavily搜索作为工具。
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.utilities.tavily_search import TavilySearchAPIWrapper
from langgraph.prebuilt import ToolNode
search = TavilySearchAPIWrapper()
tavily_tool = TavilySearchResults(api_wrapper=search, max_results=5)
tools = [tavily_tool]
tool_node = ToolNode(tools=tools)
反思(Reflection)
反思链将基于决策和工具响应对代理输出进行评分。我们将在其他两个节点中调用它。
定义反思链
from langchain_core.output_parsers.openai_tools import (
JsonOutputToolsParser,
PydanticToolsParser,
)
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import chain as as_runnable
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"反思并评分以下用户问题的助手回应。",
),
("user", "{input}"),
MessagesPlaceholder(variable_name="candidate"),
]
)
reflection_llm_chain = (
prompt
| llm.bind_tools(tools=[Reflection], tool_choice="Reflection").with_config(
run_name="Reflection"
)
| PydanticToolsParser(tools=[Reflection])
)
@as_runnable
def reflection_chain(inputs) -> Reflection:
tool_choices = reflection_llm_chain.invoke(inputs)
reflection = tool_choices[0]
if not isinstance(inputs["candidate"][-1], AIMessage):
reflection.found_solution = False
return reflection
初始响应(Initial Response)
我们从单一的根节点开始,该节点由此步骤生成。它根据用户输入以工具调用或响应进行回应。
定义初始响应链
from langchain_core.prompt_values import ChatPromptValue
from langchain_core.runnables import RunnableConfig
prompt_template = ChatPromptTemplate.from_messages(
[
(
"system",
"你是一个AI助手。",
),
("user", "{input}"),
MessagesPlaceholder(variable_name="messages", optional=True),
]
)
initial_answer_chain = prompt_template | llm.bind_tools(tools=tools).with_config(
run_name="GenerateInitialCandidate"
)
parser = JsonOutputToolsParser(return_id=True)
生成初始响应
initial_response = initial_answer_chain.invoke(
{"input": "编写一份关于锂污染的研究报告。"}
)
initial_response
起始节点(Starting Node)
将候选生成和反思封装到图的单个节点中。
定义生成初始响应的函数
def generate_initial_response(state: TreeState) -> dict:
"""生成初始候选响应。"""
res = initial_answer_chain.invoke({"input": state["input"]})
parsed = parser.invoke(res)
tool_responses = [
tool_node.invoke(
{
"messages": [
AIMessage(
content="",
tool_calls=[
{"name": r["type"], "args": r["args"], "id": r["id"]}
],
)
]
}
)
for r in parsed
]
output_messages = [res] + [tr["messages"][0] for tr in tool_responses]
reflection = reflection_chain.invoke(
{"input": state["input"], "candidate": output_messages}
)
root = Node(output_messages, reflection=reflection)
return {
**state,
"root": root,
}
候选生成(Candidate Generation)
生成N个额外的候选以进行检查。
定义候选生成函数
# 生成N个候选值
# 用于从环境中采样行动
def generate_candidates(messages: ChatPromptValue, config: RunnableConfig):
n = config["configurable"].get("N", 5)
bound_kwargs = llm.bind_tools(tools=tools).kwargs
chat_result = llm.generate(
[messages.to_messages()],
n=n,
callbacks=config["callbacks"],
run_name="GenerateCandidates",
**bound_kwargs,
)
return [gen.message for gen in chat_result.generations[0]]
expansion_chain = prompt_template | generate_candidates
执行候选生成
res = expansion_chain.invoke({"input": "编写一份关于锂污染的研究报告。"})
res
候选生成节点(Candidate Generation Node)
将候选生成和反思步骤封装到“扩展”节点中。我们将所有操作作为批处理来加速执行。
定义选择和扩展函数
from collections import defaultdict
def select(root: Node) -> dict:
"""从根节点开始,每个树层选择一个子节点,直到达到叶节点。"""
if not root.children:
return root
node = root
while node.children:
max_child = max(node.children, key=lambda child: child.upper_confidence_bound())
node = max_child
return node
def expand(state: TreeState, config: RunnableConfig) -> dict:
"""从树中“最佳”节点开始,生成N个下一步的候选。"""
root = state["root"]
best_candidate: Node = select(root)
messages = best_candidate.get_trajectory()
# 从单个子候选生成N个候选
new_candidates = expansion_chain.invoke(
{"input": state["input"], "messages": messages}, config
)
parsed = parser.batch(new_candidates)
flattened = [
(i, tool_call)
for i, tool_calls in enumerate(parsed)
for tool_call in tool_calls
]
tool_responses = [
(
i,
tool_node.invoke(
{
"messages": [
AIMessage(
content="",
tool_calls=[
{
"name": tool_call["type"],
"args": tool_call["args"],
"id": tool_call["id"],
}
],
)
]
}
),
)
for i, tool_call in flattened
]
collected_responses = defaultdict(list)
for i, resp in tool_responses:
collected_responses[i].append(resp["messages"][0])
output_messages = []
for i, candidate in enumerate(new_candidates):
output_messages.append([candidate] + collected_responses[i])
# 对每个候选进行反思
# 对于有外部验证的任务,您可以在此处添加
reflections = reflection_chain.batch(
[{"input": state["input"], "candidate": msges} for msges in output_messages],
config,
)
# 扩展树
child_nodes = [
Node(cand, parent=best_candidate, reflection=reflection)
for cand, reflection in zip(output_messages, reflections)
]
best_candidate.children.extend(child_nodes)
# 我们已经直接扩展了树,因此只需返回状态
return state
创建图(Create Graph)
定义图的结构,包括节点和边。每个代理步骤后,我们可以选择是否结束搜索。
定义图的条件边
from typing import Literal
from langgraph.graph import END, StateGraph, START
def should_loop(state: TreeState):
"""确定是否继续树搜索。"""
root = state["root"]
if root.is_solved:
return END
if root.height > 5:
return END
return "expand"
builder = StateGraph(TreeState)
builder.add_node("start", generate_initial_response)
builder.add_node("expand", expand)
builder.add_edge(START, "start")
builder.add_conditional_edges(
"start",
# 要么扩展/回滚,要么结束
should_loop,
["expand", END],
)
builder.add_conditional_edges(
"expand",
# 要么继续回滚,要么结束
should_loop,
["expand", END],
)
graph = builder.compile()
可视化图结构
from IPython.display import Image
Image(graph.get_graph().draw_mermaid_png())
调用图(Invoke)
使用定义好的图来处理具体的问题。
示例1:生成鸟类信息表
question = "生成一个表格,列出前5种最常见鸟类的平均大小、重量以及最古老的记录实例。"
last_step = None
for step in graph.stream({"input": question}):
last_step = step
step_name, step_state = next(iter(step.items()))
print(step_name)
print("展开高度: ", step_state["root"].height)
print("---")
输出示例
start
展开高度: 1
---
expand
展开高度: 2
---
获取并显示最佳解决方案
solution_node = last_step["expand"]["root"].get_best_solution()
best_trajectory = solution_node.get_trajectory(include_reflections=False)
print(best_trajectory[-1].content)
示例输出
让我们将信息综合成一个连贯的表格,总结前5种最常见鸟类的平均大小、重量以及最古老的记录实例。
...
### Top 5 Most Common Birds
根据搜索结果,前5种最常见的鸟类是:
1. 家鸡
2. 家麻雀
3. 欧洲椋鸟
4. 环颈鸥
5. 燕子
### 表格:平均大小、重量及最古老的记录实例
| 鸟类 | 平均大小 (cm) | 平均重量 (g) | 最古老的记录实例 |
|--------------------|---------------|--------------|-------------------------|
| 家鸡 | 40-50 | 1,200-2,500 | ~16年(宠物记录) |
| 家麻雀 | 14-18 | 24-40 | 13年 |
| 欧洲椋鸟 | 20-23 | 58-100 | 15年 |
| 环颈鸥 | 48-53 | 300-700 | 23年 |
| 燕子 | 15-20 | 17-20 | 16年 |
### 其他详情
- **家鸡**:根据品种和饮食习惯,平均大小和重量可能有显著差异。记录中最年长的宠物鸡活到了16岁。
- **家麻雀**:常见于城市地区,野外的平均寿命显著较短。
- **欧洲椋鸟**:以其适应能力著称,椋鸟在没有天敌或恶劣条件下具有显著的寿命。
- **环颈鸥**:这些鸥类在北美常见,寿命相对较长。
- **燕子**:以其迁徙习性闻名,考虑到其体型,这些鸟类具有相对较高的寿命。
这个表格现在提供了一个结构化且全面的总结,涵盖了前5种最常见鸟类的平均大小、重量以及最古老的记录实例。
示例2:棋局分析
question = "写出Magnus Carlsen在对阵Alireza Firouzja的比赛中的一系列走子,并提出一个替代策略。"
last_step = None
for step in graph.stream({"input": question}):
last_step = step
step_name, step_state = next(iter(step.items()))
print(step_name)
print("展开高度: ", step_state["root"].height)
print("---")
输出示例
start
展开高度: 1
---
expand
展开高度: 2
---
expand
展开高度: 3
---
expand
展开高度: 3
---
expand
展开高度: 3
---
获取并显示最佳解决方案
solution_node = last_step["expand"]["root"].get_best_solution()
best_trajectory = solution_node.get_trajectory(include_reflections=False)
print(best_trajectory[-1].content)
示例输出
看起来Magnus Carlsen和Alireza Firouzja之间具体的比赛走子在搜索结果中并不容易获取。然而,我可以提供一个典型的高水平选手如Carlsen和Firouzja之间比赛的走子示例,并基于常见的国际象棋原则提出一个替代策略。
### 示例比赛走子(假设性)
以下是Magnus Carlsen和Alireza Firouzja之间一场比赛的假设性走子序列:
1. e4 e5
2. Nf3 Nc6
3. Bb5 a6
4. Ba4 Nf6
5. O-O Be7
6. Re1 b5
7. Bb3 d6
8. c3 O-O
9. h3 Nb8
10. d4 Nbd7
11. Nbd2 Bb7
12. Bc2 Re8
13. Nf1 Bf8
14. Ng3 g6
15. a4 c5
16. d5 c4
17. Be3 Qc7
18. Qd2 Nc5
19. Nh2 Bg7
20. Ng4 Nxg4
21. hxg4 Qd7
22. f3 f6
23. Kf2 Qf7
24. Rh1 Rad8
25. Rh3 Bc8
26. Rah1 h6
27. Bxh6 Bxh6
28. Rxh6 Qg7
29. g5 f5
30. exf5 Bxf5
31. Bxf5 gxf5
32. Nh5 Qf7
33. Nf6+ Kf8
34. Rh8+ Ke7
35. Rxe8+ Rxe8
36. Nxe8 Qxe8
37. Rh7+ Kd8
38. g6 Qg8
39. Qg5+ Kc8
40. Qe7 Qd8
41. Qxd8+ Kxd8
42. g7 Kc7
43. g8=Q+ Kb6
44. Qb8+ Ka5
45. Qd8+ Kxa4
46. g4 fxg4
47. fxg4 Kb3
48. g5 Kxb2
49. Qb6 Kxc3
50. Qxc5 dxc5
51. d6 b4
52. d7 b3
53. d8=Q b2
54. Qd1 b1=Q
55. Rxb1 Kxc4
56. Qc1+ Kd5
57. Qxc3 c4
58. Ke3 Kc6
59. Kd4 Kc7
60. Qxc4+ Kd6
61. Qc5+ Ke6
62. Rb6+ Kf7
63. Qc7+ Ke8
64. Rb8#
### 替代策略
如果考虑Magnus Carlsen使用白棋并采用典型的西班牙开局,可以考虑使用不同的开局或在西班牙开局内的变体。例如:
1. **替代开局:意大利开局**
- 1. e4 e5
- 2. Nf3 Nc6
- 3. Bc4 Bc5
- 4. c3 Nf6
- 5. d4 exd4
- 6. cxd4 Bb4+
- 7. Nc3 Nxe4
- 8. O-O Bxc3
- 9. d5 Ne7
- 10. Qd3 f5
- 11. bxc3 d6
- 12. Nd4 O-O
- 13. f3 Nc5
- 14. Qc2 f4
- 15. Re1 Ng6
- 16. Ba3 Qg5
- 17. Bxc5 dxc5
- 18. Ne6 Bxe6
- 19. dxe6 Ne7
- 20. Rad1 Rad8
- 21. Rd7 Rxd7
- 22. exd7+ Kh8
- 23. Qe4 Nc6
- 24. Bd3 g6
- 25. Qe8 Kg7
- 26. Bb5 Nd8
- 27. Re7+ Kh6
- 28. Qxf8+ Kh5
- 29. Rxh7#
2. **西班牙开局的变体:**
- 白棋可以选择“舒适变体”或“延迟斯坦因防御”。
- 例如,在初始走子后:
- 1. e4 e5
- 2. Nf3 Nc6
- 3. Bb5 a6
- 4. Ba4 d6(延迟斯坦因防御)
- 5. c3 Bg4
- 6. h3 Bh5
- 7. d4 exd4
- 8. cxd4 Be7
- 9. Nc3 Nf6
- 10. O-O O-O
通过改变开局或在特定开局内的策略,Carlsen可以潜在地避免Firouzja的深度准备,并将比赛引入对手不太熟悉的领域。
结论
恭喜您成功实现了LATS!这是一种在解决复杂推理任务时相对快速且有效的技术。以下是一些注意事项:
- 树的展开:虽然有效,但树的展开可能需要额外的计算时间。如果您想将其用于生产应用,建议确保中间步骤能够流式传输(使用户能够看到思考过程或访问中间结果),或者将其用于微调数据以提高单次响应的准确性,避免长时间展开。
- 候选选择过程:候选选择的质量取决于生成的奖励。在此示例中,我们仅使用自我反思,但如果您有外部反馈来源(如代码测试执行),应将其整合到上述步骤中。
通过上述步骤,您可以使用LangGraph实现LATS,并应用于各种复杂任务中,提升LLM代理的性能和准确性。
汇总
当然,我会将上述内容中的所有相关代码汇总到一个完整的 lats_example.py 文件中,并确保所有注释和提示词均已翻译为中文。这将有助于您更好地理解和执行整个示例。
以下是完整的 lats_example.py 文件内容:
# lats_example.py
import math
import os
import getpass
from collections import deque, defaultdict
from typing import Optional, List, Dict
from pydantic import BaseModel, Field
from typing_extensions import TypedDict
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
from langchain_openai import ChatOpenAI
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.utilities.tavily_search import TavilySearchAPIWrapper
from langgraph.prebuilt import ToolNode
from langchain_core.output_parsers.openai_tools import (
JsonOutputToolsParser,
PydanticToolsParser,
)
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import chain as as_runnable
from langchain_core.prompt_values import ChatPromptValue
from langchain_core.runnables import RunnableConfig
from langgraph.graph import END, StateGraph, START
# 设置API密钥
def set_api_keys():
"""设置OpenAI和Tavily的API密钥。如果环境变量中未定义,则提示用户输入。"""
if not os.environ.get("OPENAI_API_KEY"):
os.environ["OPENAI_API_KEY"] = getpass.getpass("请输入您的 OPENAI_API_KEY: ")
if not os.environ.get("TAVILY_API_KEY"):
os.environ["TAVILY_API_KEY"] = getpass.getpass("请输入您的 TAVILY_API_KEY: ")
set_api_keys()
# 定义反思模型
class Reflection(BaseModel):
reflections: str = Field(
description="对回应的充分性、冗余性和整体质量的批评和反思"
)
score: int = Field(
description="候选回应质量的评分,范围从0到10。",
ge=0,
le=10,
)
found_solution: bool = Field(
description="回应是否完全解决了问题或任务。"
)
def as_message(self):
return HumanMessage(
content=f"推理: {self.reflections}\n评分: {self.score}"
)
@property
def normalized_score(self) -> float:
return self.score / 10.0
# 定义树节点
class Node:
def __init__(
self,
messages: List[BaseMessage],
reflection: Reflection,
parent: Optional["Node"] = None,
):
self.messages = messages
self.parent = parent
self.children: List["Node"] = []
self.value = 0
self.visits = 0
self.reflection = reflection
self.depth = parent.depth + 1 if parent is not None else 1
self._is_solved = reflection.found_solution if reflection else False
if self._is_solved:
self._mark_tree_as_solved()
self.backpropagate(reflection.normalized_score)
def __repr__(self) -> str:
return (
f"<Node value={self.value}, visits={self.visits},"
f" solution={self.messages} reflection={self.reflection}/>"
)
@property
def is_solved(self):
"""如果存在任何解决方案,我们可以结束搜索。"""
return self._is_solved
@property
def is_terminal(self):
return not self.children
@property
def best_child_score(self):
"""返回具有最高价值的子节点。"""
if not self.children:
return None
return max(self.children, key=lambda child: int(child.is_solved) * child.value)
@property
def height(self) -> int:
"""检查我们已经展开了多深的树。"""
if self.children:
return 1 + max([child.height for child in self.children])
return 1
def upper_confidence_bound(self, exploration_weight=1.0):
"""返回UCT评分。这有助于平衡分支的探索与开发。"""
if self.parent is None:
raise ValueError("无法从根节点获取UCT评分")
if self.visits == 0:
return self.value
# 鼓励对高价值轨迹的开发
average_reward = self.value / self.visits
# 鼓励对访问次数较少的轨迹进行探索
exploration_term = math.sqrt(math.log(self.parent.visits) / self.visits)
return average_reward + exploration_weight * exploration_term
def backpropagate(self, reward: float):
"""更新此节点及其父节点的分数。"""
node = self
while node:
node.visits += 1
node.value = (node.value * (node.visits - 1) + reward) / node.visits
node = node.parent
def get_messages(self, include_reflections: bool = True):
if include_reflections:
return self.messages + [self.reflection.as_message()]
return self.messages
def get_trajectory(self, include_reflections: bool = True) -> List[BaseMessage]:
"""获取表示此搜索分支的消息。"""
messages = []
node = self
while node:
messages.extend(
node.get_messages(include_reflections=include_reflections)[::-1]
)
node = node.parent
# 反转最终回溯的轨迹,以正确的顺序返回
return messages[::-1] # 根解决方案,反思,子节点1,...
def _get_all_children(self):
all_nodes = []
nodes = deque()
nodes.append(self)
while nodes:
node = nodes.popleft()
all_nodes.extend(node.children)
for n in node.children:
nodes.append(n)
return all_nodes
def get_best_solution(self):
"""从当前子树中返回最佳解决方案。"""
all_nodes = [self] + self._get_all_children()
best_node = max(
all_nodes,
# 过滤掉所有非终端、非解决方案的轨迹
key=lambda node: int(node.is_terminal and node.is_solved) * node.value,
)
return best_node
def _mark_tree_as_solved(self):
parent = self.parent
while parent:
parent._is_solved = True
parent = parent.parent
# 定义树状态类型
class TreeState(TypedDict):
# 完整的树
root: Node
# 原始输入
input: str
# 初始化LLM
llm = ChatOpenAI(model="gpt-4o")
# 定义工具
search = TavilySearchAPIWrapper()
tavily_tool = TavilySearchResults(api_wrapper=search, max_results=5)
tools = [tavily_tool]
tool_node = ToolNode(tools=tools)
# 定义反思链
prompt_reflection = ChatPromptTemplate.from_messages(
[
(
"system",
"反思并评分以下用户问题的助手回应。",
),
("user", "{input}"),
MessagesPlaceholder(variable_name="candidate"),
]
)
reflection_llm_chain = (
prompt_reflection
| llm.bind_tools(tools=[Reflection], tool_choice="Reflection").with_config(
run_name="Reflection"
)
| PydanticToolsParser(tools=[Reflection])
)
@as_runnable
def reflection_chain(inputs) -> Reflection:
tool_choices = reflection_llm_chain.invoke(inputs)
reflection = tool_choices[0]
if not isinstance(inputs["candidate"][-1], AIMessage):
reflection.found_solution = False
return reflection
# 定义初始响应链
prompt_initial = ChatPromptTemplate.from_messages(
[
(
"system",
"你是一个AI助手。",
),
("user", "{input}"),
MessagesPlaceholder(variable_name="messages", optional=True),
]
)
initial_answer_chain = prompt_initial | llm.bind_tools(tools=tools).with_config(
run_name="GenerateInitialCandidate"
)
parser = JsonOutputToolsParser(return_id=True)
# 生成初始响应的函数
def generate_initial_response(state: TreeState) -> Dict:
"""生成初始候选响应。"""
res = initial_answer_chain.invoke({"input": state["input"]})
parsed = parser.invoke(res)
tool_responses = [
tool_node.invoke(
{
"messages": [
AIMessage(
content="",
tool_calls=[
{"name": r["type"], "args": r["args"], "id": r["id"]}
],
)
]
}
)
for r in parsed
]
output_messages = [res] + [tr["messages"][0] for tr in tool_responses]
reflection = reflection_chain.invoke(
{"input": state["input"], "candidate": output_messages}
)
root = Node(output_messages, reflection=reflection)
return {
**state,
"root": root,
}
# 定义候选生成函数
def generate_candidates(messages: ChatPromptValue, config: RunnableConfig):
"""生成N个候选值,用于从环境中采样行动。"""
n = config["configurable"].get("N", 5)
bound_kwargs = llm.bind_tools(tools=tools).kwargs
chat_result = llm.generate(
[messages.to_messages()],
n=n,
callbacks=config["callbacks"],
run_name="GenerateCandidates",
**bound_kwargs,
)
return [gen.message for gen in chat_result.generations[0]]
expansion_chain = prompt_initial | generate_candidates
# 批量定义扩展节点
def expand(state: TreeState, config: RunnableConfig) -> Dict:
"""从树中“最佳”节点开始,生成N个下一步的候选。"""
root = state["root"]
best_candidate: Node = select(root)
messages = best_candidate.get_trajectory()
# 从单个子候选生成N个候选
new_candidates = expansion_chain.invoke(
{"input": state["input"], "messages": messages}, config
)
parsed = parser.batch(new_candidates)
flattened = [
(i, tool_call)
for i, tool_calls in enumerate(parsed)
for tool_call in tool_calls
]
tool_responses = [
(
i,
tool_node.invoke(
{
"messages": [
AIMessage(
content="",
tool_calls=[
{
"name": tool_call["type"],
"args": tool_call["args"],
"id": tool_call["id"],
}
],
)
]
}
),
)
for i, tool_call in flattened
]
collected_responses = defaultdict(list)
for i, resp in tool_responses:
collected_responses[i].append(resp["messages"][0])
output_messages = []
for i, candidate in enumerate(new_candidates):
output_messages.append([candidate] + collected_responses[i])
# 对每个候选进行反思
# 对于有外部验证的任务,您可以在此处添加
reflections = reflection_chain.batch(
[{"input": state["input"], "candidate": msges} for msges in output_messages],
config,
)
# 扩展树
child_nodes = [
Node(cand, parent=best_candidate, reflection=reflection)
for cand, reflection in zip(output_messages, reflections)
]
best_candidate.children.extend(child_nodes)
# 我们已经直接扩展了树,因此只需返回状态
return state
# 定义选择函数
def select(root: Node) -> Node:
"""从根节点开始,每个树层选择一个子节点,直到达到叶节点。"""
if not root.children:
return root
node = root
while node.children:
max_child = max(node.children, key=lambda child: child.upper_confidence_bound())
node = max_child
return node
# 创建图
def create_graph() -> StateGraph:
"""创建LangGraph图结构。"""
def should_loop(state: TreeState):
"""确定是否继续树搜索。"""
root = state["root"]
if root.is_solved:
return END
if root.height > 5:
return END
return "expand"
builder = StateGraph(TreeState)
builder.add_node("start", generate_initial_response)
builder.add_node("expand", expand)
builder.add_edge(START, "start")
builder.add_conditional_edges(
"start",
# 要么扩展/回滚,要么结束
should_loop,
["expand", END],
)
builder.add_conditional_edges(
"expand",
# 要么继续回滚,要么结束
should_loop,
["expand", END],
)
graph = builder.compile()
return graph
# 主函数
def main():
"""主函数,执行LATS示例。"""
graph = create_graph()
# 示例1:生成鸟类信息表
question1 = "生成一个表格,列出前5种最常见鸟类的平均大小、重量以及最古老的记录实例。"
last_step1 = None
print("### 示例1:生成鸟类信息表 ###")
for step in graph.stream({"input": question1}):
last_step1 = step
step_name, step_state = next(iter(step.items()))
print(step_name)
print("展开高度: ", step_state["root"].height)
print("---")
solution_node1 = last_step1["expand"]["root"].get_best_solution()
best_trajectory1 = solution_node1.get_trajectory(include_reflections=False)
print(best_trajectory1[-1].content)
print("\n")
# 示例2:棋局分析
question2 = "写出Magnus Carlsen在对阵Alireza Firouzja的比赛中的一系列走子,并提出一个替代策略。"
last_step2 = None
print("### 示例2:棋局分析 ###")
for step in graph.stream({"input": question2}):
last_step2 = step
step_name, step_state = next(iter(step.items()))
print(step_name)
print("展开高度: ", step_state["root"].height)
print("---")
solution_node2 = last_step2["expand"]["root"].get_best_solution()
best_trajectory2 = solution_node2.get_trajectory(include_reflections=False)
print(best_trajectory2[-1].content)
print("\n")
if __name__ == "__main__":
main()
说明与执行步骤
-
依赖安装
在运行
lats_example.py之前,请确保已安装所有必要的依赖。您可以在终端或命令提示符中运行以下命令来安装所需的库:pip install -U langchain langgraph langchain_openai tavily-python -
配置API密钥
确保您拥有OpenAI和Tavily的API密钥。运行脚本时,如果环境变量中未定义这些密钥,脚本会提示您输入它们。
-
运行脚本
在终端或命令提示符中,导航到包含
lats_example.py文件的目录,然后运行:python lats_example.py -
脚本功能
-
示例1:生成鸟类信息表
脚本将根据您的输入生成一个表格,列出前5种最常见鸟类的平均大小、重量以及最古老的记录实例。
-
示例2:棋局分析
脚本将尝试提供Magnus Carlsen和Alireza Firouzja之间比赛的走子序列,并提出一个替代策略。
-
注意事项
-
计算资源
LATS(Language Agent Tree Search)涉及复杂的搜索树展开,可能需要较多的计算资源和时间,具体取决于任务的复杂性和搜索深度。
-
API限制
使用OpenAI和Tavily的API可能受到速率限制和配额限制,请确保您的API密钥具有足够的权限和配额。
-
错误处理
脚本中未包含详细的错误处理机制。在实际应用中,建议添加适当的异常处理,以应对可能出现的API调用失败或其他运行时错误。
-
可视化
如果您希望可视化搜索树,可以使用适当的图形库(如Graphviz)进一步扩展脚本功能。
当然,我很高兴为您详细讲解蒙特卡洛树搜索(Monte Carlo Tree Search, MCTS)。MCTS是一种用于决策过程的启发式搜索算法,广泛应用于博弈论、人工智能以及其他需要在巨大搜索空间中进行决策的领域。以下内容将从理论基础、算法步骤、关键组件、应用领域、优缺点等方面进行详细介绍。
一、概述
蒙特卡洛树搜索(MCTS) 是一种基于随机采样的搜索算法,用于在复杂的决策空间中找到最优策略。它通过在决策树中进行随机模拟(模拟游走),并根据模拟结果不断扩展和优化树结构,从而逐步逼近最优解。
MCTS特别适用于具有大规模状态空间和复杂决策结构的问题,如围棋、国际象棋、电子游戏等。在这些领域,传统的搜索算法(如α-β剪枝)由于状态空间过大而难以应用,而MCTS能够通过有效的探索和利用策略在有限的计算资源下取得良好的效果。
二、MCTS的基本原理
MCTS通过构建一个决策树来模拟和评估不同的行动序列。其核心思想是通过反复进行四个阶段的操作,不断完善树结构,从而找到最有希望的行动路径。这四个阶段分别是:
- 选择(Selection)
- 扩展(Expansion)
- 模拟(Simulation)
- 反向传播(Backpropagation)
1. 选择(Selection)
从根节点开始,沿着树中已经存在的节点路径向下选择,直到到达一个未完全扩展的节点。选择路径的策略通常基于平衡探索(exploration)和利用(exploitation),常用的选择策略是上置信界(Upper Confidence Bound, UCB)。
UCB公式:
UCB=平均胜率+c×lnNnUCB = \text{平均胜率} + c \times \sqrt{\frac{\ln N}{n}}UCB=平均胜率+c×nlnN
- 平均胜率:节点的累计奖励除以访问次数。
- c:探索参数,控制探索与利用的平衡。
- N:父节点的总访问次数。
- n:当前节点的访问次数。
2. 扩展(Expansion)
在选择阶段到达的节点,如果该节点不是终局状态且尚未完全展开,则从该节点生成一个或多个子节点,通常是选择一个未被探索过的合法行动来扩展树。
3. 模拟(Simulation)
从刚扩展的节点开始,进行一次随机模拟,即随机选择行动直至游戏结束(或达到某个预设深度),以评估该路径的潜在价值。模拟的结果通常是一个胜负或得分。
4. 反向传播(Backpropagation)
将模拟的结果沿着选择和扩展路径向上传播,更新路径上所有节点的统计数据(如访问次数和累计奖励)。这些统计数据将用于后续的选择阶段,以指导搜索方向。
三、MCTS的关键组件
1. 树结构
MCTS构建的是一棵由决策节点组成的树,每个节点代表一个游戏状态或决策点。节点之间通过行动(动作)连接,表示从一个状态采取某个行动转移到另一个状态。
2. 选择策略
选择策略决定了在选择阶段如何选择路径上的下一个节点。常用的选择策略包括:
- UCB1(上置信界1):通过平衡探索和利用来选择节点。
- 改进的UCB:如PUCB(Progressive UCB)等,针对特定应用进行优化。
3. 扩展策略
扩展策略决定了在扩展阶段如何选择子节点进行扩展。通常会选择未被探索过的合法行动,以增加树的多样性和覆盖范围。
4. 模拟策略
模拟策略决定了在模拟阶段如何进行模拟。最简单的是纯随机模拟,但也可以采用启发式方法,如基于评估函数的模拟,以提高模拟的质量和准确性。
5. 反向传播策略
反向传播策略决定了如何更新路径上的节点信息。通常是简单地将模拟结果(如胜负)传递给路径上的所有节点,更新它们的访问次数和累计奖励。
四、MCTS的算法步骤详解
以下是MCTS的详细算法步骤,以围棋游戏为例:
1. 初始化
- 创建根节点,表示当前游戏状态。
- 设置循环次数或时间限制,决定搜索的深度和广度。
2. 迭代搜索
在每次迭代中,执行以下四个阶段:
a. 选择(Selection)
- 从根节点开始,使用选择策略(如UCB1)沿着树路径向下选择子节点,直到到达一个未完全扩展的节点或终局节点。
b. 扩展(Expansion)
- 如果选择的节点不是终局且尚未完全扩展,则从该节点生成一个新的子节点,表示一个新的行动。
c. 模拟(Simulation)
- 从新扩展的子节点开始,进行一次模拟(随机或启发式),直到游戏结束或达到预设深度。
- 记录模拟的结果(如胜负或得分)。
d. 反向传播(Backpropagation)
- 将模拟结果沿着选择和扩展路径向上传播,更新路径上所有节点的访问次数和累计奖励。
3. 决策
- 当达到预设的迭代次数或时间限制后,选择根节点下访问次数最多或累计奖励最高的子节点作为最终的行动决策。
五、MCTS的应用领域
MCTS广泛应用于各种需要在庞大决策空间中进行高效搜索的领域,主要包括:
1. 博弈类游戏
- 围棋:AlphaGo等围棋AI使用MCTS与深度学习结合,取得了突破性的成果。
- 国际象棋:MCTS可以与传统的搜索算法结合,提升决策质量。
- 电子游戏:在复杂的策略游戏中,MCTS用于AI决策和路径规划。
2. 优化问题
- 资源分配:在资源有限的情况下,优化资源分配策略。
- 调度问题:如生产调度、任务分配等,利用MCTS寻找最佳调度方案。
3. 自动化决策系统
- 机器人导航:MCTS用于路径规划和决策,使机器人能够在复杂环境中找到最优路径。
- 推荐系统:在推荐过程中,利用MCTS进行多步决策,提升推荐效果。
4. 自然语言处理
- 对话系统:在多轮对话中,利用MCTS优化对话策略,提高对话的连贯性和有效性。
六、MCTS的优势与劣势
优势
- 适应性强:MCTS不依赖于具体的游戏或问题规则,具有很强的通用性。
- 渐进优化:随着搜索时间的增加,MCTS能够逐步优化决策质量。
- 处理高维度问题:适用于状态空间巨大且复杂的问题,传统搜索算法难以应用的场景。
- 平衡探索与利用:通过选择策略(如UCB1),MCTS能够有效地在探索新行动和利用已知好行动之间找到平衡。
劣势
- 计算资源消耗大:尤其是在状态空间极其庞大的情况下,MCTS可能需要大量的计算时间和资源。
- 依赖模拟质量:MCTS的性能高度依赖于模拟阶段的策略和质量,纯随机模拟可能导致效果不佳。
- 缺乏全局视野:MCTS主要通过局部搜索逐步优化,可能会错过全局最优解。
- 参数敏感:选择策略中的探索参数(如UCB1中的c值)对算法性能有较大影响,需进行调优。
七、MCTS的改进与扩展
为了克服MCTS的局限性,研究人员提出了多种改进和扩展方法,包括:
1. 并行化
通过并行化搜索过程,加快MCTS的搜索速度,提升算法的效率。
2. 混合策略
将MCTS与其他搜索算法(如深度优先搜索、启发式搜索)结合,利用不同算法的优势,提升搜索效果。
3. 学习增强
结合机器学习方法,利用神经网络等模型提升模拟策略和选择策略的质量。例如,AlphaGo将MCTS与深度神经网络结合,实现了围棋AI的突破性进展。
4. 自适应探索
根据搜索过程中的反馈,自适应调整探索与利用的平衡参数,提高搜索的效率和效果。
5. 局部优化
在搜索过程中,利用局部优化方法提升搜索路径的质量,减少无效的搜索分支。
八、实际案例:AlphaGo中的MCTS
AlphaGo是MCTS应用的经典案例之一。它结合了深度神经网络和MCTS,成功击败了多位顶级围棋选手,标志着人工智能在复杂博弈领域取得了重大突破。
AlphaGo的MCTS实现
- 策略网络(Policy Network):用于指导MCTS的行动选择,提高选择的效率和质量。
- 价值网络(Value Network):用于评估局面价值,减少模拟阶段的随机性,提升模拟结果的准确性。
- 蒙特卡洛树搜索:
- 选择阶段:使用策略网络指导选择路径,结合UCB1策略平衡探索与利用。
- 扩展阶段:生成新的子节点,结合策略网络进行行动选择。
- 模拟阶段:使用价值网络评估局面,而非纯随机模拟。
- 反向传播阶段:将价值网络的评估结果反向传播,更新节点的统计数据。
AlphaGo的优势
- 结合深度学习:通过策略网络和价值网络,提升了MCTS的决策质量和效率。
- 高效搜索:利用策略网络减少了无效的搜索分支,加快了搜索速度。
- 准确评估:价值网络提供了更准确的局面评估,提升了模拟结果的可靠性。
九、总结
蒙特卡洛树搜索(MCTS)作为一种强大的启发式搜索算法,在处理复杂的决策问题中表现出色。其通过结合随机模拟和决策树的构建,能够在庞大的搜索空间中高效地找到优质的决策路径。尽管MCTS存在计算资源消耗大、依赖模拟质量等劣势,但通过各种改进和扩展方法,其应用范围和效果得到了显著提升。
MCTS的成功应用,如AlphaGo,展示了其在人工智能领域的巨大潜力。随着计算能力的提升和算法的不断优化,MCTS将在更多复杂的决策问题中发挥关键作用。
如果您对MCTS有更多的疑问或需要进一步的深入探讨,请随时提出!
2142

被折叠的 条评论
为什么被折叠?



