我们假设一个极简的翻译任务,并设定一套小巧的参数,方便观察。
场景设定
任务: 将一个最多4个词的句子,翻译成一个最多5个词的句子。
批量大小 (BATCH_SIZE): 2
(一次处理2个句子)
模型维度 (D_MODEL): 12
(通常是512,我们用12方便看清)
头的数量 (NHEAD): 3
(3个注意力头)
每个头的维度 (D_HEAD): D_MODEL / NHEAD = 12 / 3 = 4
词汇表大小 (VOCAB_SIZE): 100
第一阶段:数据在 Encoder 中的流动
Encoder 的任务是接收源序列,输出一份“记忆”.
1. 初始输入
我们输入一个批次的、已经转换成数字ID的源句子。
# 假设源句子长度为4 (已填充)
SRC_SEQ_LEN = 4
# 创建一个假的源数据张量
src_tokens = torch.randint(1, 100, (BATCH_SIZE, SRC_SEQ_LEN))
# 打印形状
print(f"初始输入形状: {src_tokens.shape}")
初始输入形状: torch.Size([2, 4])
[2, 4]
代表:一个批次包含 2 个句子,每个句子由 4 个词元ID组成。
2. 词嵌入与位置编码
模型的第一步是将这些ID转换成向量。
# 假设 embedding 是一个 nn.Embedding(VOCAB_SIZE, D_MODEL) 层
embedding_layer = nn.Embedding(100, 12)
src_vectors = embedding_layer(src_tokens) # 位置编码会加上去,但不改变形状
# 打印形状
print(f"嵌入后的形状: {src_vectors.shape}")
嵌入后的形状: torch.Size([2, 4, 12])
[2, 4, 12]
代表:批次中的 2 个句子,每个句子 4 个词,现在每个词都变成了一个 12 维的向量(D_MODEL=12
)。
3. 进入多头自注意力
追踪 src_vectors
进入一个 MultiHeadAttention
模块后的变化。
# --- 在 MultiHeadAttention 内部 ---
# 1. 线性变换得到 Q, K, V (形状不变)
# q = q_linear(src_vectors) -> torch.Size([2, 4, 12])
# 2. 为多头机制重塑形状 (view)
# 我们将 12 维的向量看作是 3 个头,每个头 4 维 (12 = 3 * 4)
reshaped_q = q.view(2, 4, 3, 4) # (BATCH, SEQ_LEN, NHEAD, D_HEAD)
print(f" view后的形状: {reshaped_q.shape}")
# 3. 转置以方便并行计算 (transpose)
# 我们把 NHEAD 维度换到前面
transposed_q = reshaped_q.transpose(1, 2) # (BATCH, NHEAD, SEQ_LEN, D_HEAD)
print(f"transpose后的形状: {transposed_q.shape}")
view后的形状: torch.Size([2, 4, 3, 4]) transpose后的形状: torch.Size([2, 3, 4, 4])
view
操作是理解多头的关键。它没有改变数据,只是改变了我们“看待”数据的方式。[..., 12]
被看作了 [..., 3, 4]
。
transpose
操作是为了把所有属于同一个“头”的数据聚合在一起,方便进行矩阵乘法。现在的维度含义是:2 个样本,每个样本有 3 个头,每个头处理一个 4 词长的序列,每个词是 4 维。
#4. 计算注意力分数 (q @ k.T)
# q 的形状: [2, 3, 4, 4]
# k.transpose 的形状: [2, 3, 4, 4]
scores = torch.matmul(transposed_q, transposed_k.transpose(-2, -1))
# scores 的形状: [2, 3, 4, 4] -> (BATCH, NHEAD, SEQ_LEN, SEQ_LEN)
# 它计算了每个头内部,每个词对其他所有词的注意力分数。
# 5. 加权求和 Value
# attn_weights 的形状: [2, 3, 4, 4]
# v 的形状: [2, 3, 4, 4]
# context = torch.matmul(attn_weights, transposed_v) -> [2, 3, 4, 4]
# 6. 拼接多头 (view)
# 首先,把维度换回来
context_transposed = context.transpose(1, 2) # -> [2, 4, 3, 4]
# 然后,用 view 把最后两个维度“粘合”起来 (3 * 4 = 12)
final_output = context_transposed.contiguous().view(2, 4, 12)
print(f"\n多头注意力输出形状: {final_output.shape}")
多头注意力输出形状: torch.Size([2, 4, 12])
经过一整套复杂的操作,输出的形状变回了和输入时一样的 [2, 4, 12],
虽然形状没变,但每个向量内部的信息已经被上下文重组和丰富了。
Encoder 最终输出
数据流经 N 个 EncoderLayer,但因为每一层的输入输出形状都相同,所以最终输出形状不变。
# 假设 encoder 是一个完整的 Encoder 实例
memory = encoder(src_tokens)
print(f"\nEncoder最终输出(memory)形状: {memory.shape}")
Encoder最终输出(memory)形状: torch.Size([2, 4, 12])
这份 [2, 4, 12]
的张量就是源句子的“记忆”,将作为只读信息提供给 Decoder。
第二阶段:数据在 Decoder 中的流动
Decoder 的任务是利用 memory
,生成目标序列。
1. Decoder 初始输入
Decoder 的输入是目标语言的词元。
# 假设目标句子长度为5 (已填充)
TGT_SEQ_LEN = 5
tgt_tokens = torch.randint(1, 100, (BATCH_SIZE, TGT_SEQ_LEN))
tgt_vectors = embedding_layer(tgt_tokens) # 同样进行嵌入
print(f"Decoder输入向量形状: {tgt_vectors.shape}")
输出及解释:
Decoder输入向量形状: torch.Size([2, 5, 12])
2 个样本,每个目标序列有 5 个词,每个词 12 维。
2. Decoder 的交叉注意力 (联动核心)
在 Decoder 的交叉注意力层,数据维度会发生有趣的交互。
Query (Q): 来自 Decoder 自身,经过了带掩码自注意力层。它的形状是 [2, 5, 12]
。
Key (K) 和 Value (V): 来自 Encoder 的 memory
。它们的形状是 [2, 4, 12]
。
# --- 在 Cross-Attention 内部 ---
# 1. 变换和重塑
# q 来自 Decoder: transpose后形状为 [2, 3, 5, 4] (BATCH, NHEAD, TGT_LEN, D_HEAD)
# k 来自 Encoder: transpose后形状为 [2, 3, 4, 4] (BATCH, NHEAD, SRC_LEN, D_HEAD)
# v 来自 Encoder: transpose后形状为 [2, 3, 4, 4] (BATCH, NHEAD, SRC_LEN, D_HEAD)
# 2. 计算注意力分数
# scores = q @ k.T
# q 的形状: [2, 3, 5, 4]
# k.transpose 的形状: [2, 3, 4, 4]
# scores 的形状: [2, 3, 5, 4] -> (BATCH, NHEAD, TGT_LEN, SRC_LEN)
# 它的含义是:对于目标句子的每个词(5个),计算它对源句子每个词(4个)的注意力。
# 3. 加权求和 Value
# attn_weights 的形状: [2, 3, 5, 4]
# v 的形状: [2, 3, 4, 4]
# context = attn_weights @ v -> [2, 3, 5, 4] -> (BATCH, NHEAD, TGT_LEN, D_HEAD)
# 注意!输出的序列长度(5)由 Query (即tgt)决定,而向量维度(4)由 Value (即src)决定。
# 4. 拼接多头
# final_output = context...view(2, 5, 12)
交叉注意力的输出形状,其序列长度 (5
) 跟随 Query
(Decoder),其向量维度 (12
) 跟随 Value
(Encoder)。最终输出形状为 [2, 5, 12]
,与进入交叉注意力层之前的 Query
形状一致。
3. Decoder 最终输出
数据流经 N 个 DecoderLayer 后,进入最后的线性层,预测下一个词。
Python
# 假设 decoder 是一个完整的 Decoder 实例
# decoder_output 的形状是 [2, 5, 12] (BATCH, TGT_LEN, D_MODEL)
decoder_output = decoder(tgt_tokens, memory)
# 假设 output_linear 是一个 nn.Linear(D_MODEL, VOCAB_SIZE) 层
final_logits = output_linear(decoder_output)
print(f"\nDecoder最终输出(logits)形状: {final_logits.shape}")
输出及解释:
Decoder最终输出(logits)形状: torch.Size([2, 5, 100])
-
[2, 5, 100]
代表:2 个样本,对于目标序列的 5 个位置,每个位置都生成了一个 100 维的向量,这个向量是模型对词汇表中每个词的预测得分。