基于Pytorch的Transformer各个模块手动实现

博客主要介绍了Transformer模型的相关组件。Scaled Dot - Product Attention是一种注意力机制,用于计算输入序列位置间相关性权重;Multi - Head Attention扩展了缩放点积注意力机制;Position - wise Feed - Forward Networks是前馈神经网络层,对位置表示进行非线性变换,还提及了编码层和解码层。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

Scaled Dot-Product Attention(缩放点积注意力):
Scaled Dot-Product Attention 是 Transformer 模型中的一种注意力机制,用于计算输入序列中不同位置之间的相关性权重。

class ScaledDotProductAttention(nn.Module):
    def __init__(self,d_k,n_heads):
        super(ScaledDotProductAttention, self).__init__()
        self.d_k = d_k
        self.n_head = n_heads

    def forward(self, Q, K, V):
        '''
        Q: [batch_size, n_heads, len_q, d_k]
        K: [batch_size, n_heads, len_k, d_k]
        V: [batch_size, n_heads, len_v(=len_k), d_v]
        attn_mask: [batch_size, n_heads, seq_len, seq_len]
        '''
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(self.d_k) # scores : [batch_size, n_heads, len_q, len_k]
        #scores.masked_fill_(attn_mask, -1e9) # Fills elements of self tensor with value where mask is True.

        attn = nn.Softmax(dim=-1)(scores)
        context = torch.matmul(attn, V) # [batch_size, n_heads, len_q, d_v]
        return context

Multi-Head Attention(多头注意力)是 Transformer 模型中的另一个重要组件,它扩展了标准的缩放点积注意力机制,以捕捉不同的注意力信息。

class MultiHeadAttention(nn.Module):
    def __init__(self,device,n_heads,d_model,d_k,d_v):
        super(MultiHeadAttention, self).__init__()
        self.n_heads = n_heads
        self.device = device
        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_v

        self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=False)
        self.W_K = nn.Linear(d_model, d_k * n_heads, bias=False)
        self.W_V = nn.Linear(d_model, d_v * n_heads, bias=False)
        self.fc = nn.Linear(n_heads * d_v, d_model, bias=False)

    def forward(self, input_Q, input_K, input_V):
        '''
        input_Q: [batch_size, len_q, d_model]
        input_K: [batch_size, len_k, d_model]
        input_V: [batch_size, len_v(=len_k), d_model]
        attn_mask: [batch_size, seq_len, seq_len]
        '''
        residual, batch_size = input_Q, input_Q.size(0)
        # (B, S, D) -proj-> (B, S, D_new) -split-> (B, S, H, W) -trans-> (B, H, S, W)
        Q = self.W_Q(input_Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2)  # Q: [batch_size, n_heads, len_q, d_k]
        K = self.W_K(input_K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2)  # K: [batch_size, n_heads, len_k, d_k]
        V = self.W_V(input_V).view(batch_size, -1, self.n_heads, self.d_v).transpose(1,2)  # V: [batch_size, n_heads, len_v(=len_k), d_v]

        #attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1) # attn_mask : [batch_size, n_heads, seq_len, seq_len]

        # context: [batch_size, n_heads, len_q, d_v], attn: [batch_size, n_heads, len_q, len_k]
        context = ScaledDotProductAttention(self.d_k, self.n_heads)(Q, K, V)
        context = context.transpose(1, 2).reshape(batch_size, -1, self.n_heads * self.d_v) # context: [batch_size, len_q, n_heads * d_v]
        output = self.fc(context) # [batch_size, len_q, d_model]
        return nn.LayerNorm(self.d_model).to(self.device)(output + residual)

Position-wise Feed-Forward Networks(位置编码前馈网络):
Position-wise Feed-Forward Networks 是 Transformer 模型中的一个前馈神经网络层,用于对每个位置的表示进行非线性变换。

class PoswiseFeedForwardNet(nn.Module):
    def __init__(self,device,d_model,d_ff):
        super(PoswiseFeedForwardNet, self).__init__()
        self.d_model = d_model
        self.device = device
        self.fc = nn.Sequential(
            nn.Linear(d_model, d_ff, bias=False),
            nn.ELU(),
            nn.Linear(d_ff, d_model, bias=False)
        )

    def forward(self, inputs):
        '''
        inputs: [batch_size, seq_len, d_model]
        '''
        residual = inputs
        output = self.fc(inputs)
        return nn.LayerNorm(self.d_model).to(self.device)(output + residual) # [batch_size, seq_len, d_model]

编码层

class EncoderLayer(nn.Module):
    def __init__(self,device,n_heads,d_model,d_k,d_v,d_ff):
        super(EncoderLayer, self).__init__()
        self.enc_self_attn = MultiHeadAttention(device,n_heads,d_model,d_k,d_v)
        self.pos_ffn = PoswiseFeedForwardNet(device,d_model,d_ff)

    def forward(self, enc_inputs):
        '''
        enc_inputs: [batch_size, src_len, d_model]
        enc_self_attn_mask: [batch_size, src_len, src_len]
        '''
        # enc_outputs: [batch_size, src_len, d_model], attn: [batch_size, n_heads, src_len, src_len]
        enc_outputs = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs) # enc_inputs to same Q,K,V
        enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size, src_len, d_model]
        return enc_outputs


class Encoder(nn.Module):
    def __init__(self,device,n_layers,n_heads,d_model,d_k,d_v,d_ff):
        super(Encoder, self).__init__()
        self.layers = nn.ModuleList([EncoderLayer(device,n_heads,d_model,d_k,d_v,d_ff) for _ in range(n_layers)])

    def forward(self, enc_inputs):
        enc_outputs = enc_inputs
        for layer in self.layers:
            # enc_outputs: [batch_size, src_len, d_model], enc_self_attn: [batch_size, n_heads, src_len, src_len]
            enc_outputs = layer(enc_outputs)
        return enc_outputs

解码层

class DecoderLayer(nn.Module):
    def __init__(self, device, n_heads, d_model, d_k, d_v, d_ff):
        super(DecoderLayer, self).__init__()
        self.dec_self_attn = MultiHeadAttention(device, n_heads, d_model, d_k, d_v)
        self.enc_dec_attn = MultiHeadAttention(device, n_heads, d_model, d_k, d_v)
        self.pos_ffn = PoswiseFeedForwardNet(device,d_model,d_ff)

    def forward(self, dec_inputs, enc_outputs):
        '''
        dec_inputs: [batch_size, tgt_len, d_model]
        enc_outputs: [batch_size, src_len, d_model]
        '''
        # dec_outputs: [batch_size, tgt_len, d_model], attn_dec: [batch_size, n_heads, tgt_len, tgt_len]
        dec_outputs = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs)  # dec_inputs to same Q,K,V
        # dec_outputs: [batch_size, tgt_len, d_model], attn_enc_dec: [batch_size, n_heads, tgt_len, src_len]
        dec_outputs = self.enc_dec_attn(dec_outputs, enc_outputs, enc_outputs)  # dec_outputs as Q, enc_outputs as K,V
        dec_outputs = self.pos_ffn(dec_outputs)  # dec_outputs: [batch_size, tgt_len, d_model]
        return dec_outputs




class Decoder(nn.Module):
    def __init__(self, device, n_layers, n_heads, d_model, d_k, d_v, d_ff):
        super(Decoder, self).__init__()
        self.layers = nn.ModuleList([DecoderLayer(device, n_heads, d_model, d_k, d_v, d_ff) for _ in range(n_layers)])
        self.chu = nn.Sequential(
            nn.Linear(d_model,1)
        )

    def forward(self, dec_inputs, enc_outputs):
        dec_outputs = dec_inputs
        for layer in self.layers:
            # dec_outputs: [batch_size, tgt_len, d_model]
            dec_outputs = layer(dec_outputs, enc_outputs)
        # dec_outputs = self.chu(dec_outputs)
        return dec_outputs

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值