python手撕代码——完整的transformer代码
1.位置编码
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=512):
super(PositionalEncoding, self).__init__()
self.encoding = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1).float()
div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(torch.log(torch.tensor(10000.0)) / d_model))
self.encoding[:, 0::2] = torch.sin(position * div_term)
self.encoding[:, 1::2] = torch.cos(position * div_term)
self.encoding = self.encoding.unsqueeze(0)
def forward(self, x):
return self.encoding[:, :x.size(1)]
2.编码器
class EncoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward):
super(EncoderLayer, self).__init__()
self.multihead_attention = nn.MultiheadAttention(d_model, nhead)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
nn.ReLU(),
nn.Linear(dim_feedforward, d_model)
)
self.layer_norm1 = nn.LayerNorm(d_model)
self.layer_norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(0.1)
def forward(self, src, src_mask=None):
attn_output, _ = self.multihead_attention(src, src, src, attn_mask=src_mask)
src = src + self.dropout(attn_output)
src = self.layer_norm1(src)
ff_output = self.feed_forward(src)
src = src + self.dropout(ff_output)
src = self.layer_norm2(src)
return src
3.解码器
class DecoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward):
super(DecoderLayer, self).__init__()
self.multihead_attention1 = nn.MultiheadAttention(d_model, nhead)
self.multihead_attention2 = nn.MultiheadAttention(d_model, nhead)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
nn.ReLU(),
nn.Linear(dim_feedforward, d_model)
)
self.layer_norm1 = nn.LayerNorm(d_model)
self.layer_norm2 = nn.LayerNorm(d_model)
self.layer_norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(0.1)
def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):
attn_output1, _ = self.multihead_attention1(tgt, tgt, tgt, attn_mask=tgt_mask)
tgt = tgt + self.dropout(attn_output1)
tgt = self.layer_norm1(tgt)
attn_output2, _ = self.multihead_attention2(tgt, memory, memory, attn_mask=memory_mask)
tgt = tgt + self.dropout(attn_output2)
tgt = self.layer_norm2(tgt)
ff_output = self