位置编码 文本序列

4ed0dafa943f43bb817b4fc84e506fe8.jpg

 import torch

import math

 

def get_positional_encoding(max_len, d_model):

    pe = torch.zeros(max_len, d_model)

    position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)

    div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

    pe[:, 0::2] = torch.sin(position * div_term)

    pe[:, 1::2] = torch.cos(position * div_term)

    return pe.unsqueeze(0) # (1, max_len, d_model)

 

# 示例

max_len = 50

d_model = 512

positional_encoding = get_positional_encoding(max_len, d_model)

print(positional_encoding.shape) # 输出: torch.Size([1, 50, 512])

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值