Mixture‑of‑Recursions让模型对简单 token 少想、对复杂 token 多想
全文核心
MoR = 同一权重循环复用(省参数)+ token 级动态深度(省算力)+ 精打细算 KV(省显存)。模型学会“难题多想、简单题少想”。
1 为什么经典 Transformer 会浪费?
经典 Transformer 从 2017 年诞生至今架构几乎没变:同样的 24 层(或 32、48 层)串行堆叠,每个 token 必须完整穿过所有层。这就像一条 24 站的装配线,不论零件多简单,都强制走完整流程,浪费主要体现在三方面:
| 浪费点 | 详细原因 | 直观后果 |
|---|---|---|
| 算力冗余 | 复杂度固定 ≈ 层数 × 序列长度。简单 token 在后期层几乎不再增益,却仍消耗 FLOPs。 | A100 上 2048 token 单次推理≈ 0.8 s;如能提前退出理想可降到 <0.4 s。 |
| 显存暴涨 | 每层都需存 Key‑Value (KV) 对。d=1280、L=2048 时单层 KV≈16 MB,24 层≈384 MB。 | 对话长一点就 OOM,即便 40 GB A100 也会爆显存。 |
| 延迟不均 | 所有 token 同速 → 必须等最慢的长句跑完 24 层。 | 响应时间抖动,GPU 利用率低。 |
小结:固定深度 = 多算、全存、慢响应。MoR 要做的就是“让每个 token 拿到刚刚够用的计算预算”。
2 MoR 的核心思路
MoR 由 递归块 Recursion Block、Router 和 选择性 KV 缓存 组成,形成“按需深度 × 参数共享 × 显存精简”闭环。
2.1 递归块 Recursion Block
| 关键词 | 设计逻辑 | 直接收益 |
|---|---|---|
| 多层打包 | 把 4–6 层合成函数 fθ,一次定义反复用 | 权重只存一份,参数 ↓4–6× |
| 循环调用 | 每个 token 最多跑 Dmax 圈: h→fθ(h)→… | 复杂词能深入,简单词早退 |
| Middle‑Cycle 共享 | 只共享中间几圈,首尾层保持独立 | 验证困惑度最低、收敛快 |
数学一眼看懂 :
h_i(0) = 输入嵌入
h_i(d) = fθ( h_i(d‑1) )
d = 1 … Dmax
2.2 Router — 给 token 发“深度配额”
Router 是一层或两层 MLP,参数量<0.1 %。它读首圈隐藏状态,输出概率向量 p(d)。
| 路由模式 | 决策时机 | 典型场景 | 稳定训练秘笈 |
|---|---|---|---|
| Token‑choice | 开局一次性给出 d | 在线对话、低延迟 | Balancing Loss + 温度退火 |
| Expert‑choice | 每圈重新挑 top‑k % 难词 | 离线大批量推理 | 路由辅助损 + Gumbel‑Softmax |
常见坑:Router 过热→全部浅层;解决:温度逐步降 & 熵正则。
2.3 选择性 KV 缓存
| 技巧 | 峰值显存节省 | 典型场景 | 性能影响 |
|---|---|---|---|
| 递归级缓存 | 理论 (Nr+1)/(2Nr) | 普通推理 | 无精度损 |
| Recursive Sharing | ≈50 % | 64k+ 长上下文 | PPL 升 0.1,可忽略 |
实现关键:已经“毕业”的 token 不再占用 KV 内存;在 Flash‑Attention 里先看一张“活跃名单”,名单外的 token 直接跳过、完全不算。
3 递归块:参数极致复用
- 参数对比:以 360 M baseline(24×15 M/层)为例,改为 4 层/块 ×4 圈,唯一权重≈ 4×15 M=60 M,参数直接砍 >70 %。
- 梯度截断:最长路径 Dmax 圈;后向传播时只回传到 d≤token 实际深度,梯度爆炸/消失问题明显减弱。
- 多尺度共享:共享同一组权重迫使 fθ 同时适应浅语法与深语义两种特征,实测 perplexity 比完全独立层还低约 0.3。
4 Router:思考预算分配
- 输出维度:Dmax 通常设 3‑4;更大深度收益递减。
- 平衡损:
L_balance = (mean_depth − target)^2,把平均深度压到设定预算,例如 1.6 圈。 - 辅助路由损:(Expert‑choice 专用)soft label 让路由提前“猜”下一圈是否仍被选中,提高稳定性。
- 示例温度计划:训练前 10 % step τ=0.75 → 中期线性降到 0.5 → 收尾固定。
小贴士:若训练中观察到深度塌缩(全部 d=1),先提高 τ 或增大 Balancing Loss 权重再继续。
5 KV 缓存:显存精打细算
- 递归级缓存公式:假设 batch 内 token 平均递归圈数 r̄,显存≈(r̄/Dmax)×原始 100 %。若平均只跑 1.5 圈,24 层模型显存即降到 37 %。
- KV Sharing 细节:首圈 KV 在显存中维持,后续圈对同一序列复用查询。需保证块内投影矩阵 weight tying,否则维度不一致。
- Prefill 优势:长上下文生成阶段,prefill 占用高峰由 O(layer)→O(active_layer)。MoR‑3 模型在 1M token 上下文可省约 14 GB 显存。
6 优缺点速览
| 优势 | 说明 |
|---|---|
| 真正三效合一 | 同时省参数、算力、显存 |
| 长上下文友好 | 缓存减半,窗口可达百万级 |
| 边缘端可用 | 8 GB NPU 跑百兆模型,延迟‑30 % |
| 迁移成本低 | 在原权重上继续训 3‑5 B token 即可 |
| 限制 | 对策 |
|---|---|
| 小模型收益有限 | 建议 Dmax ≤3 |
| Router 不稳 | 温度 & Balancing Loss 调参 |
| KV 逻辑需改内核 | 参考官方 Flash‑Attn 分支 |
7 未来可探索
- 层内稀疏 × MoR:MoE 或稀疏注意力塞进递归块,双重稀疏。
- 连续深度预算:Router 输出实值 budget,实现可微动态计算。
- 4‑bit KV + MoR:极端压显存,让手机端跑 Llama‑3。
- 多模态 MoR‑ViT:图像已验证,视频/语音尚在路上。
参考文献
Mixture-of-Recursions: Parameter Sharing for Efficient Token-level Adaptive Computation
Making Transformers More Efficient with MoR
DeepMind’s Mixture‑of‑Recursions could power smaller LLMs
69

被折叠的 条评论
为什么被折叠?



