让我用一个简单的例子来解释 dec_inp 在 Transformer 中的作用:
假设我们在做时间序列预测:
# 输入序列(batch_x): [10, 20, 30, 40, 50]
# 目标序列: [60, 70, 80]
# 1. 训练阶段
dec_inp = [50, 60, 70] # 右移一位的目标序列,以最后一个输入值开始
# 具体过程:
step1 = decoder(enc_output, dec_inp[0:1]) # 输入50,预测60
step2 = decoder(enc_output, dec_inp[0:2]) # 输入[50,60],预测70
step3 = decoder(enc_output, dec_inp[0:3]) # 输入[50,60,70],预测80
# 2. 预测阶段
dec_inp = [50] # 开始只有最后一个输入值
step1 = decoder(enc_output, dec_inp) # 预测得到60
dec_inp = [50, 60] # 将预测值加入
step2 = decoder(enc_output, dec_inp) # 预测得到70
dec_inp = [50, 60, 70] # 继续加入预测值
step3 = decoder(enc_output, dec_inp) # 预测得到80
关键点:
dec_inp 在训练时是已知的(teacher forcing),包含了目标序列右移一位的值
在预测时是逐步生成的,每次将新预测的值加入到 dec_inp 中
这种机制让模型能学会利用之前的预测来做出下一步预测
这就像是:
**训练时:**老师告诉你"看到50,应该说60;看到50和60,应该说70"
**预测时:**你看到50,说出60;然后基于50和60,说出70
这就是为什么叫"teacher forcing",因为训练时老师(ground truth)在强制指导你。
# 假设参数设置:
seq_len = 4 # 输入序列长度
label_len = 2 # 重叠长度
pred_len = 3 # 预测长度
# 具体数据示例:
batch_x = [10, 20, 30, 40] # 输入序列
batch_y = [30, 40, 50, 60, 70] # 目标序列(包含label_len和pred_len部分)
↑ ↑ ↑ ↑ ↑
| | | | |
重叠部分 预测部分
(label_len=2) (pred_len=3)
# dec_inp的构建过程:
label_part = [30, 40] # 从batch_y取label_len长度
zeros = [0, 0, 0] # pred_len个零
dec_inp = [30, 40, 0, 0, 0] # 拼接得到最终的decoder输入
关键点:
label_len 是输入和目标序列的重叠长度,用于帮助模型过渡
dec_inp 最终长度 = label_len + pred_len
dec_inp 前面部分用真实值,后面用0填充
这就像是:给模型一个"起跑助力",让它先看到一部分真实值,再开始预测未来值。
# 参数设置
seq_len = 4 # 输入序列长度
label_len = 2 # 重叠长度
pred_len = 3 # 预测长度
# 数据示例
batch_x = [10, 20, 30, 40] # 输入序列
batch_y = [30, 40, 50, 60, 70] # 目标序列
# 1. 构建初始dec_inp
dec_inp初始 = [30, 40, 0, 0, 0] # label_len部分 + pred_len个0
↑ ↑ ↑ ↑ ↑
来自batch_y 填充0
# 2. 预测过程
# Step 1: 编码器处理输入序列
enc_output = encoder(batch_x) # [10,20,30,40] -> 编码结果
# Step 2: 解码器预测
# 训练时(使用teacher forcing):
step1 = decoder(enc_output, [30,40]) # 预测50
step2 = decoder(enc_output, [30,40,50]) # 预测60
step3 = decoder(enc_output, [30,40,50,60]) # 预测70
# 预测时(自回归方式):
pred1 = decoder(enc_output, [30,40]) # 预测得到50
dec_inp变为 = [30,40,50,0,0]
pred2 = decoder(enc_output, [30,40,50]) # 预测得到60
dec_inp变为 = [30,40,50,60,0]
pred3 = decoder(enc_output, [30,40,50,60]) # 预测得到70
dec_inp变为 = [30,40,50,60,70]
最终预测结果 = [50,60,70]