学习总结
在前文实现好的encoder和decoder的基础上实现transformer
学习心得
将实现的Encoder与Decoder组合起来。
class Transformer(nn.Cell):
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
def construct(self, enc_inputs, dec_inputs, src_pad_idx, trg_pad_idx):
"""
enc_inputs: [batch_size, src_len]
dec_inputs: [batch_size, trg_len]
"""
enc_outputs, enc_self_attns = self.encoder(enc_inputs, src_pad_idx)
dec_outputs, dec_self_attns, dec_enc_attns = self.decoder(dec_inputs, enc_inputs, enc_outputs, src_pad_idx, trg_pad_idx)
dec_logits = dec_outputs.view((-1, dec_outputs.shape[-1]))
return dec_logits, enc_self_attns, dec_self_attns, dec_enc_attns
经验分享
将所有前文的代码组合起来,就能实现一个完整的transformer了
课程反馈
学会使用mindspore实现整个transformer
使用MindSpore昇思的体验和反馈
整体学习感觉十分顺畅,老师讲的也很明白
未来展望
继续学习其他模块mindspore实现的内容