论文链接:https://arxiv.org/abs/2404.19737
论文读书笔记:Better & Faster Large Language Models via Multi-token Prediction
作者:Fabian Gloeckle, Badr Youbi Idrissi, Baptiste Rozière, David Lopez-Paz, Gabriel Synnaeve
主题:通过多-token 预测训练语言模型,以提高样本效率、生成质量和推理速度。
1. 摘要 (Abstract)
- 背景:传统大型语言模型(如 GPT、Llama)采用“下一个 token”预测作为无监督训练目标。
- 问题:单 token 预测容易陷入局部模式,忽略长距离依赖,因此需要大量数据才能达到较高的流畅性。
- 方法:在每个训练位置同时预测接下来的 n n n 个 token。模型采用一个共享的 Transformer 主干和 n n n 个独立输出头进行并行预测。
- 实验结果:
- 在代码生成任务(如 HumanEval、MBPP)上显著提升,例如 13B 模型在 HumanEval 上解决问题数提升 12%(pass@1)。
- 利用额外预测头进行自我推测解码(self-speculative decoding),推理速度可提升至 3 倍。
- 贡献:
- 提出一种零额外训练时间或内存开销的多-token 预测架构。
- 实验证明,该方法在大规模模型上效果显著,尤其在代码生成任务中。
- 利用额外输出头实现推理加速。
2. 引言 (Introduction)
- 动机:
- 人类能够以较少数据学习语言,而当前 LLMs 需要海量数据。
- 单 token 预测(教师强制)容易依赖局部统计,忽略“难”决策和长距离依赖。
- 观点:通过同时预测多个未来 token,模型能够捕捉更长上下文信息及语义结构,从而改善生成质量与推理效率。
- 图示说明:论文 Figure 1 展示了如何利用 4 个独立输出头进行 4-token 预测;推理时仅使用第一个输出头,其他头可用于加速推理。
3. 方法 (Method)
3.1 标准语言模型训练目标
传统语言模型目标函数为最小化交叉熵损失,其中历史上下文用
x
1
:
t
x_{1:t}
x1:t 表示,下一个 token 为
x
t
+
1
x_{t+1}
xt+1。其目标为:
L
1
=
−
∑
t
log
P
θ
(
x
t
+
1
∣
x
1
:
t
)
L_1 = -\sum_t \log P_\theta(x_{t+1} \mid x_{1:t})
L1=−t∑logPθ(xt+1∣x1:t)
3.2 多-token 预测目标
- 目标扩展:在每个位置同时预测接下来
n
n
n 个 token,目标函数为:
L n = − ∑ t log P θ ( x t + n : t + 1 ∣ x 1 : t ) L_n = -\sum_t \log P_\theta(x_{t+n:t+1} \mid x_{1:t}) Ln=−t∑logPθ(xt+n:t+1∣x1:t) - 模型架构:
- 采用 Transformer 主干 f s f_s fs 将上下文 x 1 : t x_{1:t} x1:t 映射到隐藏表示 z 1 : t z_{1:t} z1:t。
- 设置
n
n
n 个独立的输出头
f
h
i
f_{h_i}
fhi(其中
i
=
1
,
…
,
n
i=1,\dots,n
i=1,…,n),分别预测未来第
i
i
i 个 token,最后通过共享的 unembedding 矩阵
f
u
f_u
fu 输出概率分布:
P θ ( x t + i ∣ x 1 : t ) = softmax ( f u ( f h i ( f s ( x 1 : t ) ) ) ) P_\theta(x_{t+i} \mid x_{1:t}) = \operatorname{softmax}(f_u(f_{h_i}(f_s(x_{1:t})))) Pθ(xt+i∣x1:t)=softmax(fu(fhi(fs(x1:t)))) - 特别地, i = 1 i=1 i=1 对应于传统的下一个 token 预测。
3.3 内存高效实现
- 问题描述:
大规模语言模型中,词表大小 V V V 远大于隐藏层维度 d d d,若直接生成形状为 ( n , V ) (n, V) (n,V) 的 logits 会占用大量 GPU 内存。 - 解决方案:
- 在经过共享主干 f s f_s fs 后,对每个输出头 f h i f_{h_i} fhi 顺序执行前向与反向传播,计算完一个 head 后即释放其 logits 内存,只保留 d d d 维梯度。
- 这样峰值内存从 O ( n V + d ) O(nV + d) O(nV+d) 降为 O ( V + d ) O(V + d) O(V+d),且不增加额外运行时间。
- 图示说明:论文 Figure 2 展示了这一顺序计算流程及其内存节省效果。
3.4 推理过程
- 标准推理:
生成时仅使用第一个(下一个 token)预测头,执行自回归生成。 - 自我推测解码:
- 利用额外输出头,采用 blockwise parallel decoding(参见 Stern et al., 2018)进行推理加速。
- 在代码和自然语言生成任务中,能实现 2.7× 至 3× 的推理速度提升。
4. 实验 (Experiments)
论文通过多个实验验证多-token 预测的有效性,主要包括以下方面:
4.1 模型规模与性能提升
- 设置:
在代码数据上训练从 300M 到 13B 参数的模型。 - 评估指标:
在 MBPP 与 HumanEval 基准上使用 pass@1、pass@10、pass@100 进行评估。 - 结果:
- 对于较小模型,多-token 预测与传统方法性能相近甚至略逊;
- 随着模型规模增大,多-token 模型逐步超过基线,表现出显著提升(例如 pass@1 提升数个百分点)。
4.2 推理加速实验
- 方法:
使用自我推测解码,在不同 batch size 下测试推理速度。 - 结果:
- 在代码任务上,4-token 预测模型可获得 3.0× 的推理速度提升;
- 对于字节级模型,在某些设置下可达 6.4× 加速效果。
4.3 字节级训练与全局模式学习
- 实验设计:
使用字节级 token 化,在 314B 字节数据上训练 7B 模型(约 116B tokens)。 - 对比:
单字节预测与多字节(如 8-byte)预测。 - 结果:
- 8-byte 模型在 MBPP 上 pass@1 提升 67%,在 HumanEval 上提升 20%;
- 自我推测解码在字节级任务中同样实现显著加速。
4.4 最优 n n n 的选择
- 消融实验:
测试不同预测窗口 n n n(取值 1, 2, 4, 6, 8)的效果。 - 结论:
- 对于代码任务,通常 n = 4 n=4 n=4 表现最佳,但在部分基准(如 APPS/Intro)上 n = 6 n=6 n=6 略占优势;
- 对于字节级模型,最优预测窗口通常为 8。
4.5 多轮训练(Epochs)
- 观察:
在同一数据上进行多轮训练时,多-token 预测依然保持优势(如 MBPP 上 pass@1 提升 2.4%),说明该方法在充分利用数据方面具有持续优势。
4.6 微调实验
- 任务:
在 CodeContests 数据集上进行微调。 - 设置:
- 比较预训练后继续使用多-token 目标微调与转为传统下一个 token 预测的效果。
- 结果:
- 两种微调方式均优于纯下一个 token 预训练的模型,且传统微调略有优势,这符合预训练—微调的常见范式。
4.7 自然语言任务
- 数据:
在 200B tokens 的自然语言数据上训练 7B 模型,分别采用 next-token、2-token、4-token 预测目标。 - 评估:
- 在标准多项选择题和负对数似然指标上表现相近;
- 在生成任务(如摘要、数学题)上,多-token 模型(尤其是 n = 2 n=2 n=2 模型)表现出一定优势。
5. 合成数据消融实验 (Ablations on Synthetic Data)
5.1 归纳能力 (Induction Capability)
- 实验设计:
- 在儿童故事数据集上,对人物名称进行处理,将名称拆分为两个 token;
- 第一个 token 依赖上下文语义,第二个 token 则作为归纳任务进行预测。
- 结果:
- 对于较小模型(30M 参数及以下),采用 2-token 预测显著提升了归纳能力;
- 当模型参数达到 100M 以上时,优势趋于平稳。
5.2 算法推理 (Algorithmic Reasoning)
- 任务描述:
设计基于多项式算术的任务,在有限域 F 7 [ X ] / ( X 5 ) \mathbb{F}_7[X]/(X^5) F7[X]/(X5) 下进行一元取反、加法、乘法和复合操作。难度通过操作数 m m m(训练时 m ≤ 5 m\leq5 m≤5,测试时可超出此范围)进行调控。 - 结果:
- 多-token 预测显著提升了模型在算法推理任务上的准确率,尤其在出域泛化上,比单纯增大模型参数效果更明显。
6. 为什么多-token 预测有效? (Why Does It Work?)
6.1 关键决策点的强化 (Lookahead Reinforces Choice Points)
- 观点:
- 生成过程中并非所有 token 预测都同等重要,某些 token 是“关键选择点”,决定了后续文本的整体方向。
- 多-token 预测通过同时预测连续多个 token,为那些与后续高度相关且难以预测的 token 赋予了更高的隐式权重,从而促使模型在关键位置做出更准确的决策。
- 直观例子:
论文中给出一个例子,说明某个难预测的转折(如 “5 → A”)及其后续 token 均被赋予更高权重,从而推动模型关注这一“关键”转折。
6.2 信息论论证 (Information-Theoretic Argument)
- 基本思路:
单 token 预测目标关注 H ( X ) H(X) H(X)(下一个 token 的熵),而 2-token 预测关注 H ( X ) + H ( Y ) H(X) + H(Y) H(X)+H(Y)。 - 分解:
- 单 token 情形:
H ( X ) = H ( X ∣ Y ) + I ( X ; Y ) H(X) = H(X\mid Y) + I(X;Y) H(X)=H(X∣Y)+I(X;Y) - 2-token 情形:
H ( X ) + H ( Y ) = H ( X ∣ Y ) + 2 I ( X ; Y ) + H ( Y ∣ X ) H(X) + H(Y) = H(X\mid Y) + 2I(X;Y) + H(Y\mid X) H(X)+H(Y)=H(X∣Y)+2I(X;Y)+H(Y∣X)
- 单 token 情形:
- 分析:
忽略 H ( Y ∣ X ) H(Y\mid X) H(Y∣X) 后,2-token 预测加大了互信息 I ( X ; Y ) I(X;Y) I(X;Y) 的比重,促使模型在训练时更多关注 token 之间的重要依赖关系,从而提高生成连贯性和全局一致性。
7. 相关工作 (Related Work)
- 多任务/多目标预测:
与 Caruana (1997) 提出的多任务学习思想类似,不同任务共享主干网络,而输出头各自独立。 - 其他语言模型训练目标:
Qi et al. (2020) 的多-token 预测方法以及 Tay et al. (2022) 的 span corruption 目标等。 - 推理加速:
类似 Stern et al. (2018) 提出的 blockwise parallel decoding,本方法通过额外输出头实现自我推测解码,从而加速生成。
8. 结论 (Conclusion)
- 核心观点:
- 多-token 预测通过同时预测未来多个 token,有效缩减了教师强制训练与自回归生成之间的分布差异。
- 该方法不仅提高了模型生成质量(尤其在代码生成任务中),同时利用额外预测头实现了推理加速,最高可达 3 倍以上。
- 未来方向:
- 自动选择最优预测窗口 n n n;
- 调整词汇表大小以适应多-token 预测;
- 探索在嵌入空间中操作的辅助预测损失。
9. 总结
多-token 预测为大型语言模型提供了一种简单而高效的改进策略:
- 提升样本利用率:通过同时预测多个 token,模型能够捕捉长距离依赖和全局结构。
- 强化关键决策:为重要选择点赋予更高的隐式权重,从而改善生成连贯性。
- 加速推理:利用额外预测头实现自我推测解码,大幅缩短生成时间。
这种方法在代码生成和算法推理任务中均显示出优异性能,同时为自然语言生成提供了新的改进方向,对未来模型预训练和高效推理具有重要意义。