一分钟快速EASY理解transformer的生成过程

让我用一个简单的例子来解释 dec_inp 在 Transformer 中的作用:
假设我们在做时间序列预测:

# 输入序列(batch_x): [10, 20, 30, 40, 50]
# 目标序列: [60, 70, 80]

# 1. 训练阶段
dec_inp = [50, 60, 70]  # 右移一位的目标序列,以最后一个输入值开始

# 具体过程:
step1 = decoder(enc_output, dec_inp[0:1])     # 输入50,预测60
step2 = decoder(enc_output, dec_inp[0:2])     # 输入[50,60],预测70
step3 = decoder(enc_output, dec_inp[0:3])     # 输入[50,60,70],预测80

# 2. 预测阶段
dec_inp = [50]  # 开始只有最后一个输入值
step1 = decoder(enc_output, dec_inp)          # 预测得到60
dec_inp = [50, 60]                            # 将预测值加入
step2 = decoder(enc_output, dec_inp)          # 预测得到70
dec_inp = [50, 60, 70]                        # 继续加入预测值
step3 = decoder(enc_output, dec_inp)          # 预测得到80

关键点:
dec_inp 在训练时是已知的(teacher forcing),包含了目标序列右移一位的值
在预测时是逐步生成的,每次将新预测的值加入到 dec_inp 中
这种机制让模型能学会利用之前的预测来做出下一步预测
这就像是:

**训练时:**老师告诉你"看到50,应该说60;看到50和60,应该说70"
**预测时:**你看到50,说出60;然后基于50和60,说出70
这就是为什么叫"teacher forcing",因为训练时老师(ground truth)在强制指导你。

在这里插入图片描述

# 假设参数设置:
seq_len = 4      # 输入序列长度
label_len = 2    # 重叠长度
pred_len = 3     # 预测长度

# 具体数据示例:
batch_x = [10, 20, 30, 40]        # 输入序列
batch_y = [30, 40, 50, 60, 70]    # 目标序列(包含label_len和pred_len部分)
                ↑  ↑   ↑  ↑  ↑
                |  |   |  |  |
                重叠部分 预测部分
                (label_len=2) (pred_len=3)

# dec_inp的构建过程:
label_part = [30, 40]     # 从batch_y取label_len长度
zeros = [0, 0, 0]         # pred_len个零
dec_inp = [30, 40, 0, 0, 0]  # 拼接得到最终的decoder输入

关键点:
label_len 是输入和目标序列的重叠长度,用于帮助模型过渡
dec_inp 最终长度 = label_len + pred_len
dec_inp 前面部分用真实值,后面用0填充
这就像是:给模型一个"起跑助力",让它先看到一部分真实值,再开始预测未来值。

# 参数设置
seq_len = 4      # 输入序列长度
label_len = 2    # 重叠长度
pred_len = 3     # 预测长度

# 数据示例
batch_x = [10, 20, 30, 40]        # 输入序列
batch_y = [30, 40, 50, 60, 70]    # 目标序列

# 1. 构建初始dec_inp
dec_inp初始 = [30, 40, 0, 0, 0]   # label_len部分 + pred_len个0
            ↑   ↑  ↑  ↑  ↑
            来自batch_y  填充0

# 2. 预测过程
# Step 1: 编码器处理输入序列
enc_output = encoder(batch_x)  # [10,20,30,40] -> 编码结果

# Step 2: 解码器预测
# 训练时(使用teacher forcing):
step1 = decoder(enc_output, [30,40])     # 预测50
step2 = decoder(enc_output, [30,40,50])  # 预测60
step3 = decoder(enc_output, [30,40,50,60]) # 预测70

# 预测时(自回归方式):
pred1 = decoder(enc_output, [30,40])     # 预测得到50
dec_inp变为 = [30,40,50,0,0]

pred2 = decoder(enc_output, [30,40,50])  # 预测得到60
dec_inp变为 = [30,40,50,60,0]

pred3 = decoder(enc_output, [30,40,50,60]) # 预测得到70
dec_inp变为 = [30,40,50,60,70]

最终预测结果 = [50,60,70]

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值