学习总结
transformer解码器的mindspore实现
学习心得
解码器 (Decoder)
解码器将编码器输出的上下文序列转换为目标序列的预测结果(y帽),该输出将在模型训练中与真实目标输出Y进行比较,计算损失。
不同于编码器,每个Decoder层中包含两层多头注意力机制,并在最后多出一个线性层,输出对目标序列的预测结果。
- 第一层:计算目标序列的注意力分数的掩码多头自注意力;
- 第二层:用于计算上下文序列与目标序列对应关系,其中Decoder掩码多头注意力的输出作为query,Encoder的输出(上下文序列)作为key和value;
带掩码的多头注意力
在处理目标序列的输入时,t时刻的模型只能“观察”直到t-1时刻的所有词元,后续的词语不应该一并输入Decoder中。
为了保证在t时刻,只有t-1个词元作为输入参与多头注意力分数的计算,我们需要在第一个多头注意力中额外增加一个时间掩码,使目标序列中的词随时间发展逐个被暴露出来。
该注意力掩码可通过三角矩阵实现,对角线以上的词元表示为不参与注意力计算的词元,标记为1。
经验分享
def get_attn_subsequent_mask(seq_q, seq_k):
"""生成时间掩码,使decoder在第t时刻只能看到序列的前t-1个元素
Args:
seq_q (Tensor): query序列,shape = [batch size, len_q]
seq_k (Tensor): key序列,shape = [batch size, len_k]
"""
batch_size, len_q = seq_q.shape
batch_size, len_k = seq_k.shape
ones = ops.ones((batch_size, len_q, len_k), mindspore.float32)
subsequent_mask = mnp.triu(ones, k=1)
return subsequent_mask
q = k = ops.ones((1, 4), mstype.float32)
mask = get_attn_subsequent_mask(q, k)
print(mask)
Decoder Layer
# 首先实现Decoder中的一个层
def __init__(self, d_model, n_heads, d_ff, dropout_p=0.):
super().__init__()
d_k = d_model // n_heads
if d_k * n_heads != d_model:
raise ValueError(f"The `d_model` {d_model} can not be divisible by `num_heads` {n_heads}.")
self.dec_self_attn = MultiHeadAttention(d_model, d_k, n_heads, dropout_p)
self.dec_enc_attn = MultiHeadAttention(d_model, d_k, n_heads, dropout_p)
self.pos_ffn = PoswiseFeedForward(d_ff, d_model, dropout_p)
self.add_norm1 = AddNorm(d_model, dropout_p)
self.add_norm2 = AddNorm(d_model, dropout_p)
self.add_norm3 = AddNorm(d_model, dropout_p)
def construct(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask):
"""
dec_inputs: [batch_size, trg_len, d_model]
enc_outputs: [batch_size, src_len, d_model]
dec_self_attn_mask: [batch_size, trg_len, trg_len]
dec_enc_attn_mask: [batch_size, trg_len, src_len]
"""
residual = dec_inputs
dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask)
dec_outputs = self.add_norm1(dec_outputs, residual)
residual = dec_outputs
dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs, dec_enc_attn_mask)
dec_outputs = self.add_norm2(dec_outputs, residual)
residual = dec_outputs
dec_outputs = self.pos_ffn(dec_outputs)
dec_outputs = self.add_norm3(dec_outputs, residual)
return dec_outputs, dec_self_attn, dec_enc_attn
x = y = ops.ones((1, 2, 4), mstype.float32)
mask1 = mask2 = Tensor([False]).broadcast_to((1, 2, 2))
decoder_layer = DecoderLayer(4, 1, 16)
output, attn1, attn2 = decoder_layer(x, y, mask1, mask2)
print(output.shape, attn1.shape, attn2.shape)
# 将上面实现的DecoderLayer堆叠n_layer次,添加word embedding与positional encoding,以及最后的线性层。
# 输出的dec_outputs为对目标序列的预测。
class Decoder(nn.Cell):
def __init__(self, trg_vocab_size, d_model, n_heads, d_ff, n_layers, dropout_p=0.):
super().__init__()
self.trg_emb = nn.Embedding(trg_vocab_size, d_model)
self.pos_emb = PositionalEncoding(d_model, dropout_p)
self.layers = nn.CellList([DecoderLayer(d_model, n_heads, d_ff) for _ in range(n_layers)])
self.projection = nn.Dense(d_model, trg_vocab_size)
self.scaling_factor = ops.Sqrt()(Tensor(d_model, mstype.float32))
def construct(self, dec_inputs, enc_inputs, enc_outputs, src_pad_idx, trg_pad_idx):
"""
dec_inputs: [batch_size, trg_len]
enc_inputs: [batch_size, src_len]
enc_outputs: [batch_size, src_len, d_model]
"""
dec_outputs = self.trg_emb(dec_inputs.astype(mstype.int32))
dec_outputs = self.pos_emb(dec_outputs * self.scaling_factor)
dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs, trg_pad_idx)
dec_self_attn_subsequent_mask = get_attn_subsequent_mask(dec_inputs, dec_inputs)
dec_self_attn_mask = ops.gt((dec_self_attn_pad_mask + dec_self_attn_subsequent_mask), 0)
dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs, src_pad_idx)
dec_self_attns, dec_enc_attns = [], []
for layer in self.layers:
dec_outputs, dec_self_attn, dec_enc_attn = layer(dec_outputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask)
dec_self_attns.append(dec_self_attn)
dec_enc_attns.append(dec_enc_attn)
dec_outputs = self.projection(dec_outputs)
return dec_outputs, dec_self_attns, dec_enc_attns
课程反馈
学会使用mindspore实现transformer的decoder部分
使用MindSpore昇思的体验和反馈
使用mindspore实现decoder加深了对decoder的理解
未来展望
继续学习将decoder和encoder结合起来