在 Transformer 的每一层(无论是 编码器 还是 解码器),都会使用 层归一化(Layer Normalization, LN)+ 残差连接(Residual Connection),以加速训练、稳定梯度流动,并防止模型退化。
1. 为什么需要层归一化 + 残差连接?
Transformer 结构深,包含多个编码器/解码器层。如果直接堆叠层,可能会导致:
- 梯度消失/爆炸(深层网络的常见问题)
- 训练不稳定
- 学习速度慢
使用:
- 残差连接:确保梯度能够回传到前面层,缓解梯度消失问题。
- 层归一化:稳定数据分布,提高训练速度,使训练更加稳定。
2. 公式表示
假设输入 X 经过某个子层(Multi-Head Attention 或 FFN)后,输出为 SubLayer(X),则:
- 残差连接(Residual Connection):
,确保信息可以回传,缓解梯度消失。
- 层归一化(Layer Normalization, LN):对
进行归一化,防止内部数据分布不稳定。
3. 代码实现
(1)LayerNorm 实现
在 PyTorch 中,nn.LayerNorm
用于执行 层归一化:
import torch
import torch.nn as nn
class LayerNormalization(nn.Module):
def __init__(self, d_model, eps=1e-6):
super(LayerNormalization, self).__init__()
self.gamma = nn.Parameter(torch.ones(d_model)) # 可学习参数,缩放因子
self.beta = nn.Parameter(torch.zeros(d_model)) # 可学习参数,偏移因子
self.eps = eps
def forward(self, x):
mean = x.mean(dim=-1, keepdim=True) # 计算均值
std = x.std(dim=-1, keepdim=True) # 计算标准差
return self.gamma * (x - mean) / (std + self.eps) + self.beta # 归一化
(2)残差连接 + 层归一化
class ResidualConnection(nn.Module):
def __init__(self, d_model, dropout=0.1):
super(ResidualConnection, self).__init__()
self.norm = nn.LayerNorm(d_model) # 层归一化
self.dropout = nn.Dropout(dropout) # Dropout 防止过拟合
def forward(self, x, sublayer):
return x + self.dropout(self.norm(sublayer(x))) # 残差 + 层归一化
sublayer(x)
:代表 Transformer 的 注意力机制 或 前馈神经网络(FFN)。self.norm(sublayer(x))
:先做层归一化。self.dropout(...)
:防止过拟合。x + ...
:残差连接,缓解梯度消失问题。
4. 残差连接 + 层归一化 在 Transformer 中的位置
Transformer 的 编码器和解码器 层中,每一层都包含:
- 自注意力 + 残差 + 层归一化
- 前馈神经网络(FFN)+ 残差 + 层归一化
代码示意:
class TransformerLayer(nn.Module):
def __init__(self, d_model, num_heads, hidden_dim, dropout=0.1):
super(TransformerLayer, self).__init__()
self.self_attention = MultiHeadAttention(d_model, num_heads) # 多头注意力
self.feed_forward = FeedForwardNetwork(d_model, hidden_dim, dropout) # 前馈神经网络
self.residual1 = ResidualConnection(d_model, dropout) # 第一处残差 + LN
self.residual2 = ResidualConnection(d_model, dropout) # 第二处残差 + LN
def forward(self, x, mask):
x = self.residual1(x, lambda x: self.self_attention(x, mask)) # 残差 + LN (注意力层)
x = self.residual2(x, self.feed_forward) # 残差 + LN (FFN 层)
return x
- 第一层:自注意力 + 残差 + LN
- 第二层:FFN + 残差 + LN