GPT-1
GPT-1 论文
传统 NLP 需要大量人工标注数据,且词嵌入技术(Word2Vec)仅学习词级表示,无法捕捉句子之间逻辑关系。且对于各个任务(机器翻译、语言建模等)需独立设计模型且迁移复杂。
GPT-1的思想是先通过在无标签的数据上学习一个生成式的语言模型,然后再根据特定热任务进行微调。(自回归语言建模)
无监督预训练
基于语言模型进行训练,给定一个无标签的序列
U
=
{
u
1
,
u
2
,
…
…
,
u
n
}
\mathcal{U}=\{u_1,u_2,……,u_n\}
U={u1,u2,……,un},语言模型的目标是最大化这个似然值:
L
1
(
U
)
=
∑
i
log
P
(
u
i
∣
u
i
−
k
,
…
…
,
u
i
−
1
;
Θ
)
(
1
)
L_1(\mathcal{U})=\sum_i\log P(u_i|u_{i-k},……,u_{i-1};\Theta) \qquad \qquad (1)
L1(U)=i∑logP(ui∣ui−k,……,ui−1;Θ)(1)
其中 k 是滑动窗口大小,P 是条件概率,
Θ
\Theta
Θ 是模型参数。
在 GPT-1 中,使用了 12 个 Transformer 块作为解码器,每个 Transformer 块是一个掩码多头自注意力,通过全连接得到输出的概率分布。(Decoder-only —— Next Token Prediction)
h
0
=
U
W
e
+
W
p
(
2
)
h
l
=
transformer block
(
h
l
−
1
)
∀
i
∈
[
1
,
n
]
(
3
)
P
(
u
)
=
s
o
f
t
m
a
x
(
h
n
W
e
T
)
(
4
)
\begin{aligned} h_{0} & =UW_e+W_p &\qquad \qquad (2)\\ h_{l} & =\text{transformer block}(h_{l-1})\forall i\in[1,n] &\qquad \qquad (3)\\ P(u) & =\mathrm{softmax}(h_nW_e^T) &\qquad \qquad (4) \end{aligned}
h0hlP(u)=UWe+Wp=transformer block(hl−1)∀i∈[1,n]=softmax(hnWeT)(2)(3)(4)
其中
U
=
(
u
−
k
,
…
…
,
u
−
1
)
U=(u_{-k},……,u_{-1})
U=(u−k,……,u−1) 是当前时间片的上下文 token,n 是层数,
W
e
W_e
We 是词嵌入矩阵,
W
p
W_p
Wp 是位置嵌入矩阵。
有监督微调
对于有标签的数据集
C
\mathcal{C}
C,每个实例有 m 个输入 token:
{
x
1
,
…
…
,
x
m
}
\{x^1,……,x^m\}
{x1,……,xm}。首先将这些 token 输入到训练好的预训练模型中,得到特征向量
h
l
m
h_l^m
hlm,通过一个全连接层得到预测结果 y:
P
(
y
∣
x
1
,
…
…
,
x
m
)
=
s
o
f
t
m
a
x
(
h
l
m
W
y
)
(
5
)
P(y|x^1,……,x^m)=softmax(h_l^mW_y) \qquad \qquad (5)
P(y∣x1,……,xm)=softmax(hlmWy)(5)
其中
W
y
W_y
Wy 是全连接层的参数,有监督的目标是最大化式(5):
L
2
(
C
)
=
∑
x
,
y
log
P
(
y
∣
x
1
,
…
…
,
x
m
)
(
6
)
L_2(\mathcal{C})=\sum_{x,y}\log P(y|x^1,……,x^m) \qquad \qquad (6)
L2(C)=x,y∑logP(y∣x1,……,xm)(6)
在微调过程中将语言建模作为辅助目标有助于学习,原因是:1)改善监督模型的泛化能力;2)加速收敛。由此采用:(仅训练
W
y
W_y
Wy 和 token 的嵌入)
L
3
(
C
)
=
L
2
(
C
)
+
λ
L
1
(
C
)
L_3(\mathcal{C})=L_2(\mathcal{C})+\lambda L_1(\mathcal{C})
L3(C)=L2(C)+λL1(C)
输入变换
- 分类任务:将起始和终止 token 加入到原始序列两端,输入 transformer 中得到特征向量,最后经过一个全连接得到预测的概率分布;
- 自然语言推理:将前提(premise)和假设(hypothesis)通过分隔符(Delimiter)隔开,两端加上起始和终止 token。再依次通过 transformer 和全连接得到预测结果;
- 语义相似度:输入的两个句子,正向和反向各拼接一次,然后分别输入给 transformer,得到的特征向量拼接后再送给全连接得到预测结果;
- 问答和常识推理:将 n 个选项的问题抽象化为 n 个二分类问题,即每个选项分别和内容进行拼接,然后各送入 transformer 和全连接中,最后选择置信度最高的作为预测结果。
采用 BooksCorpus 数据集,包含 7,000 本没有发布的书籍。1)数据集拥有更长的上下文依赖关系,使得模型能学得更长期的依赖关系;2)这些书籍因为没有发布,所以很难在下游数据集上见到,更能验证模型的泛化能力。
GPT-2
验证了通过海量数据和大量参数训练出来的词向量模型有迁移到其它类别任务中而不需要额外的训练。
难点:现有多任务训练方法想要通过大数定律逼近实际分布,所需要的 (dataset, objective) 样本量非常大,人工成本高昂。
学习目标是使用无监督的预训练模型做有监督的任务,相比 GPT-1 使用了更多的网络参数和更大的数据集。
因为文本数据的时序性,一个输出序列可以被表示为一系列条件概率的乘积:
p
(
x
)
=
∏
i
=
1
n
p
(
s
n
∣
s
1
,
.
.
.
,
s
n
−
1
)
p(x)=\prod_{i=1}^np(s_n|s_1,...,s_{n-1})
p(x)=i=1∏np(sn∣s1,...,sn−1)
上式也可表示为:
i
n
p
u
t
=
{
s
1
,
s
2
,
…
…
,
s
n
−
k
−
1
}
o
u
t
p
u
t
=
{
s
n
−
k
,
…
…
,
s
k
}
\begin{aligned} input = \{s_1,s_2,……,s_{n-k-1}\} \qquad output=\{s_{n-k},……,s_k\} \end{aligned}
input={s1,s2,……,sn−k−1}output={sn−k,……,sk}
因此语言模型可以建模为:
p
(
o
u
t
p
u
t
∣
i
n
p
u
t
)
p(output|input)
p(output∣input);对于一个有监督的任务,建模为:
p
(
o
u
t
p
u
t
∣
i
n
p
u
t
,
t
a
s
k
)
p(output|input,task)
p(output∣input,task)。
同时,在 decaNLP 中,他们提出的 MQNA 模型可以将机器翻译,自然语言推理,语义分析,关系提取等10类任务统一建模为一个分类任务,而无需再为每一个子任务单独设计一个模型。
基于上面的思想,作者认为,当一个语言模型的容量足够大时,它就足以覆盖所有的有监督任务,也就是说所有的有监督学习都是无监督语言模型的一个子集。
训练 trick:(减少预训练过程中各层之间的方差变化,使梯度更加稳定)
- 后置层归一化( post-norm )改为前置层归一化( pre-norm );
- 在模型最后一个自注意力层之后,额外增加一个层归一化;
- 调整参数的初始化方式,按残差层个数进行缩放,缩放比例为 1 : n \sqrt{n} n ;
- 输入序列的最大长度从 512 扩充到 1024;
GPT-3
GPT-3 论文
在 GPT-2 的基础上,对 Transformer 各层使用了交替稠密且局部带状稀疏自注意力(类似 Spare Transformer)。
传统 self-attention(dense-attention):每个 token 之间两两计算 attention,复杂度为 O ( n 2 ) O(n^2) O(n2) ;Spare Attention:每个 token 只与其他 token 的一个子集计算 attention,复杂度为 O ( n ∗ log n ) ) O(n*\log n)) O(n∗logn)) 。
GPT-3的体量非常庞大,因此在下游任务中进行 fine-tune 的成本很高。为了解决这个问题,GPT-3 使用了 In-Context Learning 的方式,在不进行梯度更新或 fine-tune 的情况下,直接在上下文中进行学习
InstructGPT
InstructGPT 论文
采用 GPT-3 网络结构,通过指示学习构建训练样本来训练一个反应预测内容效果的奖励模型(RM),最后通过这个奖励模型的打分来指导强化学习模型的训练。
人工评估发现,尽管参数少了100倍,仅包含1.3B参数的InstructGPT模型,仍然优于175B参数的GPT-3模型。同时,基于人类反馈进行微调的InstructGPT模型,其生成结果的真实性更高,有害输出也相应更少;在公开的 NLP 数据集上,它的性能也没有显著的下降。单纯的扩大模型、堆算力效果会有局限。在一些模型做不到的地方,适当加入一些人工标注,效果更好也更划算。
语言模型通过预测下一个词的方式进行训练,其目标函数是最大化给定语言序列的条件概率,而不是==“有帮助且安全地遵循用户的指示”==,因此两者之间并未对齐(the language modeling objective is misaligned)。InstructGPT提出了语言模型的三个目标:
helpful
——帮助用户解决问题honest
——不能伪造信息或误导用户harmless
——不会令人反感,也不会对他人或社会有害
- 根据采集的SFT数据集对GPT-3进行有监督的微调(Supervised FineTune,SFT);(40 个 labeler)
- 收集人工标注的对比数据,训练奖励模型(Reword Model,RM);
- 使用RM作为强化学习的优化目标,利用PPO算法微调SFT模型。
奖励模型(RM)
RM 是将 SFT 训练后的模型的最后的嵌入层去掉得到的模型,输入是 prompt 和 response,输出是奖励值。
对每个 prompt,随机生成 K 个输出,向每个 labeler 成对展示输出结果,用户从中选择更好的结果。训练时,InstructGPT 将每个 prompt 的
C
K
2
C_K^2
CK2 个相应对作为一个 batch。(比传统按样本作为 batch 的方式不容易拟合,每个 prompt 有且只有一次被输入到模型中)。
损失函数如下式,目的是最大化 labeler 更喜欢的响应和不喜欢的之间差值:
l
o
s
s
(
θ
)
=
−
1
(
K
2
)
E
(
x
,
y
w
,
y
l
)
∼
D
[
log
(
σ
(
r
θ
(
x
,
y
w
)
−
r
θ
(
x
,
y
l
)
)
)
]
\mathrm{loss}\left(\theta\right)=-\frac{1}{\binom{K}{2}}E_{(x,y_{w},y_{l})\thicksim D}\left[\log\left(\sigma\left(r_{\theta}\left(x,y_{w}\right)-r_{\theta}\left(x,y_{l}\right)\right)\right)\right]
loss(θ)=−(2K)1E(x,yw,yl)∼D[log(σ(rθ(x,yw)−rθ(x,yl)))]
PPO
- 随着模型更新,强化学习产生的数据和训练模型的数据差异会越来越大,作者在损失函数中加入正则项 β log ( π ϕ R L ( y ∣ x ) / π S F T ( y ∣ x ) ) \beta\log\left(\pi_{\phi}^{\mathrm{RL}}(y\mid x)/\pi^{\mathrm{SFT}}(y\mid x)\right) βlog(πϕRL(y∣x)/πSFT(y∣x)) 确保 PPO 的输出和 SFT 的输出差距不会很大。
- 只用 PPO 模型训练,会导致模型在通用 NLP 任务上性能下降,作者在训练目标中加入了通用语言模型目标
γ
E
x
∼
D
p
r
e
t
r
a
i
n
[
log
(
π
ϕ
R
L
(
x
)
)
]
\gamma E_{x\sim D_{\mathrm{pretrain}}}\left[\log(\pi_{\phi}^{\mathrm{RL}}(x))\right]
γEx∼Dpretrain[log(πϕRL(x))]
o b j e c t i v e ( ϕ ) = E ( x , y ) ∼ D π ϕ R L [ r θ ( x , y ) − β log ( π ϕ R L ( y ∣ x ) / π S F T ( y ∣ x ) ) ] + γ E x ∼ D p r e t r a i n [ log ( π ϕ R L ( x ) ) ] \begin{aligned} \mathrm{objective}\left(\phi\right)= & E_{(x,y)\sim D_{\pi_{\phi}^{\mathrm{RL}}}}\left[r_{\theta}(x,y)-\beta\log\left(\pi_{\phi}^{\mathrm{RL}}(y\mid x)/\pi^{\mathrm{SFT}}(y\mid x)\right)\right]+ \\ & \gamma E_{x\sim D_{\mathrm{pretrain}}}\left[\log(\pi_{\phi}^{\mathrm{RL}}(x))\right] \end{aligned} objective(ϕ)=E(x,y)∼DπϕRL[rθ(x,y)−βlog(πϕRL(y∣x)/πSFT(y∣x))]+γEx∼Dpretrain[log(πϕRL(x))]
尽管通过损失函数的方式去降低模型在通用 NLP 任务上的性能降低,但没有彻底解决该问题。