【论文解读】谷歌的MTP方法《Better & Faster Large Language Models via Multi-token Prediction》

论文链接:https://arxiv.org/abs/2404.19737


论文读书笔记:Better & Faster Large Language Models via Multi-token Prediction

作者:Fabian Gloeckle, Badr Youbi Idrissi, Baptiste Rozière, David Lopez-Paz, Gabriel Synnaeve
主题:通过多-token 预测训练语言模型,以提高样本效率、生成质量和推理速度。


1. 摘要 (Abstract)

  • 背景:传统大型语言模型(如 GPT、Llama)采用“下一个 token”预测作为无监督训练目标。
  • 问题:单 token 预测容易陷入局部模式,忽略长距离依赖,因此需要大量数据才能达到较高的流畅性。
  • 方法:在每个训练位置同时预测接下来的 n n n 个 token。模型采用一个共享的 Transformer 主干和 n n n 个独立输出头进行并行预测。
  • 实验结果
    • 在代码生成任务(如 HumanEval、MBPP)上显著提升,例如 13B 模型在 HumanEval 上解决问题数提升 12%(pass@1)。
    • 利用额外预测头进行自我推测解码(self-speculative decoding),推理速度可提升至 3 倍。
  • 贡献
    1. 提出一种零额外训练时间或内存开销的多-token 预测架构。
    2. 实验证明,该方法在大规模模型上效果显著,尤其在代码生成任务中。
    3. 利用额外输出头实现推理加速。

2. 引言 (Introduction)

  • 动机
    • 人类能够以较少数据学习语言,而当前 LLMs 需要海量数据。
    • 单 token 预测(教师强制)容易依赖局部统计,忽略“难”决策和长距离依赖。
  • 观点:通过同时预测多个未来 token,模型能够捕捉更长上下文信息及语义结构,从而改善生成质量与推理效率。
  • 图示说明:论文 Figure 1 展示了如何利用 4 个独立输出头进行 4-token 预测;推理时仅使用第一个输出头,其他头可用于加速推理。

3. 方法 (Method)

3.1 标准语言模型训练目标

传统语言模型目标函数为最小化交叉熵损失,其中历史上下文用 x 1 : t x_{1:t} x1:t 表示,下一个 token 为 x t + 1 x_{t+1} xt+1。其目标为:
L 1 = − ∑ t log ⁡ P θ ( x t + 1 ∣ x 1 : t ) L_1 = -\sum_t \log P_\theta(x_{t+1} \mid x_{1:t}) L1=tlogPθ(xt+1x1:t)

3.2 多-token 预测目标

在这里插入图片描述

  • 目标扩展:在每个位置同时预测接下来 n n n 个 token,目标函数为:
    L n = − ∑ t log ⁡ P θ ( x t + n : t + 1 ∣ x 1 : t ) L_n = -\sum_t \log P_\theta(x_{t+n:t+1} \mid x_{1:t}) Ln=tlogPθ(xt+n:t+1x1:t)
  • 模型架构
    • 采用 Transformer 主干 f s f_s fs 将上下文 x 1 : t x_{1:t} x1:t 映射到隐藏表示 z 1 : t z_{1:t} z1:t
    • 设置 n n n 个独立的输出头 f h i f_{h_i} fhi(其中 i = 1 , … , n i=1,\dots,n i=1,,n),分别预测未来第 i i i 个 token,最后通过共享的 unembedding 矩阵 f u f_u fu 输出概率分布:
      P θ ( x t + i ∣ x 1 : t ) = softmax ⁡ ( f u ( f h i ( f s ( x 1 : t ) ) ) ) P_\theta(x_{t+i} \mid x_{1:t}) = \operatorname{softmax}(f_u(f_{h_i}(f_s(x_{1:t})))) Pθ(xt+ix1:t)=softmax(fu(fhi(fs(x1:t))))
    • 特别地, i = 1 i=1 i=1 对应于传统的下一个 token 预测。

3.3 内存高效实现

  • 问题描述
    大规模语言模型中,词表大小 V V V 远大于隐藏层维度 d d d,若直接生成形状为 ( n , V ) (n, V) (n,V) 的 logits 会占用大量 GPU 内存。
  • 解决方案
    • 在经过共享主干 f s f_s fs 后,对每个输出头 f h i f_{h_i} fhi 顺序执行前向与反向传播,计算完一个 head 后即释放其 logits 内存,只保留 d d d 维梯度。
    • 这样峰值内存从 O ( n V + d ) O(nV + d) O(nV+d) 降为 O ( V + d ) O(V + d) O(V+d),且不增加额外运行时间。
  • 图示说明:论文 Figure 2 展示了这一顺序计算流程及其内存节省效果。

3.4 推理过程

  • 标准推理
    生成时仅使用第一个(下一个 token)预测头,执行自回归生成。
  • 自我推测解码
    • 利用额外输出头,采用 blockwise parallel decoding(参见 Stern et al., 2018)进行推理加速。
    • 在代码和自然语言生成任务中,能实现 2.7× 至 3× 的推理速度提升。

4. 实验 (Experiments)

论文通过多个实验验证多-token 预测的有效性,主要包括以下方面:

4.1 模型规模与性能提升

  • 设置
    在代码数据上训练从 300M 到 13B 参数的模型。
  • 评估指标
    在 MBPP 与 HumanEval 基准上使用 pass@1、pass@10、pass@100 进行评估。
  • 结果
    • 对于较小模型,多-token 预测与传统方法性能相近甚至略逊;
    • 随着模型规模增大,多-token 模型逐步超过基线,表现出显著提升(例如 pass@1 提升数个百分点)。

4.2 推理加速实验

  • 方法
    使用自我推测解码,在不同 batch size 下测试推理速度。
  • 结果
    • 在代码任务上,4-token 预测模型可获得 3.0× 的推理速度提升;
    • 对于字节级模型,在某些设置下可达 6.4× 加速效果。

4.3 字节级训练与全局模式学习

  • 实验设计
    使用字节级 token 化,在 314B 字节数据上训练 7B 模型(约 116B tokens)。
  • 对比
    单字节预测与多字节(如 8-byte)预测。
  • 结果
    • 8-byte 模型在 MBPP 上 pass@1 提升 67%,在 HumanEval 上提升 20%;
    • 自我推测解码在字节级任务中同样实现显著加速。

4.4 最优 n n n 的选择

  • 消融实验
    测试不同预测窗口 n n n(取值 1, 2, 4, 6, 8)的效果。
  • 结论
    • 对于代码任务,通常 n = 4 n=4 n=4 表现最佳,但在部分基准(如 APPS/Intro)上 n = 6 n=6 n=6 略占优势;
    • 对于字节级模型,最优预测窗口通常为 8。

4.5 多轮训练(Epochs)

  • 观察
    在同一数据上进行多轮训练时,多-token 预测依然保持优势(如 MBPP 上 pass@1 提升 2.4%),说明该方法在充分利用数据方面具有持续优势。

4.6 微调实验

  • 任务
    在 CodeContests 数据集上进行微调。
  • 设置
    • 比较预训练后继续使用多-token 目标微调与转为传统下一个 token 预测的效果。
  • 结果
    • 两种微调方式均优于纯下一个 token 预训练的模型,且传统微调略有优势,这符合预训练—微调的常见范式。

4.7 自然语言任务

  • 数据
    在 200B tokens 的自然语言数据上训练 7B 模型,分别采用 next-token、2-token、4-token 预测目标。
  • 评估
    • 在标准多项选择题和负对数似然指标上表现相近;
    • 在生成任务(如摘要、数学题)上,多-token 模型(尤其是 n = 2 n=2 n=2 模型)表现出一定优势。

5. 合成数据消融实验 (Ablations on Synthetic Data)

5.1 归纳能力 (Induction Capability)

  • 实验设计
    • 在儿童故事数据集上,对人物名称进行处理,将名称拆分为两个 token;
    • 第一个 token 依赖上下文语义,第二个 token 则作为归纳任务进行预测。
  • 结果
    • 对于较小模型(30M 参数及以下),采用 2-token 预测显著提升了归纳能力;
    • 当模型参数达到 100M 以上时,优势趋于平稳。

5.2 算法推理 (Algorithmic Reasoning)

  • 任务描述
    设计基于多项式算术的任务,在有限域 F 7 [ X ] / ( X 5 ) \mathbb{F}_7[X]/(X^5) F7[X]/(X5) 下进行一元取反、加法、乘法和复合操作。难度通过操作数 m m m(训练时 m ≤ 5 m\leq5 m5,测试时可超出此范围)进行调控。
  • 结果
    • 多-token 预测显著提升了模型在算法推理任务上的准确率,尤其在出域泛化上,比单纯增大模型参数效果更明显。

6. 为什么多-token 预测有效? (Why Does It Work?)

6.1 关键决策点的强化 (Lookahead Reinforces Choice Points)

  • 观点
    • 生成过程中并非所有 token 预测都同等重要,某些 token 是“关键选择点”,决定了后续文本的整体方向。
    • 多-token 预测通过同时预测连续多个 token,为那些与后续高度相关且难以预测的 token 赋予了更高的隐式权重,从而促使模型在关键位置做出更准确的决策。
  • 直观例子
    论文中给出一个例子,说明某个难预测的转折(如 “5 → A”)及其后续 token 均被赋予更高权重,从而推动模型关注这一“关键”转折。

6.2 信息论论证 (Information-Theoretic Argument)

  • 基本思路
    单 token 预测目标关注 H ( X ) H(X) H(X)(下一个 token 的熵),而 2-token 预测关注 H ( X ) + H ( Y ) H(X) + H(Y) H(X)+H(Y)
  • 分解
    • 单 token 情形:
      H ( X ) = H ( X ∣ Y ) + I ( X ; Y ) H(X) = H(X\mid Y) + I(X;Y) H(X)=H(XY)+I(X;Y)
    • 2-token 情形:
      H ( X ) + H ( Y ) = H ( X ∣ Y ) + 2 I ( X ; Y ) + H ( Y ∣ X ) H(X) + H(Y) = H(X\mid Y) + 2I(X;Y) + H(Y\mid X) H(X)+H(Y)=H(XY)+2I(X;Y)+H(YX)
  • 分析
    忽略 H ( Y ∣ X ) H(Y\mid X) H(YX) 后,2-token 预测加大了互信息 I ( X ; Y ) I(X;Y) I(X;Y) 的比重,促使模型在训练时更多关注 token 之间的重要依赖关系,从而提高生成连贯性和全局一致性。

7. 相关工作 (Related Work)

  • 多任务/多目标预测
    与 Caruana (1997) 提出的多任务学习思想类似,不同任务共享主干网络,而输出头各自独立。
  • 其他语言模型训练目标
    Qi et al. (2020) 的多-token 预测方法以及 Tay et al. (2022) 的 span corruption 目标等。
  • 推理加速
    类似 Stern et al. (2018) 提出的 blockwise parallel decoding,本方法通过额外输出头实现自我推测解码,从而加速生成。

8. 结论 (Conclusion)

  • 核心观点
    • 多-token 预测通过同时预测未来多个 token,有效缩减了教师强制训练与自回归生成之间的分布差异。
    • 该方法不仅提高了模型生成质量(尤其在代码生成任务中),同时利用额外预测头实现了推理加速,最高可达 3 倍以上。
  • 未来方向
    • 自动选择最优预测窗口 n n n
    • 调整词汇表大小以适应多-token 预测;
    • 探索在嵌入空间中操作的辅助预测损失。

9. 总结

多-token 预测为大型语言模型提供了一种简单而高效的改进策略:

  • 提升样本利用率:通过同时预测多个 token,模型能够捕捉长距离依赖和全局结构。
  • 强化关键决策:为重要选择点赋予更高的隐式权重,从而改善生成连贯性。
  • 加速推理:利用额外预测头实现自我推测解码,大幅缩短生成时间。

这种方法在代码生成和算法推理任务中均显示出优异性能,同时为自然语言生成提供了新的改进方向,对未来模型预训练和高效推理具有重要意义。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值