第一步:理解论文
论文:《Attention Is All You Need》
关键点总结
-
Transformer 架构:
- 基于注意力机制,不使用循环神经网络(RNN)或卷积神经网络(CNN)。
- 包括编码器(Encoder)和解码器(Decoder)两个部分。
-
自注意力机制:
- 核心公式:[ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V ]
- 查询(Query)、键(Key)、值(Value)的计算。
-
多头注意力机制:
- 多头注意力通过并行计算多个注意力头来提升模型的表示能力。
-
位置编码:
- 由于模型没有顺序信息,引入位置编码以保持序列的顺序信息。
第二步:代码复现
准备工作
- 安装 PyTorch 库:
pip install torch
- 导入必要的模块:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
实现自注意力机制
class SelfAttention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64):
super(SelfAttention, self).__init__()
self.heads = heads
self.dim_head = dim_head
self.scale = dim_head ** -0.5 # 缩放因子
# 线性变换,分别生成查询(Q)、键(K)、值(V)
self.to_qkv = nn.Linear(dim, dim_head * heads * 3, bias=False)
self.to_out = nn.Linear(dim_head * heads, dim)
def forward(self, x):
b, n, _ = x.shape # 获取输入的形状
qkv = self.to_qkv(x).chunk(3, dim=-1) # 线性变换后分成 Q、K、V
q, k, v = map(lambda t: t.reshape(b, n, self.heads, self.dim_head).transpose(1, 2), qkv)
scores = (q @ k.transpose(-2, -1)) * self.scale # 计算注意力得分
attn = scores.softmax(dim=-1) # 计算注意力权重
out = (attn @ v).transpose(1, 2).reshape(b, n, -1) # 计算输出
return self.to_out(out) # 返回输出
实现多头注意力机制
class MultiHeadAttention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64):
super(MultiHeadAttention, self).__init__()
self.heads = heads
self.dim_head = dim_head
self.scale = dim_head ** -0.5
self.to_qkv = nn.Linear(dim, dim_head * heads * 3, bias=False)
self.to_out = nn.Linear(dim_head * heads, dim)
def forward(self, x):
b, n, _ = x.shape
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: t.reshape(b, n, self.heads, self.dim_head).transpose(1, 2), qkv)
scores = (q @ k.transpose(-2, -1)) * self.scale
attn = scores.softmax(dim=-1)
out = (attn @ v).transpose(1, 2).reshape(b, n, -1)
return self.to_out(out)
实现位置编码
class PositionalEncoding(nn.Module):
def __init__(self, dim, max_len=5000):
super(PositionalEncoding, self).__init__()
self.encoding = torch.zeros(max_len, dim)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim))
self.encoding[:, 0::2] = torch.sin(position * div_term)
self.encoding[:, 1::2] = torch.cos(position * div_term)
self.encoding = self.encoding.unsqueeze(0)
def forward(self, x):
return x + self.encoding[:, :x.size(1), :].to(x.device)
实现 Transformer 编码器
class TransformerEncoderLayer(nn.Module):
def __init__(self, dim, heads, dim_ff, dropout=0.1):
super(TransformerEncoderLayer, self).__init__()
self.self_attn = MultiHeadAttention(dim, heads)
self.feed_forward = nn.Sequential(
nn.Linear(dim, dim_ff),
nn.ReLU(),
nn.Linear(dim_ff, dim)
)
self.layernorm1 = nn.LayerNorm(dim)
self.layernorm2 = nn.LayerNorm(dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
attn_out = self.self_attn(x)
x = x + self.dropout(attn_out)
x = self.layernorm1(x)
ff_out = self.feed_forward(x)
x = x + self.dropout(ff_out)
x = self.layernorm2(x)
return x
class TransformerEncoder(nn.Module):
def __init__(self, dim, heads, dim_ff, num_layers, dropout=0.1):
super(TransformerEncoder, self).__init__()
self.layers = nn.ModuleList([TransformerEncoderLayer(dim, heads, dim_ff, dropout) for _ in range(num_layers)])
self.layernorm = nn.LayerNorm(dim)
def forward(self, x, mask=None):
for layer in self.layers:
x = layer(x, mask)
return self.layernorm(x)
问题解决和总结
问题解决
在复现过程中,常见问题包括:
- 理解错误:对于复杂的公式和算法,可能会理解错误。解决办法是反复阅读论文并参考其他实现。
- 维度错误:矩阵计算中,维度匹配是常见问题。使用调试工具或打印维度信息来检查。
- 参数初始化:注意初始化参数的方式,确保模型能正确收敛。
总结
- 阅读和理解论文是关键:深入理解论文中的核心思想和实现细节,能够更好地复现模型。
- 参考现有实现:利用现有的开源代码,可以帮助理解和实现复杂的模型。
- 逐步调试:通过逐步调试和检查,可以找到并解决实现过程中的问题。
示例使用完整模型
class TransformerModel(nn.Module):
def __init__(self, input_dim, num_classes, num_layers, dim_ff, heads, dropout=0.1):
super(TransformerModel, self).__init__()
self.embedding = nn.Linear(input_dim, input_dim)
self.pos_encoding = PositionalEncoding(input_dim)
self.encoder = TransformerEncoder(input_dim, heads, dim_ff, num_layers, dropout)
self.fc = nn.Linear(input_dim, num_classes)
def forward(self, x, mask=None):
x = self.embedding(x)
x = self.pos_encoding(x)
x = self.encoder(x, mask)
x = self.fc(x.mean(dim=1)) # 聚合所有时间步的输出
return x
# 示例使用
batch_size = 64
sequence_length = 10
input_dim = 128
num_classes = 10
num_layers = 6
dim_ff = 512
heads = 8
dropout = 0.1
input_data = torch.randn(batch_size, sequence_length, input_dim)
model = TransformerModel(input_dim, num_classes, num_layers, dim_ff, heads, dropout)
output = model(input_data)
print(output.shape) # Output shape should be (batch_size, num_classes)