- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊
目录
理论知识
Transformer是可以用于Seq2Seq任务的一种模型,和Seq2Seq不冲突。
模型结构
结构分解
黑盒
以机器翻译任务为例
两大模块
在Transformer内部,可以分成Encoder编码器和和Decoder解码器两部分,这也是Seq2Seq的标准结构。
块级结构
继续拆解,可以发现模型的由许多的编码器块和解码器块组成并且每个解码器都可以获取到最后一层编码器的输出以及上一层解码器的输出(第一个当然是例外的)。
编码器的组成
继续拆解,一个编码器是由一个自注意力块和一个前馈网络组成。
解码器的组成
而解码器,是在编码器的结构中间又插入了一个Encoder-Decoder Attention层。
模型实现
通过前面自顶向下的拆解,已经基本掌握了模型的总体结构。接下来自底向上的复现Transformer模型。
多头自注意力块
class MultiHeadAttention(nn.Module):
"""多头注意力模块"""
def __init__(self, dims, n_heads):
"""
dims: 每个词向量维度
n_heads: 注意力头数
"""
super().__init__()
self.dims = dims
self.n_heads = n_heads
# 维度必需整除注意力头数
assert dims % n_heads == 0
# 定义Q矩阵
self.w_Q = nn.Linear(dims, dims)
# 定义K矩阵
self.w_K = nn.Linear(dims, dims)
# 定义V矩阵
self.w_V = nn.Linear(dims, dims)
self.fc = nn.Linear(dims, dims)
# 缩放
self.scale = torch.sqrt(torch.FloatTensor([dims//n_heads])).to(device)
def forward(self, query, key, value, mask=None):
batch_size = query.shape[0]
# 例如: [32, 1024, 300] 计算10头注意力
Q = self.w_Q(query)
K = self.w_K(key)
V = self.w_V(value)
# [32, 1024, 300] -> [32, 1024, 10, 30] 把向量重新分组
Q = Q.view(batch_size, -1, self.n_heads, self.dims//self.n_heads).permute(0, 2, 1, 3)
K = K.view(batch_size, -1, self.n_heads, self.dims//self.n_heads).permute(0, 2, 1, 3)
V = V.view(batch_size, -1, self.n_heads, self.dims//self.n_heads).