Scaled Dot-Product Attention(缩放点积注意力):
Scaled Dot-Product Attention 是 Transformer 模型中的一种注意力机制,用于计算输入序列中不同位置之间的相关性权重。
class ScaledDotProductAttention(nn.Module):
def __init__(self,d_k,n_heads):
super(ScaledDotProductAttention, self).__init__()
self.d_k = d_k
self.n_head = n_heads
def forward(self, Q, K, V):
'''
Q: [batch_size, n_heads, len_q, d_k]
K: [batch_size, n_heads, len_k, d_k]
V: [batch_size, n_heads, len_v(=len_k), d_v]
attn_mask: [batch_size, n_heads, seq_len, seq_len]
'''
scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(self.d_k) # scores : [batch_size, n_heads, len_q, len_k]
#scores.masked_fill_(attn_mask, -1e9) # Fills elements of self tensor with value where mask is True.
attn = nn.Softmax(dim=-1)(scores)
context = torch.matmul(attn, V) # [batch_size, n_heads, len_q, d_v]
return context
Multi-Head Attention(多头注意力)是 Transformer 模型中的另一个重要组件,它扩展了标准的缩放点积注意力机制,以捕捉不同的注意力信息。
class MultiHeadAttention(nn.Module):
def __init__(self,device,n_heads,d_model,d_k,d_v):
super(MultiHeadAttention, self).__init__()
self.n_heads = n_heads
self.device = device
self.d_model = d_model
self.d_k = d_k
self.d_v = d_v
self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=False)
self.W_K = nn.Linear(d_model, d_k * n_heads, bias=False)
self.W_V = nn.Linear(d_model, d_v * n_heads, bias=False)
self.fc = nn.Linear(n_heads * d_v, d_model, bias=False)
def forward(self, input_Q, input_K, input_V):
'''
input_Q: [batch_size, len_q, d_model]
input_K: [batch_size, len_k, d_model]
input_V: [batch_size, len_v(=len_k), d_model]
attn_mask: [batch_size, seq_len, seq_len]
'''
residual, batch_size = input_Q, input_Q.size(0)
# (B, S, D) -proj-> (B, S, D_new) -split-> (B, S, H, W) -trans-> (B, H, S, W)
Q = self.W_Q(input_Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2) # Q: [batch_size, n_heads, len_q, d_k]
K = self.W_K(input_K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2) # K: [batch_size, n_heads, len_k, d_k]
V = self.W_V(input_V).view(batch_size, -1, self.n_heads, self.d_v).transpose(1,2) # V: [batch_size, n_heads, len_v(=len_k), d_v]
#attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1) # attn_mask : [batch_size, n_heads, seq_len, seq_len]
# context: [batch_size, n_heads, len_q, d_v], attn: [batch_size, n_heads, len_q, len_k]
context = ScaledDotProductAttention(self.d_k, self.n_heads)(Q, K, V)
context = context.transpose(1, 2).reshape(batch_size, -1, self.n_heads * self.d_v) # context: [batch_size, len_q, n_heads * d_v]
output = self.fc(context) # [batch_size, len_q, d_model]
return nn.LayerNorm(self.d_model).to(self.device)(output + residual)
Position-wise Feed-Forward Networks(位置编码前馈网络):
Position-wise Feed-Forward Networks 是 Transformer 模型中的一个前馈神经网络层,用于对每个位置的表示进行非线性变换。
class PoswiseFeedForwardNet(nn.Module):
def __init__(self,device,d_model,d_ff):
super(PoswiseFeedForwardNet, self).__init__()
self.d_model = d_model
self.device = device
self.fc = nn.Sequential(
nn.Linear(d_model, d_ff, bias=False),
nn.ELU(),
nn.Linear(d_ff, d_model, bias=False)
)
def forward(self, inputs):
'''
inputs: [batch_size, seq_len, d_model]
'''
residual = inputs
output = self.fc(inputs)
return nn.LayerNorm(self.d_model).to(self.device)(output + residual) # [batch_size, seq_len, d_model]
编码层
class EncoderLayer(nn.Module):
def __init__(self,device,n_heads,d_model,d_k,d_v,d_ff):
super(EncoderLayer, self).__init__()
self.enc_self_attn = MultiHeadAttention(device,n_heads,d_model,d_k,d_v)
self.pos_ffn = PoswiseFeedForwardNet(device,d_model,d_ff)
def forward(self, enc_inputs):
'''
enc_inputs: [batch_size, src_len, d_model]
enc_self_attn_mask: [batch_size, src_len, src_len]
'''
# enc_outputs: [batch_size, src_len, d_model], attn: [batch_size, n_heads, src_len, src_len]
enc_outputs = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs) # enc_inputs to same Q,K,V
enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size, src_len, d_model]
return enc_outputs
class Encoder(nn.Module):
def __init__(self,device,n_layers,n_heads,d_model,d_k,d_v,d_ff):
super(Encoder, self).__init__()
self.layers = nn.ModuleList([EncoderLayer(device,n_heads,d_model,d_k,d_v,d_ff) for _ in range(n_layers)])
def forward(self, enc_inputs):
enc_outputs = enc_inputs
for layer in self.layers:
# enc_outputs: [batch_size, src_len, d_model], enc_self_attn: [batch_size, n_heads, src_len, src_len]
enc_outputs = layer(enc_outputs)
return enc_outputs
解码层
class DecoderLayer(nn.Module):
def __init__(self, device, n_heads, d_model, d_k, d_v, d_ff):
super(DecoderLayer, self).__init__()
self.dec_self_attn = MultiHeadAttention(device, n_heads, d_model, d_k, d_v)
self.enc_dec_attn = MultiHeadAttention(device, n_heads, d_model, d_k, d_v)
self.pos_ffn = PoswiseFeedForwardNet(device,d_model,d_ff)
def forward(self, dec_inputs, enc_outputs):
'''
dec_inputs: [batch_size, tgt_len, d_model]
enc_outputs: [batch_size, src_len, d_model]
'''
# dec_outputs: [batch_size, tgt_len, d_model], attn_dec: [batch_size, n_heads, tgt_len, tgt_len]
dec_outputs = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs) # dec_inputs to same Q,K,V
# dec_outputs: [batch_size, tgt_len, d_model], attn_enc_dec: [batch_size, n_heads, tgt_len, src_len]
dec_outputs = self.enc_dec_attn(dec_outputs, enc_outputs, enc_outputs) # dec_outputs as Q, enc_outputs as K,V
dec_outputs = self.pos_ffn(dec_outputs) # dec_outputs: [batch_size, tgt_len, d_model]
return dec_outputs
class Decoder(nn.Module):
def __init__(self, device, n_layers, n_heads, d_model, d_k, d_v, d_ff):
super(Decoder, self).__init__()
self.layers = nn.ModuleList([DecoderLayer(device, n_heads, d_model, d_k, d_v, d_ff) for _ in range(n_layers)])
self.chu = nn.Sequential(
nn.Linear(d_model,1)
)
def forward(self, dec_inputs, enc_outputs):
dec_outputs = dec_inputs
for layer in self.layers:
# dec_outputs: [batch_size, tgt_len, d_model]
dec_outputs = layer(dec_outputs, enc_outputs)
# dec_outputs = self.chu(dec_outputs)
return dec_outputs