本文来源公众号“机器学习算法那些事”,仅用于学术分享,侵权删,干货满满。
原文链接:逐模块解析transformer结构
transformer是一种编解码(encoder-decoer)结构,用于自然语言处理、计算机视觉等领域,编解码结构是当前大模型必包含的部分。
编解码结构图:
image-20240221221206633
transformer模块编码输入得到特征,然后解码得到输出。
transformer论文的一张非常经典的图:
结合transformer论文和代码,模块主要包括了:
-
词嵌入模块(input embedding)
-
位置编码模块(Positional Encoding)
-
多头注意力机制模块(Multi-Head Attention)
-
层归一化模块(LayNorm)
-
残差模块
-
前馈神经网络模块(FFN)
-
交叉多头注意力机制模块(Cross Multi-Head Attention)
-
掩膜多头注意力机制模块(Masked Multi-Head Attention)
接下来一一介绍上述几个模块。
1. 词嵌入模块
词嵌入模块调用nn.Embedding,其主要作用是将每个单词表示成一个向量,方便下一步计算和处理。
class TokenEmbedding(nn.Embedding):
"""
Token Embedding using torch.nn
they will dense representation of word using weighted matrix
"""
def __init__(self, vocab_size, d_model):
"""
class for token embedding that included positional information
:param vocab_size: 字典中词的个数
:param d_model: 嵌入维度
"""
super(TokenEmbedding, self).__init__(vocab_size, d_model, padding_idx=1)