从理论到实践:树形思维探索(ToT)

  在开始前请确保您有一定的LLM基础和强化学习基础😊

如果您没有RL基础我推荐David Sliver的讲座(前三集即可)RL Course by David Silver - Lecture 1: Introduction to Reinforcement Learning - YouTube

如果您还不了解CoT,请看我的这篇文章:从理论到实践:思维链(CoT)提示-优快云博客

如果您还不了解多路径生成与自洽性,请看我的这篇文章:从理论到实践:CoT的多路径生成与自洽性-优快云博客


叠甲:我对文章中提到的所有算法的数学解析只是片面的,深入研究会在不久的将来发布(也许吧😔),敬请期待😊(欢迎各位大佬指出错误😊)

数学视角

树形结构的测度论定义

ToT的核心是树结构 T=(V,E),其中:

  • 节点集 V=\left \{ v_{i} \right \}:每个节点表示一个中间思考状态 s_{i}\in SS 为状态空间
  • 边集 E=\left \{ (v_{i},v_{j},a_{k}) \right \}:边表示状态转移动作  a_{k}\in AA 为动作空间

状态空间:

  • 将状态空间 S 视为可测空间 (S,\mathcal{F}),其中 \mathcal{F} 是状态集合的\sigma-代数
  • 每个节点 v_{i} 对应一个状态测度 \mu _{i}:\mathcal{F}\rightarrow \left [ 0,1 \right ],表示该状态的概率分布
  • 转移核:边 (v_{i},v_{j}) 对应转移概率  K(v_{i},v_{j})=P(v_{j}|v_{i}),满足:

        \sum_{v_{j}\in children(v_{i})}^{}K(v_{i},V_{j})=1

路径空间的积分表示:

  • 路径 P=(v_{0}\rightarrow v_{1}\rightarrow ...\rightarrow v_{n})  的概率可表示为转移核的乘积:     

        \mathbb{P}(P)=\prod_{k=0}^{n-1}K(v_{k},v_{k+1})

  • 目标函数 U(P) 的期望值为:

        \mathbb{E}\left [ U \right ]=\int_{\mathcal{P}}^{}U(P)d\mathbb{P}(P)

        其中积分在路径空间 \mathcal{P} 上进行

随机过程与马尔可夫性

1.马尔可夫链建模

  • ToT的搜索过程可建模为非其次马尔可夫链,状态转移满足:

        P(v_{t+1}|v_{t},v_{t-1},...,v_{0})=P(v_{t+1}|v_{t})

  • Chapman-Kolmogorov方程:多步转移概率满足:

        K^{(n)}(v_{i},v_{j})=\sum_{v_{k}}^{}K^{(m)}(v_{i},v_{k})K^{(n-m)}(v_{k},v_{j})

2.收敛性分析

  • 若树满足不可约性(所有状态互通)和非周期性,则存在唯一平稳分布 \pi,使得:

        \pi (v_{j})=\sum_{v_{i}}^{}\pi (v_{i})K(v_{i},v_{j})

  • Perron-Frobenius定理:转移矩阵的最大特征值为1,对应平稳分布

    搜索算法

    ToT的搜索过程可形式化为优化问题:      

            \underset{P\in \mathcal{P}}{max}U(P)

    其中 \mathcal{P} 为所有可能路径,U(P) 为路径效用函数

    1.启发式搜索(A*算法)

    • 定义评估函数 f(v)=g(v)+h(v),其中:

            g(v):从根节点到 v 的实际代价

            h(v):启发式函数,估计 v 到目标状态的代价

    • 可纳性条件:若 h(v)\leq h^{*}(v)(真实最小代价),则 A* 算法保证找到最优解

    2.蒙特卡洛树搜索(MCTS)

    • 节点选择基于Upper Confidence Bound(UCB)

            UCB(v_{i})=\frac{Q(v_{i})}{N(v_{i})}+c\sqrt{\frac{ln(N(parent(v_{i})))}{N(v_{i})}}

    其中 Q(v_{i}) 为累积奖励,N(v_{i}) 为访问次数,c 为探索参数

    概率推理与贝叶斯方法

    ToT中路径选择通常涉及不确定性,可通过概率模型量化:

    1.贝叶斯更新:

    • 节点 v_{i} 的先验概率 P(v_{i}) 通过似然函数 P(D|v_{i}) 更新为后验概率:

            P(v_{i}|D)=\frac{P(D|V_{i})P(v_{i})}{\sum_{j}^{}P(D|v_{j})P(v_{j})}

            其中 D 为新观测数据

    2.信息熵决策:

    • 选择最大化信息增益的路径: 

            IG(v_{i})=H(S)-H(S|v_{i})

            其中 H(S) 为状态空间的熵,H(S|v_{i}) 为条件熵 

    组合优化与剪枝策略

    1.动态规划

    • 利用最优子结构性质,递归定义节点价值:

            V(v)=\underset{a\in A}{max}[R(v,a)+\gamma V(next(v,a))]

            其中 R(v,a) 为即时奖励,\gamma 为折扣因子

    2.Alpha-Beta剪枝

    • 通过上下界 \alpha 和 \beta 剪除无效子树,最理想状态下将复杂度从 O(b^{d})  降至 O(b^{d/2})

    收敛性与复杂度分析

    1.收敛性:若搜索空间有限且满足完备性条件(所有解均可达),则ToT在有限步内收敛

    2.复杂性:

    • 时间:O(b^{d})(无剪枝),O(b^{d/2})(启发式剪枝)
    • 空间:O(bd)(深度优先),O(b^{d})(广度优先)

    扩展模型:随机过程与博弈论

    1.马尔可夫决策过程(MDP):

    • 定义五元组(S,A,P,R,\gamma ),其中 P(s'|s,a) 为状态转移概率
    • 最优策略 \pi ^{*} 满足Bellman方程:

            V^{*}(s)=\underset{a}{max}\left [ R(s,a)+\gamma \sum_{s'}^{}P(s'|s,a)V^{*}(s') \right ]

    2.多人树搜索的微分博弈模型:

    • 定义玩家集合 \mathcal{N},每个玩家策略 u_{i}(t)\in \mathcal{U}_{i}
    • 状态方程:

            \frac{ds}{dt}=f(s,u_{1},...,u_{N}),\quad s(0)=s_{0}

    • 玩家 i 的目标泛函:

            J_{i}(u_{1},...,u_{N})=\int_{0}^{T}\phi (s,u_{i})dt+\Psi _{i}(s(T))

    代码实现

    ToT树结构

    class ToTNode:
        def __init__(self, state, parent=None, depth=0):
            self.state = state      # 当前状态字符串
            self.parent = parent    # 父节点引用
            self.children = []      # 子节点列表
            self.value = 0.0        # 累计评估值
            self.visits = 0        # 访问次数
            self.depth = depth     # 当前深度
            self.is_terminal = False  # 是否为终止节点
            
        def ucb_score(self, exploration=1.414):
            """计算UCB选择分数"""
            if self.visits == 0:
                return float('inf')
            return (self.value / self.visits) + exploration * math.sqrt(math.log(self.parent.visits) / self.visits)
        
        def to_dict(self):
            """序列化节点信息"""
            return {
                "state": self.state[:50]+"..." if len(self.state)>50 else self.state,
                "depth": self.depth,
                "value": round(self.value,2),
                "visits": self.visits,
                "is_terminal": self.is_terminal
            }
    

    扩展树节点

    def expand_node(node, beam_width=2):
        """扩展单个节点"""
        candidates = []
        for method in ["algebra", "matrix", "default"]:
            # 生成提示
            prompt = generate_tot_prompt(
                node.state.split("<solution>")[-1],  # 仅传递推导步骤部分
                method, 
                node.depth+1
            )
                    
            # 模型生成
            inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
            outputs = model.generate(
                inputs.input_ids,
                attention_mask=inputs.attention_mask,
                max_new_tokens=400,
                temperature=max(0.5, 1.0 - 0.05*node.depth),  # 深度越大生成越稳定
                top_p=0.9,
                repetition_penalty=1.3,  # 增强重复惩罚
                num_beams=7,        # 增加束搜索宽度
                no_repeat_ngram_size=3,  # 添加N-gram重复惩罚
                early_stopping=True,
                do_sample=True,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
            )
            new_state = tokenizer.decode(outputs[0], skip_special_tokens=True)
            
            # 创建子节点
            child = ToTNode(
                state=new_state,
                parent=node,
                depth=node.depth + 1
            )
            child.is_terminal = is_terminal_state(new_state) is not None
            candidates.append(child)
        
        # 评估并选择最优分支
        evaluated = [(c, evaluate_node(c)) for c in candidates]
        evaluated.sort(key=lambda x: x[1], reverse=True)
        node.children = [c for c, _ in evaluated[:beam_width]]
        
        return node

    蒙特卡洛树搜索

    def monte_carlo_tree_search(root, iterations=100):
        """蒙特卡洛树搜索主循环"""
        for i in range(iterations):
            # 选择阶段
            node = root
            while node.children:
                node = max(node.children, key=lambda x: x.ucb_score())
            
            # 扩展阶段
            if not node.is_terminal:
                expand_node(node)
            
            # 模拟阶段(此处简化使用评估值代替随机模拟)
            reward = evaluate_node(node)
            
            # 回溯更新
            while node is not None:
                node.visits += 1
                node.value += reward / (node.depth + 1)  # 深度归一化
                node = node.parent
            # 调试输出
            print(f"Iteration {i+1}: Root visits={root.visits}, value={root.value:.2f}")
        return root

    结果分析
     

    测试用例: <final_answer>x=2,y=4</final_a...
    提取结果: x=2,y=4 | 预期: True | 状态: 成功
    ------------------------------------------------------------
    测试用例: Final_Answer: X=2.5, Y=3...
    提取结果: x=2.5,y=3 | 预期: True | 状态: 成功
    ------------------------------------------------------------
    测试用例: 答案:x=3, y=2.8...
    提取结果: x=3,y=2.8 | 预期: True | 状态: 成功
    ------------------------------------------------------------
    测试用例: x=值,y=值...
    提取结果: None | 预期: False | 状态: 成功
    ------------------------------------------------------------
    测试用例: x=, y=4...
    提取结果: None | 预期: False | 状态: 成功
    ------------------------------------------------------------
    测试用例: 无效内容...
    提取结果: None | 预期: False | 状态: 成功
    ------------------------------------------------------------
    
    ============================================================
    处理问题:解方程组:
    方程1: 2x + 3y = 16
    方程2: x - 2y = -6
    [DEBUG] 首次扩展后的子节点:
     子节点1:{'state': "请严格遵循以下要求:\n1. 必须使用具体数值,禁止使用'值'等占位符\n2. 小数保留两位,例如:x=...", 'depth': 1, 'value': 0.0, 'visits': 0, 'is_terminal': True}
     子节点2:{'state': "请严格遵循以下要求:\n1. 必须使用具体数值,禁止使用'值'等占位符\n2. 小数保留两位,例如:x=...", 'depth': 1, 'value': 0.0, 'visits': 0, 'is_terminal': True}
     子节点3:{'state': "请严格遵循以下要求:\n1. 必须使用具体数值,禁止使用'值'等占位符\n2. 小数保留两位,例如:x=...", 'depth': 1, 'value': 0.0, 'visits': 0, 'is_terminal': True}
    Iteration 1: Root visits=1, value=0.56
    Iteration 2: Root visits=2, value=1.12
    Iteration 3: Root visits=3, value=1.68
    Iteration 4: Root visits=4, value=2.24
    Iteration 5: Root visits=5, value=2.80
    Iteration 6: Root visits=6, value=3.36
    Iteration 7: Root visits=7, value=3.92
    Iteration 8: Root visits=8, value=4.48
    Iteration 9: Root visits=9, value=5.04
    Iteration 10: Root visits=10, value=5.60
    Iteration 11: Root visits=11, value=6.16
    Iteration 12: Root visits=12, value=6.72
    Iteration 13: Root visits=13, value=7.28
    Iteration 14: Root visits=14, value=7.84
    Iteration 15: Root visits=15, value=8.40
    Iteration 16: Root visits=16, value=8.96
    Iteration 17: Root visits=17, value=9.52
    Iteration 18: Root visits=18, value=10.08
    Iteration 19: Root visits=19, value=10.64
    Iteration 20: Root visits=20, value=11.20
    Iteration 21: Root visits=21, value=11.76
    Iteration 22: Root visits=22, value=12.32
    Iteration 23: Root visits=23, value=12.88
    Iteration 24: Root visits=24, value=13.44
    Iteration 25: Root visits=25, value=14.00
    Iteration 26: Root visits=26, value=14.56
    Iteration 27: Root visits=27, value=15.12
    Iteration 28: Root visits=28, value=15.68
    Iteration 29: Root visits=29, value=16.24
    Iteration 30: Root visits=30, value=16.80
    Iteration 31: Root visits=31, value=17.36
    Iteration 32: Root visits=32, value=17.92
    Iteration 33: Root visits=33, value=18.48
    Iteration 34: Root visits=34, value=19.04
    Iteration 35: Root visits=35, value=19.60
    Iteration 36: Root visits=36, value=20.16
    Iteration 37: Root visits=37, value=20.72
    Iteration 38: Root visits=38, value=21.28
    Iteration 39: Root visits=39, value=21.84
    Iteration 40: Root visits=40, value=22.40
    Iteration 41: Root visits=41, value=22.96
    Iteration 42: Root visits=42, value=23.52
    Iteration 43: Root visits=43, value=24.08
    Iteration 44: Root visits=44, value=24.64
    Iteration 45: Root visits=45, value=25.20
    Iteration 46: Root visits=46, value=25.76
    Iteration 47: Root visits=47, value=26.32
    Iteration 48: Root visits=48, value=26.88
    Iteration 49: Root visits=49, value=27.44
    Iteration 50: Root visits=50, value=28.00
    [DEBUG] 根节点属性: visits=50, value=27.999999999999975
    树结构已保存至 tot_structure_20250410_133415.pdf
    可视化结果已保存至 equation_visualization.png
    ========================================
    最优推导路径:
    1. Depth 1:
    1. 方程2 => x = 2y - 6
    2. 代入方程1得 4y - 12 + 3y = 16
    3. 解得 y = 4.00 => x = 2.00
    
    ========================================
    最终答案:x=2.00,y=4.00
    ========================================

    完整代码实现请见我的github仓库:naidezhujimo/Tree-shaped-chain-of-thought

    评论
    添加红包

    请填写红包祝福语或标题

    红包个数最小为10个

    红包金额最低5元

    当前余额3.43前往充值 >
    需支付:10.00
    成就一亿技术人!
    领取后你会自动成为博主和红包主的粉丝 规则
    hope_wisdom
    发出的红包
    实付
    使用余额支付
    点击重新获取
    扫码支付
    钱包余额 0

    抵扣说明:

    1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
    2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

    余额充值