Transformer的PositionEncoding源代码解读

trax是Google开源的深度学习库,基于Tensorflow和Jax,专注于简洁的端到端模型实现。文章介绍了Transformer模型中的PositionEncoding层,它是如何引入序列位置信息的,并展示了trax中PositionEncoding的实现代码,包括初始化和前向传播过程。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1. trax库了解一下

详细了解戳 trax

trax库是google开源的一个深度学习代码库,基于tensorflow, jax实现了主流的深度学习模型。市场上有这么多开源的深度学习模型实现库,为什么还要搞个trax呢?它的特点就是聚焦端到端的深度学习模型,主打的就是实现简洁,好理解。当然也可以用它来实现自己的模型,CPU、TPU、GPU都能支持,选择trax的源代码, 当然也是因为google对模型的实现正确性更有保证(这一点很重要,对于初学者,看到错误的实现不一定能识别出来)。

tranformer源代码地址:transformer

2. PositionEncodingTransformer模型中的位置

这里上经典的tranformer架构图看一下,如下图,这个PositionEncoding在Encoder与Decoder的输入端都有使用,从图上看是PositionEncoding的输出与输入相加后做为Encoder-block与Decoder-block的输入, 实际的实现又是怎么做的呢?继续往下看。。。

3. PositionEncoding层的原理

位置编码是为了把序列的位置信息考虑进模型,在tranformer中,作者说因为模型不包含递归(如RNN)和卷积,所以加入PositionEncoding以利用序列的位置信息。

看一下原论文:Attention is all you need, 给出了以下两个公式:

 其中pos表示序列的位置,i代表维度(指嵌入维度,tranformer的嵌入维度是512,即dmodel是512,i的取值范围是0<= 2i < 2i + 1 <= 512, 这也就符合论文中说的函数波长从2π 到 10000 · 2π)。

从这个公式可以看出第0和1个嵌入维度位置函数周期是相同的,一个使用sin, 一个使用cos; 类推,第2,3个嵌入维度周期也是相同,波长按指数级扩大一些。。。

位置信息处理都有哪些方法呢?看一下new bing的回答,可以参考一下:

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值