Transformer代码阅读——Transformer部分

原作者视频链接:
【研1基本功 (真的很简单)召唤Transformer】手写“变压器”or“变形金刚”_哔哩哔哩_bilibiliicon-default.png?t=N7T8https://www.bilibili.com/video/BV1oK421Y7Vh/?spm_id_from=333.788&vd_source=01e171598915d67de063d93cfd6421e6

1、模型整体框架:

2、代码部分:

class Transformer(nn.Module):
    # 首先是两个pad,就是对输入的pad和decoder pad的一个标识符的一个记录
    # 然后告诉大家encoder vocabulary size和decoder vocabulary size 分辨是多大
    # max_len最大长度,d_model的大小,头的大小
    # 前向传播隐藏层的大小,总层数,dropout,最后还有device
    def __init__(self, src_pad_idx, trg_pad_idx, enc_voc_size, dec_voc_size, max_len, d_model, n_heads, ffn_hidden, n_layers, drop_prob, device):
        super(Transformer,self).__init__()

        self.encoder = Encoder(enc_voc_size, max_len, d_model, ffn_hidden, n_heads, n_layers, drop_prob, device)
        self.decoder = Decoder(dec_voc_size, max_len, d_model, ffn_hidden, n_heads, n_layers, drop_prob, device)

        # 生成两个padding的index标识符
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device

    def make_pad_mask(self, q, k, pad_idx_q, pad_idx_k):
        len_q, len_k = q.size(1), k.size(1)

        # (Batch, Time, len_q, len_k)  第三维和第四维是QK相乘之后得到的2*2矩阵,所以后面两个就是矩阵的一个维度
        q = q.ne(pad_idx_q).unsqueeze(1).unsqueeze(3)   # 本来q的维度是batch和len_q两维,现在为了统一格式,因此需要增加两个维度到四维
        q = q.repeat(1, 1, 1, len_k)   # 需要把len_k补全(因为每一个q都有一个对应的k)

        k = k.ne(pad_idx_k).unsqueeze(1).unsqueeze(2)
        k = k.repeat(1, 1, len_q, 1)

        # 生成Q,K之后,需要进行暗位取余的操作(全一出一,只要有零则出零)
        mask = q & k
        return mask

    def make_casual_mask(self, q, k):
        len_q, len_k = q.size(1), k.size(1)
        mask = torch.tril(torch.ones(len_q, len_k)).type(torch.BoolTensor).to(self.device)
        return mask

    def forward (self, src, trg):
        # 构建mask
        # 首先构建encoder当中自己的padding mask
        src_mask = self.make_pad_mask(src, src, self.src_pad_idx, self.src_pad_idx)
        # 然后是decoder自己的因果mask
        trg_mask = self.make_pad_mask(trg, trg, self.trg_pad_idx, self.trg_pad_idx) * self.make_casual_mask(trg, trg)
        # 交叉注意力机制的mask, q来自query(target),k来自encoder(source),
        src_trg_mask = self.make_pad_mask(trg, src, self.trg_pad_idx, self.src_pad_idx)

        enc = self.encoder(src, src_mask)
        output = self.decoder(trg, enc, trg_mask, src_trg_mask)
        return output

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值