【源码解读】Transformer的MultiHeadAttention部分代码解读

本文详细解析Transformer模型中的多头注意力机制实现原理与PyTorch代码实现过程,包括关键步骤如权重矩阵相乘、维度变换及张量操作等。

1 说明

首先,先给出Transformer的MultiHeadAttention部分的pytorch版本的代码,然后再对于此部分的细节进行解析

2 源码

class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        "Take in model size and number of heads."
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0#剖析点1
        # We assume d_v always equals d_k
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self, query, key, value, mask=None):
        # 纬度
        # shape:query=key=value--->:[batch_size,max_legnth,embedding_dim=512]
        
        if mask is not None:
            # Same mask applied to all h heads.
            mask = mask.unsqueeze(1)#剖析点2
        nbatches = query.size(0)
        
        #第一步:将q,k,v分别与Wq,Wk,Wv矩阵进行相乘
        #shape:Wq=Wk=Wv----->[512,512]
        #第二步:将获得的Q、K、V在第三个纬度上进行切分
        #shape:[batch_size,max_length,8,64]
        #第三部:填充到第一个纬度
        #shape:[batch_size,8,max_length,64]
        query, key, value = \
            [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
             for l, x in zip(self.linears, (query, key, value))]#剖析点3
        
        #进入到attention之后纬度不变,shape:[batch_size,8,max_length,64]
        x, self.attn = attention(query, key, value, mask=mask, 
                                 dropout=self.dropout)
        
        # 将纬度进行还原
        # 交换纬度:[batch_size,max_length,8,64]
        # 纬度还原:[batch_size,max_length,512]
        x = x.transpose(1, 2).contiguous() \
             .view(nbatches, -1, self.h * self.d_k)#剖析点4
        
        # 最后与WO大矩阵相乘 shape:[512,512]
        return self.linears[-1](x)

3 源码剖析

3.1 剖析点1:assert d_model % h == 0

assert断言机制
Python assert(断言)用于判断一个表达式,在表达式条件为 false 的时候触发异常。
语法:

assert expression

等价于(这种方式比较好理解)

if not expression:
    raise AssertionError(arguments)

assert 后面也可以紧跟参数:

assert expression [, arguments]
等价于
if not expression:
    raise AssertionError(arguments)

eg:

assert True#没有任何输出 程序继续向下执行
assert False
#输出
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input</
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值