1. trax库了解一下
详细了解戳 trax
trax库是google开源的一个深度学习代码库,基于tensorflow, jax实现了主流的深度学习模型。市场上有这么多开源的深度学习模型实现库,为什么还要搞个trax呢?它的特点就是聚焦端到端的深度学习模型,主打的就是实现简洁,好理解。当然也可以用它来实现自己的模型,CPU、TPU、GPU都能支持,选择trax的源代码, 当然也是因为google对模型的实现正确性更有保证(这一点很重要,对于初学者,看到错误的实现不一定能识别出来)。
tranformer源代码地址:transformer
2. PositionEncoding在Transformer模型中的位置
这里上经典的tranformer架构图看一下,如下图,这个PositionEncoding在Encoder与Decoder的输入端都有使用,从图上看是PositionEncoding的输出与输入相加后做为Encoder-block与Decoder-block的输入, 实际的实现又是怎么做的呢?继续往下看。。。
3. PositionEncoding层的原理
位置编码是为了把序列的位置信息考虑进模型,在tranformer中,作者说因为模型不包含递归(如RNN)和卷积,所以加入PositionEncoding以利用序列的位置信息。
看一下原论文:Attention is all you need, 给出了以下两个公式:
其中pos表示序列的位置,i代表维度(指嵌入维度,tranformer的嵌入维度是512,即dmodel是512,i的取值范围是0<= 2i < 2i + 1 <= 512, 这也就符合论文中说的函数波长从2π 到 10000 · 2π)。
从这个公式可以看出第0和1个嵌入维度位置函数周期是相同的,一个使用sin, 一个使用cos; 类推,第2,3个嵌入维度周期也是相同,波长按指数级扩大一些。。。
位置信息处理都有哪些方法呢?看一下new bing的回答,可以参考一下: