Transformer的输入部分进行了解,主要是文本嵌入层的代码分析和位置编码。
1.文本嵌入层的代码分析
#定义Embeddings类来实现文本嵌入层,这里s说明代表两个一模一样的嵌入层,他们共享参数.
class Embeddings(nn.Module):
#"""类的初始化函数,有两个参数. d _model:指词嵌入的维度, vocab:指词表的大小. """
def __init__(self, d_model, vocab):
#接着就是使用super的方式指明继承nn.Module的初始化函数,我们自己实现的所有层都会这样去
super(Embeddings,self).__init__()
#调用nn中预定义层Embeddings,获得一个词嵌入对象self.lut
self.lut = nn.Embedding(vocab, d_model)
#最后将d_model传入类中
self.d_model = d_model
"""可以将其理解为该层的前向传播逻辑,所有层中都会有此函数当传给该类的实例化对象参数时,自动调用该类函数
参数x︰因为Embedding层是首层,所以代表输入给模型的文本通过词汇映射后的张量"""
def forward(self, x):
#将x传给self. lut并与根号下self.d_model相乘作