tot学习

部署运行你感兴趣的模型镜像

1. 概述:大脑树(Tree of Thoughts, ToT)算法

大脑树(ToT) 由 Yao 等人提出,是一种通用的大型语言模型(LLM)代理搜索算法。它结合了反思/评估和简单搜索(如广度优先搜索 BFS,当然也可以应用深度优先搜索 DFS 或其他算法)。ToT 主要包括三个步骤:

  1. 扩展(Expand):生成一个或多个问题的候选解决方案。
  2. 评分(Score):衡量这些解决方案的质量。
  3. 修剪(Prune):保留评分最高的前 K 个候选方案。

如果没有找到解决方案或解决方案质量不足,则返回“扩展”步骤继续迭代。

2. 前提条件(Prerequisites)

首先,我们需要安装教程所依赖的包,并设置所选择的 LLM 提供商的 API 密钥。

%%capture --no-stderr
%pip install -U langgraph langchain-openai
import getpass
import os

def _set_env(var: str):
    if not os.environ.get(var):
        os.environ[var] = getpass.getpass(f"{var}: ")

_set_env("OPENAI_API_KEY")
# 如果需要可视化算法
trace = True
if trace:
    _set_env("LANGSMITH_API_KEY")
    os.environ["LANGSMITH_PROJECT"] = "ToT Tutorial"

解释:

  • 安装依赖包:使用 %pip install 安装最新版本的 langgraphlangchain-openai
  • 设置环境变量:通过 getpass.getpass 安全地获取用户输入的 API 密钥,并将其设置为环境变量。这里需要设置 OPENAI_API_KEY,如果启用了追踪(trace),还需要设置 LANGSMITH_API_KEYLANGSMITH_PROJECT

3. 任务定义(Task Definition)

我们的代理将尝试玩“24 点游戏”。给定 4 个数字,生成一个数学方程式,使用每个数字恰好一次,结果为 24。

import operator
from typing import List, Literal, Union, NamedTuple, Optional
from pydantic import BaseModel, Field

OperatorType = Literal["+", "-", "*", "/"]
TokenType = Union[float, OperatorType]

解释:

  • 导入必要的模块:包括 operator 模块用于数学运算,typing 模块用于类型提示,pydantic 用于数据验证。
  • 定义类型
    • OperatorType:仅限于加、减、乘、除四种运算符。
    • TokenType:可以是浮点数或运算符。

定义方程式的模型:

class Equation(BaseModel):
    """结合提供的数字以达到 24 的公式。"""

    tokens: List[TokenType] = Field(
        description="逆波兰表示法的令牌栈。例如:[3, 4, '+', -1, '*'] 将计算为 (3 + 4) * -1 = -7。",
    )

    def compute(self) -> float:
        op_funcs = {
            "+": operator.add,
            "-": operator.sub,
            "*": operator.mul,
            "/": operator.truediv,
        }
        stack = []
        for token in self.tokens:
            if isinstance(token, float):
                stack.append(token)
            else:
                b, a = stack.pop(), stack.pop()
                stack.append(op_funcs[token](a, b))
        return stack[0]

解释:

  • Equation
    • tokens:使用逆波兰表示法(后缀表示法)存储方程的令牌,例如 [3, 4, '+', -1, '*'] 代表 (3 + 4) * -1
    • compute 方法:计算方程的值,使用栈来解析逆波兰表达式。
class GuessEquations(BaseModel):
    """提交多个方程作为猜测。"""

    reasoning: str = Field(
        description="提交的猜测背后的推理。解释如何得出这些方程。",
    )
    equations: List[Equation] = Field(
        description="要提交的方程列表。",
    )

解释:

  • GuessEquations
    • reasoning:描述提交的方程背后的思路。
    • equations:提交的方程列表。

定义候选者模型:

class Candidate(NamedTuple):
    candidate: Equation
    score: Optional[float] = None
    feedback: Optional[str] = None

    def __str__(self):
        try:
            computed = self.candidate.compute()
        except Exception as e:
            computed = f"无效的方程: {self.candidate.tokens}; 错误: {repr(e)}"
        return f"Equation({self.candidate.tokens}) = {computed} (奖励: {self.score})"

class ScoredCandidate(Candidate):
    candidate: Equation
    score: float
    feedback: str

解释:

  • Candidate

    • 表示一个候选解决方案,包括 Equation 对象、评分和反馈。
    • __str__ 方法用于打印候选者的信息。
  • ScoredCandidate

    • 继承自 Candidate,确保 scorefeedback 是必填项。

4. 获取数据(Fetch data)

我们将使用“24 点游戏”数据集中的一个示例。

import requests
import csv

csv_data = requests.get(
    "https://storage.googleapis.com/benchmarks-artifacts/game-of-24/24.csv"
).content.decode("utf-8")
# 获取“Puzzles”列(列索引为1)
puzzles = [row[1].strip() for row in csv.reader(csv_data.splitlines()[1:])]

print(f"示例谜题: {puzzles[:3]}")

解释:

  • 获取 CSV 数据:从指定的 URL 下载 CSV 文件并解码为字符串。
  • 解析 CSV:使用 csv.reader 解析 CSV 数据,提取第二列(索引为1)的“Puzzles”数据。
  • 打印示例谜题:展示前三个谜题,例如 ['1 1 4 6', '1 1 11 11', '1 1 3 8']

5. 扩展器(Expander)

ToT 算法是相对通用的。主要的两个任务特定组件是扩展器和评分器。扩展器(增强的 LLM)尝试生成一个或多个问题的解决方案。在后续尝试中,它会接收来自之前搜索的种子/候选值。

定义扩展器:

from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI

prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "你正在玩24点游戏。使用提供的数字,创建一个等于24的方程式。\n"
            "本轮请提交恰好{k}个猜测。",
        ),
        ("user", "为这些数字解决24点游戏: {problem}.{candidate}"),
    ],
).partial(candidate="")
llm = ChatOpenAI(model="gpt-4o-mini")

bound_llm = llm.with_structured_output(GuessEquations)
solver = prompt | bound_llm
# API参考: ChatPromptTemplate | ChatOpenAI

解释:

  • ChatPromptTemplate

    • 定义了与 LLM 交互的提示模板。
    • 系统消息:指示 LLM 参与“24 点游戏”,并提交 {k} 个猜测。
    • 用户消息:提供具体的问题 {problem} 和之前的候选 {candidate}
  • ChatOpenAI

    • 使用 OpenAI 的 GPT-4 模型(这里是 gpt-4o-mini,可能是一个简化版本)。
  • 绑定结构化输出

    • with_structured_output(GuessEquations):确保 LLM 的输出符合 GuessEquations 模型的结构。
  • 组合提示和 LLM

    • 使用管道操作符 | 将提示模板和 LLM 连接起来,形成一个完整的求解器 solver

6. 评分器(Scorer)

在这个游戏中,评分器相对简单。我们需要确认两点:

  1. LLM 生成了一个有效的方程,使用每个数字恰好一次。
  2. 方程的结果等于 24。

定义评分函数:

def compute_score(problem: str, candidate: Candidate) -> ScoredCandidate:
    numbers = list(map(int, problem.split()))
    # 检查候选方程是否使用了所有4个数字且每个数字恰好一次
    used_numbers = [
        token for token in candidate.candidate.tokens if isinstance(token, float)
    ]
    if sorted(used_numbers) != sorted(numbers):
        score = 0
        feedback = "方程必须使用所有4个数字且每个数字恰好一次。"
        return ScoredCandidate(
            candidate=candidate.candidate, score=score, feedback=feedback
        )
    try:
        result = candidate.candidate.compute()
        score = 1 / (1 + abs(24 - result))
        feedback = f"结果: {result}"
    except Exception as e:
        score = 0
        feedback = f"无效的方程。错误: {repr(e)}"
    return ScoredCandidate(
        candidate=candidate.candidate, score=score, feedback=feedback
    )

解释:

  • 参数

    • problem:当前的谜题(四个数字)。
    • candidate:候选的方程。
  • 评分逻辑

    1. 验证数字使用情况

      • 提取候选方程中使用的数字,检查是否恰好使用了问题中的所有数字。
      • 如果不匹配,评分为 0,反馈为“方程必须使用所有4个数字且每个数字恰好一次。”
    2. 计算结果

      • 如果数字使用正确,尝试计算方程的结果。
      • 评分公式:1 / (1 + abs(24 - result)),即结果越接近 24,评分越高。
      • 如果计算过程中出现错误(如除以零),评分为 0,反馈包含错误信息。

7. 构建图(Graph)

接下来,我们需要创建 LangGraph 图,该图将包括扩展、评分和修剪三个节点,并定义它们之间的连接关系。

定义图的状态和配置:

import operator
from typing import Optional, Dict, Any
from typing_extensions import Annotated, TypedDict
from langgraph.graph import StateGraph

from langchain_core.runnables import RunnableConfig
from langgraph.constants import Send
from langgraph.checkpoint.memory import MemorySaver

def update_candidates(
    existing: Optional[list] = None,
    updates: Optional[Union[list, Literal["clear"]]] = None,
) -> List[str]:
    if existing is None:
        existing = []
    if updates is None:
        return existing
    if updates == "clear":
        return []
    # 连接现有列表和更新列表
    return existing + updates

class ToTState(TypedDict):
    problem: str
    candidates: Annotated[List[Candidate], update_candidates]
    scored_candidates: Annotated[List[ScoredCandidate], update_candidates]
    depth: Annotated[int, operator.add]

class Configuration(TypedDict, total=False):
    max_depth: int
    threshold: float
    k: int
    beam_size: int

解释:

  • ToTState

    • problem:当前的谜题。
    • candidates:当前的候选方程列表,使用 update_candidates 函数进行更新。
    • scored_candidates:评分后的候选方程列表,同样使用 update_candidates
    • depth:当前的搜索深度,每次迭代增加。
  • Configuration

    • max_depth:搜索的最大深度,默认 10。
    • threshold:评分阈值,默认 0.9。
    • k:每轮生成的猜测数量,默认 5。
    • beam_size:修剪时保留的候选数量,默认 3。

辅助函数:

def _ensure_configurable(config: RunnableConfig) -> Configuration:
    """获取配置搜索算法的参数。"""
    configurable = config.get("configurable", {})
    return {
        **configurable,
        "max_depth": configurable.get("max_depth", 10),
        "threshold": config.get("threshold", 0.9),
        "k": configurable.get("k", 5),
        "beam_size": configurable.get("beam_size", 3),
    }

解释:

  • _ensure_configurable
    • RunnableConfig 中提取配置参数,提供默认值。

定义扩展函数:

class ExpansionState(ToTState):
    seed: Optional[Candidate]

def expand(state: ExpansionState, *, config: RunnableConfig) -> Dict[str, List[str]]:
    """生成下一个状态。"""
    configurable = _ensure_configurable(config)
    if not state.get("seed"):
        candidate_str = ""
    else:
        candidate_str = "\n\n" + str(state["seed"])
    try:
        equation_submission = solver.invoke(
            {
                "problem": state["problem"],
                "candidate": candidate_str,
                "k": configurable["k"],
            },
            config=config,
        )
    except Exception:
        return {"candidates": []}
    new_candidates = [
        Candidate(candidate=equation) for equation in equation_submission.equations
    ]
    return {"candidates": new_candidates}

解释:

  • expand 函数
    • 使用 solver(之前定义的扩展器)生成新的候选方程。
    • 如果有 seed 候选者,将其转换为字符串并附加到提示中,以指导 LLM 生成更好的猜测。
    • 返回新生成的候选方程列表。

定义评分函数:

def score(state: ToTState) -> Dict[str, List[float]]:
    """评估候选生成。"""
    candidates = state["candidates"]
    scored = []
    for candidate in candidates:
        scored.append(compute_score(state["problem"], candidate))
    return {"scored_candidates": scored, "candidates": "clear"}

解释:

  • score 函数
    • 遍历当前候选方程,对每个候选进行评分。
    • 将评分结果存储在 scored_candidates 中,并清空 candidates 列表。

定义修剪函数:

def prune(
    state: ToTState, *, config: RunnableConfig
) -> Dict[str, List[Dict[str, Any]]]:
    scored_candidates = state["scored_candidates"]
    beam_size = _ensure_configurable(config)["beam_size"]
    organized = sorted(
        scored_candidates, key=lambda candidate: candidate.score, reverse=True
    )
    pruned = organized[:beam_size]
    return {
        # 更新下一次迭代的起点
        "candidates": pruned,
        # 清除旧的评分结果
        "scored_candidates": "clear",
        # 深度增加 1
        "depth": 1,
    }

解释:

  • prune 函数
    • 根据评分对候选方程进行排序,保留评分最高的 beam_size 个。
    • 更新 candidates 为修剪后的候选者,清空 scored_candidates,并增加搜索深度。

定义终止条件:

def should_terminate(
    state: ToTState, config: RunnableConfig
) -> Union[Literal["__end__"], Send]:
    configurable = _ensure_configurable(config)
    solved = state["candidates"][0].score >= configurable["threshold"]
    if solved or state["depth"] >= configurable["max_depth"]:
        return "__end__"
    return [
        Send("expand", {**state, "seed": candidate})
        for candidate in state["candidates"]
    ]

解释:

  • should_terminate 函数
    • 检查是否满足终止条件:
      • 是否找到评分超过阈值的解决方案。
      • 是否达到最大搜索深度。
    • 如果满足其中任一条件,返回 "__end__",终止搜索。
    • 否则,继续进行扩展操作,对每个候选者发送 expand 请求。

构建图:

# 创建图构建器
builder = StateGraph(state_schema=ToTState, config_schema=Configuration)

# 添加节点
builder.add_node(expand)
builder.add_node(score)
builder.add_node(prune)

# 添加边
builder.add_edge("expand", "score")
builder.add_edge("score", "prune")
builder.add_conditional_edges("prune", should_terminate, path_map=["expand", "__end__"])

# 设置入口点
builder.add_edge("__start__", "expand")

# 编译图
graph = builder.compile(checkpointer=MemorySaver())
# API参考: RunnableConfig | StateGraph | Send | MemorySaver

解释:

  • 创建 StateGraph

    • 定义图的状态模式 ToTState 和配置模式 Configuration
  • 添加节点

    • expand:扩展节点。
    • score:评分节点。
    • prune:修剪节点。
  • 添加边

    • expandscore
    • scoreprune
    • prune 根据 should_terminate 函数的结果决定下一步:
      • 如果终止,结束图执行。
      • 否则,返回 expand 继续迭代。
  • 设置入口点

    • __start__expand
  • 编译图

    • 使用 MemorySaver 作为检查点,编译完成后的图对象为 graph

8. 运行(Run)

现在,我们可以尝试在其中一个谜题上运行这个图。

可视化图:

from IPython.display import Image, display

display(Image(graph.get_graph().draw_mermaid_png()))

解释:

  • 可视化图
    • 使用 draw_mermaid_png 方法绘制图的结构,并在 Jupyter Notebook 中显示。

执行图:

config = {
    "configurable": {
        "thread_id": "test_1",
        "depth": 10,
    }
}
for step in graph.stream({"problem": puzzles[42]}, config):
    print(step)

解释:

  • 配置参数

    • thread_id:标识当前线程,方便跟踪。
    • depth:最大搜索深度,设置为 10。
  • 运行图

    • 使用 graph.stream 方法逐步执行图,并打印每一步的状态。

示例输出解析:

{'expand': {'candidates': [Candidate(...), ...]}}
{'score': {'candidates': 'clear', 'scored_candidates': [ScoredCandidate(...), ...]}}
{'prune': {'candidates': [ScoredCandidate(...), ...], 'scored_candidates': 'clear', 'depth': 1}}
...
{'prune': {'candidates': [ScoredCandidate(...), ...], 'scored_candidates': 'clear', 'depth': 1}}

final_state = graph.get_state(config)
winning_solution = final_state.values["candidates"][0]
search_depth = final_state.values["depth"]
if winning_solution.score == 1:
    print(f"在 {search_depth} 步内找到一个获胜的解决方案: {winning_solution}")
else:
    print(
        f"在 {search_depth} 步内未能找到获胜的解决方案。最佳猜测: {winning_solution}"
    )

解释:

  • 逐步执行

    • 每一步 expandscoreprune 都会打印当前的状态,包括生成的候选方程、评分结果和修剪后的候选者。
  • 获取最终状态

    • 使用 graph.get_state(config) 获取最终的搜索状态。
    • 提取 winning_solutionsearch_depth
  • 输出结果

    • 如果找到评分为 1 的解决方案(即结果精确等于 24),则打印成功信息。
    • 否则,打印失败信息及最佳猜测。

示例输出:

在 2 步内找到一个获胜的解决方案: [Equation(tokens=[1.0, 5.0, 7.0, '*', 12.0, '-', '+']), 1.0, '结果: 24.0']

解释:

  • 成功找到解决方案
    • 在第 2 步搜索中,找到一个评分为 1 的解决方案,即 (1 + (5 * 7)) - 12 = 24

9. 总结

本案例展示了如何使用 LangGraph 和 ToT 算法构建一个智能代理,解决“24 点游戏”。通过以下几个关键步骤:

  1. 定义任务和数据模型:包括方程表示、候选者和评分模型。
  2. 数据获取:从外部数据源加载游戏谜题。
  3. 构建扩展器和评分器:使用 LLM 生成候选方程,并评估其质量。
  4. 构建搜索图:定义扩展、评分和修剪节点及其连接关系。
  5. 执行搜索:运行图,迭代生成和评估候选,直到找到满意的解决方案或达到最大搜索深度。

通过这种方式,您可以利用 LLM 的强大能力,结合结构化的搜索策略,有效地解决复杂的组合优化问题。

汇总

当然,以下是将您提供的 LangGraph 入门案例的相关代码汇总到一个完整的 tot_game_of_24.py 文件中。所有提示词和注释均已翻译为中文,以便您更好地理解和执行代码。

请确保在运行此脚本之前,已按照以下步骤操作:

  1. 安装必要的包

    在终端或命令提示符中运行以下命令,以安装所需的 Python 包:

    pip install -U langgraph langchain-openai pydantic requests typing-extensions
    
  2. 设置环境变量

    运行脚本时,系统会提示您输入 OPENAI_API_KEY。如果启用了追踪功能(trace = True),还需要输入 LANGSMITH_API_KEY

以下是完整的 tot_game_of_24.py 文件内容:

# tot_game_of_24.py

import getpass
import os
import operator
from typing import List, Literal, Union, NamedTuple, Optional, Dict, Any
from typing_extensions import Annotated, TypedDict

import requests
import csv
from pydantic import BaseModel, Field

from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI

from langgraph.graph import StateGraph
from langchain_core.runnables import RunnableConfig
from langgraph.constants import Send
from langgraph.checkpoint.memory import MemorySaver

# 设置环境变量的函数
def _set_env(var: str):
    if not os.environ.get(var):
        os.environ[var] = getpass.getpass(f"{var}:")

# 设置 OPENAI_API_KEY
_set_env("OPENAI_API_KEY")

# 如果需要可视化算法
trace = True
if trace:
    _set_env("LANGSMITH_API_KEY")
    os.environ["LANGSMITH_PROJECT"] = "ToT Tutorial"

# 定义运算符类型和令牌类型
OperatorType = Literal["+", "-", "*", "/"]
TokenType = Union[float, OperatorType]

# 定义方程式的模型
class Equation(BaseModel):
    """结合提供的数字以达到 24 的公式。"""

    tokens: List[TokenType] = Field(
        description="逆波兰表示法的令牌栈。例如:[3, 4, '+', -1, '*'] 将计算为 (3 + 4) * -1 = -7。",
    )

    def compute(self) -> float:
        op_funcs = {
            "+": operator.add,
            "-": operator.sub,
            "*": operator.mul,
            "/": operator.truediv,
        }
        stack = []
        for token in self.tokens:
            if isinstance(token, float):
                stack.append(token)
            else:
                if len(stack) < 2:
                    raise ValueError("栈中元素不足以进行运算。")
                b, a = stack.pop(), stack.pop()
                try:
                    result = op_funcs[token](a, b)
                except ZeroDivisionError:
                    raise ValueError("除以零错误。")
                stack.append(result)
        if len(stack) != 1:
            raise ValueError("最终栈中元素数量不正确。")
        return stack[0]

# 定义猜测方程式的模型
class GuessEquations(BaseModel):
    """提交多个方程作为猜测。"""

    reasoning: str = Field(
        description="提交的猜测背后的推理。解释如何得出这些方程。",
    )
    equations: List[Equation] = Field(
        description="要提交的方程列表。",
    )

# 定义候选者模型
class Candidate(NamedTuple):
    candidate: Equation
    score: Optional[float] = None
    feedback: Optional[str] = None

    def __str__(self):
        try:
            computed = self.candidate.compute()
        except Exception as e:
            computed = f"无效的方程: {self.candidate.tokens}; 错误: {repr(e)}"
        return f"Equation({self.candidate.tokens}) = {computed} (奖励: {self.score})"

# 定义评分后的候选者模型
class ScoredCandidate(Candidate):
    candidate: Equation
    score: float
    feedback: str

# 获取数据函数
def fetch_puzzles() -> List[str]:
    """从指定的URL获取24点游戏的谜题数据。"""
    csv_url = "https://storage.googleapis.com/benchmarks-artifacts/game-of-24/24.csv"
    response = requests.get(csv_url)
    if response.status_code != 200:
        raise ValueError(f"无法获取数据,状态码:{response.status_code}")
    csv_data = response.content.decode("utf-8")
    # 获取“Puzzles”列(列索引为1)
    puzzles = [row[1].strip() for row in csv.reader(csv_data.splitlines()[1:])]
    return puzzles

# 调用获取数据函数并打印示例谜题
puzzles = fetch_puzzles()
print(f"示例谜题: {puzzles[:3]}")

# 定义扩展器(Expander)
prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "你正在玩24点游戏。使用提供的数字,创建一个等于24的方程式。\n"
            "本轮请提交恰好{k}个猜测。",
        ),
        ("user", "为这些数字解决24点游戏: {problem}.{candidate}"),
    ],
).partial(candidate="")

llm = ChatOpenAI(model="gpt-4o-mini")

bound_llm = llm.with_structured_output(GuessEquations)
solver = prompt | bound_llm
# API参考: ChatPromptTemplate | ChatOpenAI

# 定义评分器(Scorer)
def compute_score(problem: str, candidate: Candidate) -> ScoredCandidate:
    numbers = list(map(int, problem.split()))
    # 检查候选方程是否使用了所有4个数字且每个数字恰好一次
    used_numbers = [
        token for token in candidate.candidate.tokens if isinstance(token, float)
    ]
    if sorted(used_numbers) != sorted(numbers):
        score = 0
        feedback = "方程必须使用所有4个数字且每个数字恰好一次。"
        return ScoredCandidate(
            candidate=candidate.candidate, score=score, feedback=feedback
        )
    try:
        result = candidate.candidate.compute()
        score = 1 / (1 + abs(24 - result))
        feedback = f"结果: {result}"
    except Exception as e:
        score = 0
        feedback = f"无效的方程。错误: {repr(e)}"
    return ScoredCandidate(
        candidate=candidate.candidate, score=score, feedback=feedback
    )

# 定义图的状态和配置
class ToTState(TypedDict):
    problem: str
    candidates: Annotated[List[Candidate], "update_candidates"]
    scored_candidates: Annotated[List[ScoredCandidate], "update_candidates"]
    depth: Annotated[int, "operator.add"]

class Configuration(TypedDict, total=False):
    max_depth: int
    threshold: float
    k: int
    beam_size: int

# 辅助函数:更新候选者列表
def update_candidates(
    existing: Optional[List[Union[Candidate, ScoredCandidate]]] = None,
    updates: Optional[Union[List[Union[Candidate, ScoredCandidate]], Literal["clear"]]] = None,
) -> List[Union[Candidate, ScoredCandidate]]:
    if existing is None:
        existing = []
    if updates is None:
        return existing
    if updates == "clear":
        return []
    # 连接现有列表和更新列表
    return existing + updates

# 辅助函数:确保配置参数
def _ensure_configurable(config: RunnableConfig) -> Configuration:
    """获取配置搜索算法的参数。"""
    configurable = config.get("configurable", {})
    return {
        **configurable,
        "max_depth": configurable.get("max_depth", 10),
        "threshold": configurable.get("threshold", 0.9),
        "k": configurable.get("k", 5),
        "beam_size": configurable.get("beam_size", 3),
    }

# 定义扩展状态
class ExpansionState(ToTState):
    seed: Optional[Candidate]

# 定义扩展函数
def expand(state: ExpansionState, *, config: RunnableConfig) -> Dict[str, List[Candidate]]:
    """生成下一个状态。"""
    configurable = _ensure_configurable(config)
    if not state.get("seed"):
        candidate_str = ""
    else:
        candidate_str = "\n\n" + str(state["seed"])
    try:
        equation_submission = solver.invoke(
            {
                "problem": state["problem"],
                "candidate": candidate_str,
                "k": configurable["k"],
            },
            config=config,
        )
    except Exception as e:
        print(f"扩展过程中出现错误: {e}")
        return {"candidates": []}
    new_candidates = [
        Candidate(candidate=equation) for equation in equation_submission.equations
    ]
    return {"candidates": new_candidates}

# 定义评分函数
def score(state: ToTState) -> Dict[str, List[ScoredCandidate]]:
    """评估候选生成。"""
    candidates = state["candidates"]
    scored = []
    for candidate in candidates:
        scored_candidate = compute_score(state["problem"], candidate)
        scored.append(scored_candidate)
    return {"scored_candidates": scored, "candidates": "clear"}

# 定义修剪函数
def prune(
    state: ToTState, *, config: RunnableConfig
) -> Dict[str, List[Dict[str, Any]]]:
    """修剪候选者列表,保留评分最高的beam_size个。"""
    scored_candidates = state["scored_candidates"]
    beam_size = _ensure_configurable(config)["beam_size"]
    organized = sorted(
        scored_candidates, key=lambda candidate: candidate.score, reverse=True
    )
    pruned = organized[:beam_size]
    return {
        # 更新下一次迭代的起点
        "candidates": pruned,
        # 清除旧的评分结果
        "scored_candidates": "clear",
        # 深度增加 1
        "depth": 1,
    }

# 定义终止条件函数
def should_terminate(
    state: ToTState, config: RunnableConfig
) -> Union[Literal["__end__"], Send]:
    """判断是否满足终止条件。"""
    configurable = _ensure_configurable(config)
    if not state["candidates"]:
        return "__end__"
    solved = state["candidates"][0].score >= configurable["threshold"]
    if solved or state["depth"] >= configurable["max_depth"]:
        return "__end__"
    return [
        Send("expand", {**state, "seed": candidate})
        for candidate in state["candidates"]
    ]

# 创建图构建器
builder = StateGraph(state_schema=ToTState, config_schema=Configuration)

# 添加节点
builder.add_node(expand)
builder.add_node(score)
builder.add_node(prune)

# 添加边
builder.add_edge("expand", "score")
builder.add_edge("score", "prune")
builder.add_conditional_edges("prune", should_terminate, path_map=["expand", "__end__"])

# 设置入口点
builder.add_edge("__start__", "expand")

# 编译图
graph = builder.compile(checkpointer=MemorySaver())
# API参考: RunnableConfig | StateGraph | Send | MemorySaver

# 可视化图(需要在支持的环境中运行,如Jupyter Notebook)
# from IPython.display import Image, display
# display(Image(graph.get_graph().draw_mermaid_png()))

# 运行图
def run_tot(puzzle_index: int):
    """在指定的谜题上运行ToT算法。"""
    if puzzle_index >= len(puzzles):
        print(f"索引 {puzzle_index} 超出谜题列表范围。")
        return

    config = {
        "configurable": {
            "thread_id": "test_1",
            "max_depth": 10,
        }
    }

    initial_state = {"problem": puzzles[puzzle_index]}
    print(f"正在解决的谜题: {initial_state['problem']}")

    try:
        for step in graph.stream(initial_state, config):
            print(step)
    except Exception as e:
        print(f"运行过程中出现错误: {e}")

    final_state = graph.get_state(config)
    if not final_state:
        print("未能获取最终状态。")
        return

    if "candidates" not in final_state.values:
        print("最终状态中缺少候选者信息。")
        return

    if not final_state.values["candidates"]:
        print("未找到任何候选解决方案。")
        return

    winning_solution = final_state.values["candidates"][0]
    search_depth = final_state.values["depth"]
    if winning_solution.score == 1:
        print(f"在 {search_depth} 步内找到一个获胜的解决方案: {winning_solution}")
    else:
        print(
            f"在 {search_depth} 步内未能找到获胜的解决方案。最佳猜测: {winning_solution}"
        )

# 示例运行
if __name__ == "__main__":
    # 选择要解决的谜题索引,例如第43个谜题(索引42)
    puzzle_index = 42
    run_tot(puzzle_index)

代码说明

  1. 环境设置

    • 安装包:确保已安装 langgraph, langchain-openai, pydantic, requests, 和 typing-extensions
    • 环境变量:脚本会提示您输入 OPENAI_API_KEY,如果启用了追踪功能,还需要输入 LANGSMITH_API_KEY。这些密钥将被安全地存储为环境变量。
  2. 数据模型

    • Equation:表示一个数学方程,使用逆波兰表示法(后缀表示法)存储令牌,并提供计算方法。
    • GuessEquations:用于提交多个方程作为猜测,包括推理说明。
    • CandidateScoredCandidate:表示候选解决方案及其评分和反馈。
  3. 数据获取

    • fetch_puzzles() 函数从指定的 URL 下载 CSV 数据,并提取“Puzzles”列中的谜题。
  4. 扩展器(Expander)

    • 使用 ChatPromptTemplate 定义与 LLM 的交互模板,指示其生成等于24的方程式。
    • 使用 ChatOpenAI 的 GPT-4 模型(gpt-4o-mini)作为语言模型,并绑定结构化输出以符合 GuessEquations 模型。
  5. 评分器(Scorer)

    • compute_score() 函数验证候选方程是否使用了所有提供的数字,并计算结果与24的接近程度。评分越高,候选方案越接近正确答案。
  6. 图的构建(Graph Building)

    • 定义图的状态 (ToTState) 和配置 (Configuration)。
    • 定义 expand, score, 和 prune 函数,分别用于生成候选方案、评分候选方案和修剪候选方案列表。
    • 使用 StateGraph 创建图,添加节点和边,并设置入口点。
    • 编译图以生成可执行的 graph 对象。
  7. 运行图(Running the Graph)

    • run_tot() 函数接受一个谜题索引,运行 ToT 算法解决该谜题,并打印每一步的状态。
    • 最终输出是否找到满足条件的解决方案。

执行脚本

  1. 保存脚本

    将上述代码保存为 tot_game_of_24.py

  2. 运行脚本

    在终端或命令提示符中,导航到脚本所在的目录,并运行:

    python tot_game_of_24.py
    

    脚本将提示您输入 OPENAI_API_KEY,并在启用了追踪功能的情况下,提示输入 LANGSMITH_API_KEY

  3. 查看结果

    脚本将打印解决过程中的每一步状态,并在最终输出中指示是否找到一个获胜的解决方案。

注意事项

  • 模型名称:请确保您拥有访问 gpt-4o-mini 模型的权限。如果没有,您可能需要更改为您拥有访问权限的其他模型名称,例如 gpt-4
  • 环境支持:如果您希望可视化图结构,可以在支持的环境(如 Jupyter Notebook)中取消注释相关代码部分。
  • 错误处理:脚本包含基本的错误处理,但在实际应用中,您可能需要根据具体情况添加更多的异常处理逻辑。

通过运行此脚本,您可以全面理解并体验 LangGraph 和 ToT 算法如何协同工作,以解决“24 点游戏”这一具体任务。如果有任何问题或需要进一步的帮助,请随时提问!

您可能感兴趣的与本文相关的镜像

Qwen3-VL-8B

Qwen3-VL-8B

图文对话
Qwen3-VL

Qwen3-VL是迄今为止 Qwen 系列中最强大的视觉-语言模型,这一代在各个方面都进行了全面升级:更优秀的文本理解和生成、更深入的视觉感知和推理、扩展的上下文长度、增强的空间和视频动态理解能力,以及更强的代理交互能力

ToT(Tree of Thoughts,思维树)是一种用于增强大语言模型推理能力的提示词技术,它通过维护一棵包含多个思维路径的树结构,使模型能够生成并评估多个可能的解决方案[^5]。ToT 的实现方法主要包括问题分解、想法生成、状态评价和搜索算法的选择四个主要环节[^1]。 ### 问题分解 ToT 的第一步是将复杂的问题拆解成多个小问题,以便逐个解决。这种分解使得模型能够专注于解决每一个子问题,从而提高解决问题的整体效率。例如,在数学推理任务中,一个复杂的问题可以被分解为多个步骤,每个步骤对应一个子问题[^5]。 ### 想法生成 在每个子问题上,模型生成可能的想法或解决方案。这一步骤是通过提示词设计来引导大语言模型生成多个候选项,每个候选项代表一个可能的解决方案[^5]。这些候选项构成了思维树的一个节点。 ### 状态评价 对生成的想法进行评估,选择最佳的候选方案。评估的标准包括与之前步骤的一致性和合理性。这一步骤通常需要定义一个评估函数或策略,该策略可以是基于规则的,也可以是基于机器学习的[^5]。 ### 搜索算法的选择 根据评估结果,选择合适的搜索算法来进一步探索和优化解决方案。ToT 框架的优势在于它能够同时考虑到多个可能的思维路径,并能够根据输入序列中的不同语义信息来动态调整生成输出的策略[^2]。常见的搜索算法包括广度优先搜索(BFS)和深度优先搜索(DFS)等。 以下是一个简单的 ToT 实现示例,展示了如何使用 Python 和一个假设的大语言模型 API 来实现 ToT 的基本流程: ```python class TreeOfThought: def __init__(self, problem): self.problem = problem self.tree = {} def decompose_problem(self): # 分解问题为子问题 # 这里只是一个示例,实际分解需要根据具体问题来实现 self.sub_problems = ["Sub-problem 1", "Sub-problem 2", "Sub-problem 3"] def generate_ideas(self, sub_problem): # 使用大语言模型生成多个候选项 # 假设调用一个大语言模型API来生成想法 ideas = model_api_call(sub_problem) return ideas def evaluate_ideas(self, ideas): # 评估想法,返回最佳候选 # 这里只是一个简单的评估示例 best_idea = max(ideas, key=lambda x: self.evaluate_idea(x)) return best_idea def evaluate_idea(self, idea): # 实现具体的评估逻辑 # 返回一个评分 return score def search_algorithm(self): # 选择搜索算法,如BFS或DFS # 这里实现BFS作为示例 for sub_problem in self.sub_problems: ideas = self.generate_ideas(sub_problem) best_idea = self.evaluate_ideas(ideas) self.tree[sub_problem] = best_idea def solve(self): self.decompose_problem() self.search_algorithm() return self.tree # 假设的大语言模型API调用 def model_api_call(prompt): # 这里模拟API调用,返回生成的想法 return ["Idea 1", "Idea 2", "Idea 3"] # 使用示例 tot = TreeOfThought("Example Problem") solution_tree = tot.solve() print(solution_tree) ``` ### 相关问题 ToT 的实现方法不仅限于上述内容,还可以结合不同的搜索算法和评估策略来优化解决方案。此外,ToT 框架的应用领域也在不断扩展,包括数学推理、自然语言处理和机器翻译等[^5]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值