大模型学习 (Datawhale_Happy-LLM)笔记5: 搭建一个 Transformer
搭建 Transformer 的核心组件总结
1. 基础功能模块
- 自注意力机制:通过
QKV
矩阵计算序列内依赖关系,公式为:
Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)VAttention(Q,K,V)=softmax(dkQKT)V - 多头注意力:将特征拆分为多个子空间并行计算注意力,增强模型表达能力
- 位置编码:通过正弦余弦函数为序列添加位置信息,解决 Transformer 无序列感知问题
2. 核心网络层
- 层归一化 (LayerNorm):对每个样本的所有特征维度归一化,公式为:
LayerNorm(x)=α⊙x−μσ2+ϵ+β\text{LayerNorm}(x) = \alpha \odot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \betaLayerNorm(x)=α⊙σ2+ϵx−μ+β - 前馈神经网络 (MLP):由
Linear+ReLU+Linear
构成,增强模型非线性表达能力
3. 编码器与解码器架构
- 编码器层:由
LayerNorm+MultiHeadAttention+残差连接
和 LayerNorm+MLP+残差连接
组成 - 解码器层:比编码器多一个
掩码多头注意力
,用于避免预测时看到未来信息 - 堆叠结构:多个编码器/解码器层堆叠形成深度网络
4. 输入输出处理
- 嵌入层:将离散token转换为连续向量,与位置编码相加后输入网络
- 输出层:通过线性层将特征映射到词表空间,用于生成概率分布
5. 关键技术点
- 残差连接:解决深度网络训练梯度消失问题
- Dropout:随机丢弃神经元,防止过拟合
- 掩码机制:在解码器中屏蔽未来位置,确保预测时的因果关系
完整 Transformer 架构流程图
输入序列 ──→ 嵌入层 + 位置编码 ──→ 编码器序列 ──→ 编码器输出
│
↓
解码器序列(含掩码)
│
↓
线性层 + Softmax ─→ 输出序列
多头自注意力模块
class ModelArgs:
def __init__(self):
self.n_embed = 256
self.n_head = 4
self.head_dim = self.n_embed // self.n_heads
self.dropout = 0.1
self.max_seq_len = 512
self.n_layers = 6
self.vocab_size = None
self.block_size = None
class MultiHeadAttention(nn.Module):
def __init__(self, args: ModelArgs, is_causal=False):
super().__init__()
assert args.n_embed % args.n_head == 0
model_parallel_size = 1
self.n_local_heads = args.n_heads // model_parallel_size
self.head_dim = args.dim // args.n_heads
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
self.attn_dropout = nn.Dropout(args.dropout)
self.resid_dropout = nn.Dropout(args.dropout)
if is_causal:
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
mask = torch.triu(mask, diagonal=1)
self.register_buffer("mask", mask)
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
bsz, seqlen, _ = q.shape
xq, xk, xv = self.wq(q), self.wk(k), self.wv(v)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xq = xq.transpose(1, 2)
xk = xk.transpose(1, 2)
xv = xv.transpose(1, 2)
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
if self.is_causal:
assert hasattr(self, 'mask')
scores = scores + self.mask[:, :, :seqlen, :seqlen]
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores = self.attn_dropout(scores)
output = torch.matmul(scores, xv)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
output = self.wo(output)
output = self.resid_dropout(output)
return output
前馈神经网络
class MLP(nn.Module):
'''前馈神经网络
MLP, (Multi-Layer Perceptron) 多层感知机
用以构建前馈神经网络
'''
def __init__(self, dim:int, hidden_dim: int, dropout: float):
super().__init__()
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.dropout(self.w2(F.relu(self.w1(x))))
LayerNorm 层
归一化操作的数学公式
- 计算样本均值 (其中,ZjiZ_j^iZji 是样本i在第j个维度上的值,m就是mini-batch的大小)
μj=1m∑i=1mZji\displaystyle\mu_j = \frac{1}{m}\sum_{i=1}^{m}Z_j^{i}μj=m1i=1∑mZji - 再计算样本的方差
σ2=1m∑i=1m(Zji−μj)2\displaystyle\sigma^2 = \frac{1}{m}\sum_{i=1}{m}(Z_j^i - \mu_j)^2σ2=m1i=1∑m(Zji−μj)2 - 最后对每个样本的值减去均值再除以标准差来将这个mini-batch的样本分布转化为标准正态分布
Z~j=Zj−μjσ2+ϵ\displaystyle\widetilde{Z}_j = \frac{Z_j-\mu_j}{\sqrt{\sigma^2+\epsilon}}Zj=σ2+ϵZj−μj
(此处ϵ\epsilonϵ这一极小量是为了避免分母为零)
class LayerNorm(nn.Module):
'''基于上述归一化的式子,实现一个简单的 Layer Norm 层'''
def __init__(self, features, eps = 1e-6):
super(LayerNorm, self).__init__()
self.a_2 = nn.Parameter(torch.ones(features))
self.b_2 = nn.Parameter(torch.zeros(features))
self.eps = eps
def forward(self, x):
mean = x.mean(-1, keepdim=True)
std = x.std(-1, keepdim=True)
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
Encoder
class EncoderLayer(nn.Module):
'''Encoder 层'''
def __init__(self, args):
super().__init__()
self.attention_norm = LayerNorm(args.n_embed)
self.attention = MultiHeadAttention(args, is_causal=False)
self.fnn_norm = LayerNorm(args.n_embed)
self.feed_forward = MLP(dim=args.n_embed, hidden_dim=args.n_embed*4, dropout=args.dropout)
def forward(self, x):
x = self.attention_norm(x)
h = x + self.attention.forward(x,x,x)
out = h + self.feed_forward.forward(self.fnn_norm(h))
return out
class Encoder(nn.Module):
'''Encoder 块'''
def __init__(self, args):
super(Encoder, self).__init__()
self.layers = nn.ModuleList([EncoderLayer(args) for _ in range(args.n_layer)])
self.norm = LayerNorm(args.n_embed)
def forward(self, x):
for layer in self.layers:
x = layer(x)
return self.norm(x)
Decoder
class DecoderLayer(nn.Module):
'''解码层'''
def __init__(self, args):
super().__init__()
self.attention_norm_1 = LayerNorm(args.n_embed)
self.mask_attention = MultiHeadAttention(args, is_causal=True)
self.attention_norm2 = LayerNorm(args.n_embed)
self.attention = MultiHeadAttention(args, is_causal=False)
self.ffn_norm = LayerNorm(args.n_embed)
self.feed_forward = MLP(args)
def forward(self, x, enc_out):
x = self.attention_norm_1(x)
x = x + self.mask_attention.forward(x,x,x)
x = self.attention_norm_2(x)
h = x + self.attention.forward(x, enc_out, enc_out)
out = h + self.feed_forward.forward(self.ffn_norm(h))
return out
class Decoder(nn.Module):
'''解码器'''
def __init__(self, args):
super(Decoder, self).__init__()
self.layers = nn.ModuleList([DecoderLayer(args) for _ in range(args.n_layer)])
self.norm = LayerNorm(args.n_embed)
def forward(self, x, enc_out):
for layer in self.layers:
x = layer(x, enc_out)
return self.norm(x)
搭建一个 Transformer
class PositionalEncoding(nn.Module):
def __init__(self, args):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=args.dropout)
pe = torch.zeros(args.block_size, args.n_embed)
position = torch.arange(0, args.block_size).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, args.n_embed,2) * -(math.log(10000.0) / args.n_embed)
)
pe[:,0::2] = torch.sin(position * div_term)
pe[:,1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer("pe", pe)
def forward(self, x):
x = x + self.pe[:, : x.size(1)].requires_grad_(False)
return self.dropout(x)
一个完整的 Transformer
class Transformer(nn.Module):
"""整体模型"""
def __init__(self, args):
super().__init__()
assert args.vocab_size is not None
assert args.block_size is not None
self.args = args
self.transformer = nn.ModuleDict(dict(
wte = nn.Embedding(args.vocab_size, args.n_embed),
wpe = PositionalEncoding(args),
drop = nn.Dropout(args.dropout),
encoder = Encoder(args),
decoder = Decoder(args),
))
self.lm_head = nn.Linear(args.n_embed, args.vocab_size, bias=False)
self.apply(self._init_weights)
print(f'number of parameters: {self.get_num_params()/1e6:.2fM}')
def get_num_params(self, non_embedding=False):
"""统计所有参数的数量"""
n_params = sum(p.numel() for p in self.parameters())
if non_embedding:
n_params -= self.transformer.wpe.weight.numel()
return n_params
def _init_weights(self, module):
"""初始化权重"""
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, idx, targets=None):
device = idx.device
b, t = idx.size()
assert t <= self.args.block_size, f'不能计算该序列,该序列长度为{t}, 最大序列长度只有 {self.args.block_size}'
print(f'idx: {idx.size()}')
tok_emb = self.transformer.wte(idx)
print(f"tok_emb: {tok_emb.size()}")
pos_emb = self.transformer.wpe(tok_emb)
x = self.transformer.drop(pos_emb)
print(f'x after wpe: {x.size()}')
if targets is not None:
logits = self.lm_head(x)
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
targets.view(-1),
ignore_index=-1)
else:
logits = self.lm_head(x[:, [-1], :])
loss = None
return logits, loss