【论文笔记】LLaDA——突破传统自回归的文本生成新范式

原文链接:https://arxiv.org/abs/2502.09992

在这里插入图片描述

原文摘要:自回归模型(ARMs)被广泛认为是大型语言模型(LLMs)的基石。我们通过引入LLaDA来挑战这一观点,LLaDA是一种在预训练和监督微调(SFT)范式下从头开始训练的扩散模型。LLaDA通过前向数据掩码过程和反向过程对分布进行建模,采用标准Transformer架构预测掩码标记。通过优化似然界限,它为概率推理提供了一种严谨的生成方法。在广泛的基准测试中,LLaDA表现出强大的可扩展性,超越了我们自建的ARM基线。值得注意的是,LLaDA 8B在上下文学习中与LLaMA3 8B等强大的LLMs表现相当,并且在SFT之后的多轮对话等案例研究中表现出优秀的指令遵循能力。此外,LLaDA解决了反向诅咒问题,在反向诗歌补全任务中超越了GPT-4o。我们的研究结果证实了扩散模型作为ARMs的一个可行的替代方案,挑战了上述关键LLM能力与ARMs固有相关的假设。

What is now proved was once only imagined. —William Blake
“现世所有明证,皆曾浮于臆想”

一、核心要点

LLaDA (Large Language Diffusion with mAsking) 模型是一种基于扩散模型 (Diffusion Model) 的大型语言模型,其核心思想借鉴了计算机视觉领域的扩散模型,通过逐步去除掩码来生成文本,而非传统自回归模型逐个生成词元的方式。

LLaDA 采用以下方式生成文本:

  • 前向扩散过程:将原始文本逐步 mask,直到所有 tokens 都被 mask,变成一段随机噪声。
  • 反向去噪过程:训练一个 mask predictor 来预测被 mask 的 tokens,逐步恢复原始文本。

这种方式与自回归模型的逐个 token 预测有本质的区别。自回归模型是顺序生成,而 LLaDA 是并行生成。此外,LLaDA 可以同时考虑上下文信息,而自回归模型通常只能利用单向的上下文信息

概念补充:

  • 自回归模型(ARMs):一种生成模型,通过预测序列中的下一个元素来生成整个序列。例如,给定 “The weather is”,自回归模型会预测下一个词是 “nice”,然后继续预测下一个词,直到生成完整的句子。
  • 扩散模型(Diffusion Models):一类生成模型,通过逐步向数据中添加噪声,然后再学习如何逆向去除噪声来生成数据。在图像生成中,扩散模型可以从一张完全是噪声的图片开始,逐步去除噪声,最终生成清晰的图像。
  • 费希尔一致性(Fisher consistency):,指当数据量趋于无穷大时,模型能够收敛到真实分布。
  • 逆转诅咒(Reversal curse):指一个模型在训练时学习到的句子是“A是B”,它不会自动泛化到相反的方向“B是A”。

二、研究背景

大语言模型完全属于生成建模的框架,旨在通过优化模型分布 p θ ( ⋅ ) p_\theta(\cdot) pθ() 来逼近数据分布 p data ( ⋅ ) p_{\text{data}}(\cdot) pdata(),可以通过最大似然估计,或者等价地通过最小化两个分布之间的KL散度来实现:
生成模型原理:
max ⁡ θ E p data ( x ) log ⁡ p θ ( x ) ⇔ min ⁡ θ KL ( p data ( x ) ∥ p θ ( x ) ) \max_\theta \mathbb{E}_{p_{\text{data}}(x)} \log p_\theta(x) \Leftrightarrow \min_\theta \text{KL}(p_{\text{data}}(x) \| p_\theta(x)) θmaxEpdata(x)logpθ(x)θminKL(pdata(x)pθ(x))

主要的方法依赖于自回归建模(ARM)(通常被称为下一个token预测范式)来定义模型分布:
自回归公式:
p θ ( x ) = p θ ( x 1 ) ∏ i = 2 L p θ ( x i ∣ x 1 , … , x i − 1 ) p_\theta(x) = p_\theta(x^1) \prod_{i=2}^L p_\theta(x^i \mid x^1, \ldots, x^{i-1}) pθ(x)=pθ(x1)i=2Lpθ(xix1,,xi1)

其中 x x x 是长度为 L L L 的序列, x i x^i xi 是第 i i i 个token。

文中观点:

  • 可扩展性主要是Transformers、模型和数据规模以及由生成原理引导的Fisher一致性之间相互作用的结果,而非ARM的独特成果
  • 指令遵循和上下文学习的能力似乎是所有结构一致的语言任务上合适的条件生成模型的内在属性,而非ARM的专属优势
  • 自回归特性固有的局限性
    • 逐个词元顺序生成会产生高计算成本
    • 从左到右的建模限制了其在反向推理任务中的有效性

——引入LLaDA(Large Language Diffusion with mAsking,大型语言掩码扩散模型),以研究LLM所展现的能力是否可以从自回归公式之外的生成建模原理中产生,从而解决前面提出的基本问题。

三、主要贡献

  • 可扩展性:LLaDA有效地扩展到10²³ FLOPs的计算预算,在六个任务(如MMLU和GSM8K)上,其结果与在相同数据上训练的自建ARM基线相当。
  • 上下文学习:值得注意的是,LLaDA 8B在几乎所有15个标准的零/少样本学习任务上都超过了LLaMA2 7B,同时与LLaMA3 8B表现相当。
  • 指令遵循:LLaDA在SFT后显著增强了遵循指令的能力,如在多轮对话等案例研究中所示。
  • 反向推理:LLaDA有效地打破了反转诅咒,在前向和反向任务中表现出一致的性能。值得注意的是,它在反向诗歌补全任务中优于GPT-4o。

四、自回归建模范式介绍

自回归建模(Autoregressive Modeling,简称ARM)是一种生成模型,它通过逐步预测序列中的下一个元素来生成整个序列。这种范式在自然语言处理、时间序列分析、音乐生成等领域中非常流行。

4.1 基本概念

在自回归建模中,模型的目标是学习一个序列数据的概率分布 p ( x ) p(x) p(x),其中 x x x 是一个序列。序列可以是文本、时间序列数据、音频信号等。自回归模型通过将序列分解为条件概率的乘积来近似这个分布:

p ( x ) = p ( x 1 ) ∏ i = 2 L p ( x i ∣ x 1 , … , x i − 1 ) p(x) = p(x^1) \prod_{i=2}^L p(x^i \mid x^1, \ldots, x^{i-1}) p(x)=p(x1)i=2Lp(xix1,,xi1)

这里, x i x^i xi 表示序列中的第 i i i 个元素, L L L 是序列的长度。

4.2 工作流程

  1. 初始化:模型首先学习序列的第一个元素 x 1 x^1 x1 的概率分布 p ( x 1 ) p(x^1) p(x1)
  2. 递归预测:一旦模型确定了序列的第一个元素,它就会使用这个元素来预测第二个元素 x 2 x^2 x2 的条件概率分布 p ( x 2 ∣ x 1 ) p(x^2 \mid x^1) p(x2x1)。这个过程会递归地进行,每一步都依赖于之前预测的元素。
  3. 序列生成:在生成新的序列时,模型从第一个元素开始,逐步预测下一个元素,直到生成完整的序列。

4.3 优势与局限

优势:

  • 灵活性:自回归模型可以处理任意长度的序列,并且可以很容易地扩展到新的数据。
  • 解释性:由于模型是逐步生成序列的,每一步的预测都可以被解释,这有助于理解模型的行为。
  • 适用性:自回归模型适用于多种类型的序列数据,包括文本、时间序列、音频等。

局限:

  • 计算效率:由于每一步的预测都依赖于之前所有的预测,这可能导致计算效率较低,尤其是在处理长序列时。
  • 长期依赖:在长序列中,模型可能会遇到长期依赖问题,即模型难以捕捉序列中远距离元素之间的关系。

4.4 常见模型

  • RNN(循环神经网络):特别是LSTM和GRU,它们能够处理序列数据,捕捉时间序列中的依赖关系。
  • Transformer:虽然不是传统意义上的自回归模型,但通过自回归解码器(如GPT系列)也可以实现序列生成。
  • ARIMA:在时间序列分析中,ARIMA模型是一种常见的自回归模型。

4.5 ARM vs LLaDA

特性传统自回归模型LLaDA 扩散模型
生成方式逐个词元生成并行掩码预测
上下文利用因果掩码限制双向注意力机制
解码过程顺序解码过程迭代去噪生成
逆向推理能力易受“逆转诅咒”影响优秀逆向推理能力

自回归模型之所以会出现 reversal curse,是因为它是单向建模的,只能学习到 A -> B 的条件概率,而无法学习到 B -> A 的条件概率。

LLaDA 能够更好地解决 reversal curse,是因为它是双向建模的。在训练过程中,LLaDA 需要预测被 mask 的 tokens,这使得模型能够同时学习到 A -> B 和 B -> A 的关系。

五、LLaDA模型方法介绍

5.1 整体架构

在这里插入图片描述

(a) 预训练。LLaDA在文本上进行训练,文本中的所有词元以相同比例 t ~ U[0, 1] 独立随机掩码。
(b) SFT。仅响应词元可能被掩码。
© 采样。LLaDA模拟一个从t=1(完全掩码)到t=0(未掩码)的扩散过程,在每个步骤中同时预测所有掩码,并采用灵活的重掩码策略。

5.2 概率公式

与 ARM 不同,LLaDA 通过一个前向过程和一个反向过程定义了一个模型分布 p θ ( x 0 ) p_\theta(x_0) pθ(x0)

  • 前向过程逐渐在 x 0 x_0 x0 中独立地掩盖 token,直到在 t = 1 t = 1 t=1 时序列完全被掩盖。对于 t ∈ ( 0 , 1 ) t \in (0, 1) t(0,1),序列 x t x_t xt 被部分掩盖,每个 token 被掩盖的概率为 t t t 或以概率 1 − t 1 - t 1t 保持未掩盖。
  • 反向过程通过迭代预测被掩盖的 token(当 t t t 从 1 移动到 0 时)来恢复数据分布。

LLaDA 的核心是一个掩码预测器,一个参数模型 p θ ( ⋅ ∣ x t ) p_\theta(\cdot | x_t) pθ(xt),它将 x t x_t xt 作为输入并同时预测所有被掩盖的 token(表示为 M)。它使用仅在被掩盖的 token 上计算的交叉熵损失进行训练:

L ( θ ) ≜ − E t , x 0 , x t [ 1 t ∑ i = 1 L 1 [ x t i = M ] log ⁡ p θ ( x 0 i ∣ x t ) ] , \mathcal{L}(\theta) \triangleq -\mathbb{E}_{t, x_0, x_t} \left[ \frac{1}{t} \sum_{i=1}^L \mathbf{1}[x_t^i = M] \log p_\theta(x_0^i | x_t) \right], L(θ)Et,x0,xt[t1i=1L1[xti=M]logpθ(x0ixt)],

其中 x 0 x_0 x0 从训练数据中采样, t t t 从 [0, 1] 中均匀采样, x t x_t xt 从前向过程中采样。指示函数 1 [ ⋅ ] \mathbf{1}[\cdot] 1[] 确保损失仅对被掩盖的 token 计算。

训练完成后,将 t = 0 t = 0 t=0 时的边际分布定义为模型分布 p θ ( x 0 ) p_\theta(x_0) pθ(x0)。模型损失是模型分布负对数似然的上界:

− E p data ( x 0 ) [ log ⁡ p θ ( x 0 ) ] ≤ L ( θ ) -\mathbb{E}_{p_{\text{data}}(x_0)} [\log p_\theta(x_0)] \leq L(\theta) Epdata(x0)[logpθ(x0)]L(θ)

5.3 训练策略

5.3.1 训练阶段

LLaDA 模型的训练过程的两个主要阶段:预训练(Pre-training)和监督微调(SFT)

预训练阶段

LLaDA 模型的核心组件是一个标准的 Transformer 架构(作为掩码预测器,在反向过程中用于预测被掩码的词元)。与传统的自回归 Transformer (如 GPT 系列) 不同,LLaDA 在预训练阶段不使用因果掩码 (causal masking),而是采用双向注意力机制,允许模型在处理每个词元时都能看到整个输入序列的上下文信息 。
在 LLaDA 的 8B 参数版本中,为了在保持参数量的同时,平衡由于不支持 KV 缓存(KV caching)带来的计算开销,模型采用了原生的多头注意力机制(vanilla multi-head attention)并适当减小了前馈网络(FFN)的维度

训练细节:

  • 在包含2.3万亿(T)标记的数据集上进行预训练
  • 数据来自在线语料库,通过手动设计的规则和基于LLM的方法过滤低质量内容。
  • 预训练过程使用固定序列长度4096个标记,总计算成本为13万H800 GPU小时,与相同规模和数据集大小的ARMs相似。
  • 为了增强LLaDA处理可变长度数据的能力,将1%的预训练数据设置为随机长度,该长度从范围 [1, 4096] 中均匀采样。
  • 对于训练序列 x 0 x_0 x0,随机采样 t ∈ [ 0 , 1 ] t \in [0, 1] t[0,1],独立地以相同的概率 t t t 掩蔽每个标记以获得 x t x_t xt,并通过蒙特卡洛方法估计上述概率公式,以进行随机梯度下降训练。

监督微调阶段

  • 使用配对数据( p 0 , r 0 p_0, r_0 p0,r0)增强LLaDA遵循指令的能力,其中 p 0 p_0 p0 是提示, r 0 r_0 r0 表示响应。
  • 需要对条件分布 p θ ( r 0 ∣ p 0 ) p_\theta(r_0 | p_0) pθ(r0p0) 建模,而不是预训练中的 p θ ( x 0 ) p_\theta(x_0) pθ(x0)

实现方法与预训练类似。保持提示不变,并独立地掩蔽响应中的标记,就像对 x 0 x_0 x0 一样。然后,将提示和掩蔽的响应 r t r_t rt 输入预训练的掩码预测器以计算SFT的损失,其中 L ′ L' L 表示动态长度。

− E t , p 0 , r 0 , r t [ 1 t ∑ i = 1 L ′ 1 [ r t i = M ] log ⁡ p θ ( r 0 i ∣ p 0 , r t ) ] -\mathbb{E}_{t, p_0, r_0, r_t} \left[ \frac{1}{t} \sum_{i=1}^{L'} \mathbb{1}[r_t^i = M] \log p_\theta(r_0^i | p_0, r_t) \right] Et,p0,r0,rt t1i=1L1[rti=M]logpθ(r0ip0,rt)

这种方法与预训练完全兼容

  • p 0 p_0 p0 r 0 r_0 r0 的连接可以被视为干净的预训练数据 x 0 x_0 x0
  • p 0 p_0 p0 r t r_t rt 的连接则作为掩码版本 x t x_t xt
  • 该过程与预训练相同,唯一的区别是所有掩码标记恰好出现在 r 0 r_0 r0 部分。

5.3.2 优化器选择

LLaDA 模型在训练过程中采用了 AdamW 优化器

AdamW 是 Adam 优化器的一个变种,其主要改进在于将权重衰减与梯度更新解耦,从而更有效地进行正则化并提升模型性能 。Adam 优化器本身结合了动量和自适应学习率的优点,能够为每个参数维护一个一阶矩估计和二阶矩估计,并根据这些估计来调整学习率。这种机制使得 Adam 及其变种在处理稀疏梯度或噪声较多的数据集时表现良好,非常适合大规模语言模型的训练。

5.3.3 学习率调度策略

预训练阶段:采用Warmup-Stable-Decay策略, 包括三个阶段:

  • “热身”(Warmup)阶段,学习率从一个很小的值(甚至为零)逐渐增加到预设的峰值学习率。有助于在训练初期稳定优化过程,避免因初始梯度较大而导致模型参数震荡或不收敛。
  • “稳定”(Stable)阶段,在此阶段学习率保持在峰值学习率不变,模型参数进行充分更新。
  • “衰减”(Decay)阶段,学习率根据某种策略(如线性衰减、指数衰减或按步衰减)逐渐减小,以便在训练后期对模型参数进行更精细的调整,帮助模型收敛到更优的局部最优点。

SFT阶段:采用与预训练阶段类似的方案,训练3个周期。

LLaDA 模型学习率调度策略总结

训练阶段学习率调度策略关键参数/阶段描述
预训练Warmup-Stable-DecayWarmup: 0 到 4 × 10⁻⁴ (2000 次迭代)
Stable: 4 × 10⁻⁴ (1.2T tokens), 后降至 1 × 10⁻⁴ (0.8T tokens)
Decay: 1 × 10⁻⁴ 线性降低到 1 × 10⁻⁵ (0.3T tokens)
监督微调 (SFT)与预训练阶段类似前50次迭代: 0 到 2.5 × 10⁻⁵ ==> 保持恒定 ==> 最后10%次迭代:2.5 × 10⁻⁵ 线性降至 2.5 × 10⁻⁶

5.4 推理

推理步骤:

  1. 从采样开始,给定提示 p 0 p_0 p0,通过从模型分布 p θ ( r 0 ∣ p 0 ) p_\theta(r_0 | p_0) pθ(r0p0) 中采样来离散化反向过程,从完全掩蔽的响应开始
  2. 在从时间 t ∈ ( 0 , 1 ] t \in (0, 1] t(0,1] s ∈ [ 0 , t ) s \in [0, t) s[0,t) 的中间步骤中,将 p 0 p_0 p0 r t r_t rt 输入掩码预测器,并同时预测所有掩码标记。
  3. 随后,将预测标记中的 s t \frac{s}{t} ts 个词元remask,以获得 r s r_s rs,确保反向过程的转换与正向过程一致,以实现准确采样。

采样步数是一个超参数,为LLaDA提供了一个效率与样本质量之间的权衡。
生成长度也视为超参数,指定采样过程开始时完全掩蔽句子的长度。(由于预训练和SFT都使用了可变长度的数据集,因此最终结果对这个长度超参数不敏感)

5.5 不同的重掩码策略

三种remask策略:

  • Random remasking(随机 remask):在每个推理步骤中,随机选择一部分 tokens 进行 mask。这种策略简单直接,但可能不够高效。
  • Low-confidence remasking(低置信度 remask):在每个推理步骤中,选择模型预测置信度最低的 tokens 进行 mask。这种策略可以集中精力预测模型不确定的 tokens,提高生成质量。
  • Semi-autoregressive remasking(半自回归 remask)
    • 将文本序列分割成多个块:将需要生成的文本序列分割成多个较短的块,每个块的长度可以根据实际情况进行调整。
    • 块内自回归生成:对于每个块,使用LLaDA模型的反向过程(例如随机重新遮蔽或低置信度重新遮蔽)进行采样,生成该块内的文本。
    • 块间非自回归生成:将生成的块依次连接起来,形成一个完整的文本序列。由于每个块都是独立生成的,因此这种生成方式是非自回归的。

在这里插入图片描述

Semi-Autoregressive Remasking的优势

  • 每个块都是独立生成的,可以并行处理。
  • 块内自回归生成可以充分利用上下文信息,生成的文本更加流畅和连贯
  • 可以应用于各种文本生成任务,例如问答、对话和诗歌创作等。

Semi-Autoregressive Remasking的局限性

  • 块间是非自回归生成的,可能会出现衔接不自然的情况。
  • 块长的选择会影响生成效率和文本质量,需要根据具体任务调整。

在这里插入图片描述

5.6 LLaDA扩散模型框架*

扩散模型框架是 LLaDA 区别于传统自回归语言模型的核心创新。该框架包含两个主要过程:前向掩码过程(forward masking process)和反向去噪恢复过程(reverse denoising process)

在这里插入图片描述

  • 前向过程中,原始文本序列中的一部分词元会被随机地或以特定策略掩码掉,形成一个部分被掩盖的序列。掩码的策略是动态的,即在每个训练步骤中,掩码率 t t t从 0 到 1 之间随机采样的,然后每个词元独立地以概率 t t t 被掩码。
  • 反向过程中,模型(即掩码预测器)的任务是根据这个部分掩码的序列,逐步预测并恢复所有被掩码的词元。这个过程通过多次迭代进行,模型在每一步都会预测所有掩码位置的词元,然后根据预测结果和一定的策略(例如,低置信度重掩码)更新掩码状态,直到所有词元都被恢复或达到预设的迭代次数。这种迭代优化的方式使得 LLaDA 能够以非自回归的方式生成文本,从而在理论上具有更好的并行性和处理双向上下文的能力。

LLaDA采用动态掩码率(random masking ratio):即在每个训练步骤中,掩码词元的比例是随机采样的,这与 BERT 等模型使用固定掩码率的策略不同,这对其作为生成模型以及实现上下文学习的能力至关重要。

六、实验结果

6.1 LLaDA在语言任务上的可扩展性

在这里插入图片描述

研究LLaDA在下游任务上与自构建的ARM基线的可扩展性比较。图表包含六个子图,分别对应MMLU (5-shot), ARC-C (0-shot), CMMLU (5-shot), PIQA (0-shot), GSM8K (4-shot), HumanEval (0-shot)。每个子图的X轴是FLOPs (从10²⁰到10²³),Y轴是对应任务的得分。

如图所示,LLaDA表现出令人印象深刻的可扩展性,其总体趋势与ARM高度竞争。在MMLU和GSM8K等任务中,LLaDA表现出更强的可扩展性。即使在PIQA等性能落后的任务上,LLaDA在更大规模时也缩小了与ARM的差距。

由此可以分析:

  • 在逆转诗歌等任务上LLaDA的优越表现,表明双向建模方案的有效性;
  • 在 PIQA 等任务表现较弱,暗示物理常识推理仍更依赖顺序信息。

6.2 基准测试结果

  • 全面评估LLaDA 8B的上下文学习和指令遵循能力,任务的选择和评估协议与现有研究一致,涵盖了通用任务、数学、代码和中文领域的15个流行基准。
  • 在2.3T词元上预训练后,LLaDA 8B表现出卓越的性能,在几乎所有任务上都超过了LLaMA2 7B,并且总体上与LLaMA3 8B具有竞争力。
  • SFT提高了LLaDA在大多数下游任务上的性能。一些指标(如MMLU)有所下降,推测可能是由于SFT数据质量欠佳所致。
  • 没有进行RL对齐,结果略逊于LLaMA3 8B Instruct,但即使仅通过SFT,LLaDA也展现出令人印象深刻的指令遵循能力。

6.3 反向推理和分析

  • 构建了一个包含496对著名中国古诗句的数据集。给定一首诗中的一句,模型任务是生成下一句(正向)或前一句(反向),无需额外微调。
  • LLaDA有效解决了逆转诅咒,在正向和反向任务中均表现出一致的零样本性能。
  • 相比之下,Qwen 2.5和GPT-40在这两者之间都表现出显著的差距。正向生成的结果证实了这两个ARM都很强大,受益于远大于LLaDA的数据集和计算资源。

在这里插入图片描述

6.4 示例展示

Prompt: Explain what artificial intelligence is.

在这里插入图片描述

七、小结

In the middle of difffculty lies opportunity. —Albert Einstein
“绝境渊底,自有天光。”

LLaDA在可扩展性、上下文学习和指令跟随方面表现出强大的能力,其性能可与现有的强大LLMs相媲美。此外,LLaDA还具有双向建模和增强鲁棒性的独特优势,有效解决了现有LLMs的一些固有限制。尽管前景广阔,但扩散模型的潜力尚未完全发挥。由于计算限制,LLaDA与ARM模型的直接比较受到限制,且LLaDA在推理过程中对超参数较为敏感,尚未与强化学习进行对齐。未来的研究需要进一步扩展LLaDA的规模,探索其在多模态数据处理中的能力,并研究后训练对LLaDA性能的影响。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值