1 说明
首先,先给出Transformer的MultiHeadAttention部分的pytorch版本的代码,然后再对于此部分的细节进行解析
2 源码
class MultiHeadedAttention(nn.Module):
def __init__(self, h, d_model, dropout=0.1):
"Take in model size and number of heads."
super(MultiHeadedAttention, self).__init__()
assert d_model % h == 0#剖析点1
# We assume d_v always equals d_k
self.d_k = d_model // h
self.h = h
self.linears = clones(nn.Linear(d_model, d_model), 4)
self.attn = None
self.dropout = nn.Dropout(p=dropout)
def forward(self, query, key, value, mask=None):
# 纬度
# shape:query=key=value--->:[batch_size,max_legnth,embedding_dim=512]
if mask is not None:
# Same mask applied to all h heads.
mask = mask.unsqueeze(1)#剖析点2
nbatches = query.size(0)
#第一步:将q,k,v分别与Wq,Wk,Wv矩阵进行相乘
#shape:Wq=Wk=Wv----->[512,512]
#第二步:将获得的Q、K、V在第三个纬度上进行切分
#shape:[batch_size,max_length,8,64]
#第三部:填充到第一个纬度
#shape:[batch_size,8,max_length,64]
query, key, value = \
[l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
for l, x in zip(self.linears, (query, key, value))]#剖析点3
#进入到attention之后纬度不变,shape:[batch_size,8,max_length,64]
x, self.attn = attention(query, key, value, mask=mask,
dropout=self.dropout)
# 将纬度进行还原
# 交换纬度:[batch_size,max_length,8,64]
# 纬度还原:[batch_size,max_length,512]
x = x.transpose(1, 2).contiguous() \
.view(nbatches, -1, self.h * self.d_k)#剖析点4
# 最后与WO大矩阵相乘 shape:[512,512]
return self.linears[-1](x)
3 源码剖析
3.1 剖析点1:assert d_model % h == 0
assert断言机制
Python assert(断言)用于判断一个表达式,在表达式条件为 false 的时候触发异常。
语法:
assert expression
等价于(这种方式比较好理解)
if not expression:
raise AssertionError(arguments)
assert 后面也可以紧跟参数:
assert expression [, arguments]
等价于
if not expression:
raise AssertionError(arguments)
eg:
assert True#没有任何输出 程序继续向下执行
assert False
#输出
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
<ipython-input</

本文详细解析Transformer模型中的多头注意力机制实现原理与PyTorch代码实现过程,包括关键步骤如权重矩阵相乘、维度变换及张量操作等。
最低0.47元/天 解锁文章
2万+

被折叠的 条评论
为什么被折叠?



