TR2 - Transformer模型的复现

这篇文章详细介绍了Transformer模型的理论知识,包括其结构分解、编码器和解码器的组成,以及多头自注意力块和前馈网络的实现。作者通过实例展示了如何构建Transformer模型并分享了从CV模型转向Transformer的心得体会。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >



理论知识

Transformer是可以用于Seq2Seq任务的一种模型,和Seq2Seq不冲突。

模型结构

模型整体结构

结构分解

黑盒

以机器翻译任务为例
黑盒

两大模块

在Transformer内部,可以分成Encoder编码器和和Decoder解码器两部分,这也是Seq2Seq的标准结构。
两大模块

块级结构

继续拆解,可以发现模型的由许多的编码器块和解码器块组成并且每个解码器都可以获取到最后一层编码器的输出以及上一层解码器的输出(第一个当然是例外的)。
块组成

编码器的组成

继续拆解,一个编码器是由一个自注意力块和一个前馈网络组成。
编码器的组成

解码器的组成

而解码器,是在编码器的结构中间又插入了一个Encoder-Decoder Attention层。
解码器的组成

模型实现

通过前面自顶向下的拆解,已经基本掌握了模型的总体结构。接下来自底向上的复现Transformer模型。

多头自注意力块

class MultiHeadAttention(nn.Module):
    """多头注意力模块"""
    def __init__(self, dims, n_heads):
        """
        dims: 每个词向量维度
        n_heads: 注意力头数
        """
        super().__init__()

        self.dims = dims
        self.n_heads = n_heads

        # 维度必需整除注意力头数
        assert dims % n_heads == 0
        # 定义Q矩阵
        self.w_Q = nn.Linear(dims, dims)
        # 定义K矩阵
        self.w_K = nn.Linear(dims, dims)
        # 定义V矩阵
        self.w_V = nn.Linear(dims, dims)

        self.fc = nn.Linear(dims, dims)
        # 缩放
        self.scale = torch.sqrt(torch.FloatTensor([dims//n_heads])).to(device)

    def forward(self, query, key, value, mask=None):
        batch_size = query.shape[0]
        # 例如: [32, 1024, 300] 计算10头注意力
        Q = self.w_Q(query)
        K = self.w_K(key)
        V = self.w_V(value)

        # [32, 1024, 300] -> [32, 1024, 10, 30] 把向量重新分组
        Q = Q.view(batch_size, -1, self.n_heads, self.dims//self.n_heads).permute(0, 2, 1, 3)
        K = K.view(batch_size, -1, self.n_heads, self.dims//self.n_heads).permute(0, 2, 1, 3)
        V = V.view(batch_size, -1, self.n_heads, self.dims//self.n_heads).
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值