在开始前请确保您有一定的LLM基础和强化学习基础😊
如果您没有RL基础我推荐David Sliver的讲座(前三集即可)RL Course by David Silver - Lecture 1: Introduction to Reinforcement Learning - YouTube
如果您还不了解CoT,请看我的这篇文章:从理论到实践:思维链(CoT)提示-优快云博客
如果您还不了解多路径生成与自洽性,请看我的这篇文章:从理论到实践:CoT的多路径生成与自洽性-优快云博客
叠甲:我对文章中提到的所有算法的数学解析只是片面的,深入研究会在不久的将来发布(也许吧😔),敬请期待😊(欢迎各位大佬指出错误😊)
数学视角
树形结构的测度论定义
ToT的核心是树结构 ,其中:
- 节点集
:每个节点表示一个中间思考状态
,
为状态空间
- 边集
:边表示状态转移动作
,
为动作空间
状态空间:
- 将状态空间
视为可测空间
,其中
是状态集合的
-代数
- 每个节点
对应一个状态测度
,表示该状态的概率分布
- 转移核:边
对应转移概率
,满足:
路径空间的积分表示:
- 路径
的概率可表示为转移核的乘积:
- 目标函数
的期望值为:
其中积分在路径空间 上进行
随机过程与马尔可夫性
1.马尔可夫链建模
- ToT的搜索过程可建模为非其次马尔可夫链,状态转移满足:
- Chapman-Kolmogorov方程:多步转移概率满足:
2.收敛性分析
- 若树满足不可约性(所有状态互通)和非周期性,则存在唯一平稳分布
,使得:
- Perron-Frobenius定理:转移矩阵的最大特征值为1,对应平稳分布
搜索算法
ToT的搜索过程可形式化为优化问题:
其中 为所有可能路径,
为路径效用函数
1.启发式搜索(A*算法)
- 定义评估函数
,其中:
:从根节点到
的实际代价
:启发式函数,估计
到目标状态的代价
- 可纳性条件:若
(真实最小代价),则 A* 算法保证找到最优解
2.蒙特卡洛树搜索(MCTS)
- 节点选择基于Upper Confidence Bound(UCB):
其中 为累积奖励,
为访问次数,
为探索参数
概率推理与贝叶斯方法
ToT中路径选择通常涉及不确定性,可通过概率模型量化:
1.贝叶斯更新:
- 节点
的先验概率
通过似然函数
更新为后验概率:
其中 为新观测数据
2.信息熵决策:
- 选择最大化信息增益的路径:
其中 为状态空间的熵,
为条件熵
组合优化与剪枝策略
1.动态规划
- 利用最优子结构性质,递归定义节点价值:
其中 为即时奖励,
为折扣因子
2.Alpha-Beta剪枝
- 通过上下界
和
剪除无效子树,最理想状态下将复杂度从
降至
收敛性与复杂度分析
1.收敛性:若搜索空间有限且满足完备性条件(所有解均可达),则ToT在有限步内收敛
2.复杂性:
- 时间:
(无剪枝),
(启发式剪枝)
- 空间:
(深度优先),
(广度优先)
扩展模型:随机过程与博弈论
1.马尔可夫决策过程(MDP):
- 定义五元组
,其中
为状态转移概率
- 最优策略
满足Bellman方程:
2.多人树搜索的微分博弈模型:
- 定义玩家集合
,每个玩家策略
- 状态方程:
- 玩家
的目标泛函:
代码实现
ToT树结构
class ToTNode:
def __init__(self, state, parent=None, depth=0):
self.state = state # 当前状态字符串
self.parent = parent # 父节点引用
self.children = [] # 子节点列表
self.value = 0.0 # 累计评估值
self.visits = 0 # 访问次数
self.depth = depth # 当前深度
self.is_terminal = False # 是否为终止节点
def ucb_score(self, exploration=1.414):
"""计算UCB选择分数"""
if self.visits == 0:
return float('inf')
return (self.value / self.visits) + exploration * math.sqrt(math.log(self.parent.visits) / self.visits)
def to_dict(self):
"""序列化节点信息"""
return {
"state": self.state[:50]+"..." if len(self.state)>50 else self.state,
"depth": self.depth,
"value": round(self.value,2),
"visits": self.visits,
"is_terminal": self.is_terminal
}
扩展树节点
def expand_node(node, beam_width=2):
"""扩展单个节点"""
candidates = []
for method in ["algebra", "matrix", "default"]:
# 生成提示
prompt = generate_tot_prompt(
node.state.split("<solution>")[-1], # 仅传递推导步骤部分
method,
node.depth+1
)
# 模型生成
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=400,
temperature=max(0.5, 1.0 - 0.05*node.depth), # 深度越大生成越稳定
top_p=0.9,
repetition_penalty=1.3, # 增强重复惩罚
num_beams=7, # 增加束搜索宽度
no_repeat_ngram_size=3, # 添加N-gram重复惩罚
early_stopping=True,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
new_state = tokenizer.decode(outputs[0], skip_special_tokens=True)
# 创建子节点
child = ToTNode(
state=new_state,
parent=node,
depth=node.depth + 1
)
child.is_terminal = is_terminal_state(new_state) is not None
candidates.append(child)
# 评估并选择最优分支
evaluated = [(c, evaluate_node(c)) for c in candidates]
evaluated.sort(key=lambda x: x[1], reverse=True)
node.children = [c for c, _ in evaluated[:beam_width]]
return node
蒙特卡洛树搜索
def monte_carlo_tree_search(root, iterations=100):
"""蒙特卡洛树搜索主循环"""
for i in range(iterations):
# 选择阶段
node = root
while node.children:
node = max(node.children, key=lambda x: x.ucb_score())
# 扩展阶段
if not node.is_terminal:
expand_node(node)
# 模拟阶段(此处简化使用评估值代替随机模拟)
reward = evaluate_node(node)
# 回溯更新
while node is not None:
node.visits += 1
node.value += reward / (node.depth + 1) # 深度归一化
node = node.parent
# 调试输出
print(f"Iteration {i+1}: Root visits={root.visits}, value={root.value:.2f}")
return root
结果分析
测试用例: <final_answer>x=2,y=4</final_a...
提取结果: x=2,y=4 | 预期: True | 状态: 成功
------------------------------------------------------------
测试用例: Final_Answer: X=2.5, Y=3...
提取结果: x=2.5,y=3 | 预期: True | 状态: 成功
------------------------------------------------------------
测试用例: 答案:x=3, y=2.8...
提取结果: x=3,y=2.8 | 预期: True | 状态: 成功
------------------------------------------------------------
测试用例: x=值,y=值...
提取结果: None | 预期: False | 状态: 成功
------------------------------------------------------------
测试用例: x=, y=4...
提取结果: None | 预期: False | 状态: 成功
------------------------------------------------------------
测试用例: 无效内容...
提取结果: None | 预期: False | 状态: 成功
------------------------------------------------------------
============================================================
处理问题:解方程组:
方程1: 2x + 3y = 16
方程2: x - 2y = -6
[DEBUG] 首次扩展后的子节点:
子节点1:{'state': "请严格遵循以下要求:\n1. 必须使用具体数值,禁止使用'值'等占位符\n2. 小数保留两位,例如:x=...", 'depth': 1, 'value': 0.0, 'visits': 0, 'is_terminal': True}
子节点2:{'state': "请严格遵循以下要求:\n1. 必须使用具体数值,禁止使用'值'等占位符\n2. 小数保留两位,例如:x=...", 'depth': 1, 'value': 0.0, 'visits': 0, 'is_terminal': True}
子节点3:{'state': "请严格遵循以下要求:\n1. 必须使用具体数值,禁止使用'值'等占位符\n2. 小数保留两位,例如:x=...", 'depth': 1, 'value': 0.0, 'visits': 0, 'is_terminal': True}
Iteration 1: Root visits=1, value=0.56
Iteration 2: Root visits=2, value=1.12
Iteration 3: Root visits=3, value=1.68
Iteration 4: Root visits=4, value=2.24
Iteration 5: Root visits=5, value=2.80
Iteration 6: Root visits=6, value=3.36
Iteration 7: Root visits=7, value=3.92
Iteration 8: Root visits=8, value=4.48
Iteration 9: Root visits=9, value=5.04
Iteration 10: Root visits=10, value=5.60
Iteration 11: Root visits=11, value=6.16
Iteration 12: Root visits=12, value=6.72
Iteration 13: Root visits=13, value=7.28
Iteration 14: Root visits=14, value=7.84
Iteration 15: Root visits=15, value=8.40
Iteration 16: Root visits=16, value=8.96
Iteration 17: Root visits=17, value=9.52
Iteration 18: Root visits=18, value=10.08
Iteration 19: Root visits=19, value=10.64
Iteration 20: Root visits=20, value=11.20
Iteration 21: Root visits=21, value=11.76
Iteration 22: Root visits=22, value=12.32
Iteration 23: Root visits=23, value=12.88
Iteration 24: Root visits=24, value=13.44
Iteration 25: Root visits=25, value=14.00
Iteration 26: Root visits=26, value=14.56
Iteration 27: Root visits=27, value=15.12
Iteration 28: Root visits=28, value=15.68
Iteration 29: Root visits=29, value=16.24
Iteration 30: Root visits=30, value=16.80
Iteration 31: Root visits=31, value=17.36
Iteration 32: Root visits=32, value=17.92
Iteration 33: Root visits=33, value=18.48
Iteration 34: Root visits=34, value=19.04
Iteration 35: Root visits=35, value=19.60
Iteration 36: Root visits=36, value=20.16
Iteration 37: Root visits=37, value=20.72
Iteration 38: Root visits=38, value=21.28
Iteration 39: Root visits=39, value=21.84
Iteration 40: Root visits=40, value=22.40
Iteration 41: Root visits=41, value=22.96
Iteration 42: Root visits=42, value=23.52
Iteration 43: Root visits=43, value=24.08
Iteration 44: Root visits=44, value=24.64
Iteration 45: Root visits=45, value=25.20
Iteration 46: Root visits=46, value=25.76
Iteration 47: Root visits=47, value=26.32
Iteration 48: Root visits=48, value=26.88
Iteration 49: Root visits=49, value=27.44
Iteration 50: Root visits=50, value=28.00
[DEBUG] 根节点属性: visits=50, value=27.999999999999975
树结构已保存至 tot_structure_20250410_133415.pdf
可视化结果已保存至 equation_visualization.png
========================================
最优推导路径:
1. Depth 1:
1. 方程2 => x = 2y - 6
2. 代入方程1得 4y - 12 + 3y = 16
3. 解得 y = 4.00 => x = 2.00
========================================
最终答案:x=2.00,y=4.00
========================================
完整代码实现请见我的github仓库:naidezhujimo/Tree-shaped-chain-of-thought