文章目录
一、整体功能
本文介绍了一个基于 RNN 的注意力机制实现,具体应用于机器翻译任务中。该机制通过动态聚焦源语言(如“欢迎、来、北京”)的不同词,辅助解码器(如预测“to”)生成更准确的目标语言词汇。
RNN中的注意力机制详解:网页链接
核心流程:
- 编码器(Encoder):将源语言句子转换为一系列隐状态(
h1, h2, ..., hn)。 - 解码器(Decoder):在生成目标语言词汇时,利用注意力机制关注源语言中的关键信息。
- 注意力机制:计算解码器当前状态与编码器各隐状态的相关性(注意力分数),并加权得到上下文向量,最终融合解码器状态和上下文向量进行输出。
示意图解析:

- 左侧(蓝色框):编码器处理源语言(“欢迎、来、北京”)的过程。
- 中间(绿色框):注意力机制的具体计算步骤。
- 右侧(粉色框):解码器根据注意力结果生成目标语言(“Welcome to Beijing”)的过程。
二、类初始化
def __init__(self, query_dim, key_dim, value_seq_len, value_dim, output_dim):
super().__init__()
# 注意力权重计算层 - 对应图中 “神经网络线性变换(Q拼K后)” 等操作,计算注意力分数
self.attention_weights = nn.Linear(query_dim + key_dim, value_seq_len)
# 输出映射层 - 对应图中 后续“神经网络线性变换(Q拼接C后)”,将融合后的特征映射到指定输出维度
self.output_projection = nn.Linear(query_dim + value_dim, output_dim)
self.attention_weights:对应图里对拼接后的Q(解码器上一步输出语义,如Welcome的词向量 )和K(编码器隐层输出,如h1等 )进行线性变换的操作,用于计算注意力分数,决定对源语言各词的关注程度。self.output_projection:对应图中注意力输出与Q再次拼接后,进行线性变换以适配解码器输入的环节,整合信息并转换维度。
三、前向传播流程(forward 函数)
3.1 扩展 query 以便拼接(对应图中 Q 与 K 交互准备)
# 1. 扩展query以便与key拼接
batch_size = query.size(0)
query_expanded = query.expand(-1, value.size(1), -1) # [batch, seq_len, query_dim]
- 作用:
query初始是解码器当前时间步状态(如形状[1, 1, 3]),要和长度为source_seq_len(源语言词数量,如 3 )的key拼接,需扩展query维度,让其在序列长度维度(第 2 维 )与key一致,为后续拼接Q和K做准备,对应图中Q参与和K交互的前置处理。
3.2 拼接 Q 和 K 并计算注意力分数(对应图中 Q 拼 K 与线性变换)
# 2. 拼接Q和K并计算注意力分数
qk_concat = torch.cat([query_expanded, key], dim=2) # [batch, seq_len, query_dim+key_dim]
attention_scores = self.attention_weights(qk_concat) # [batch, seq_len, seq_len]
- 操作解析:
torch.cat([query_expanded, key], dim=2):把扩展后的Q(解码器状态,代表目标语言已生成部分语义 )和K(编码器隐层输出,代表源语言各词语义 )在特征维度(第 3 维 )拼接,对应图里Q与K的拼接操作,融合二者初始特征。self.attention_weights(qk_concat):用线性层对拼接后的特征进行变换,输出注意力分数。此时输出形状[batch, seq_len, seq_len],可理解为计算出Q与K各位置间的关联分数,对应图中 “神经网络线性变换” 环节,初步得到注意力相关数值。
3.3 提取对角元素并获注意力分布(对应图中 Softmax 层等)
# 3. 提取对角元素作为位置注意力分数
attention_scores = attention_scores.diagonal(dim1=1, dim2=2).unsqueeze(1) # [batch, 1, seq_len]
# 4. 应用softmax获得注意力分布
attention_dist = F.softmax(attention_scores, dim=-1) # [batch, 1, seq_len]
- 操作解析:
attention_scores.diagonal(dim1=1, dim2=2).unsqueeze(1):从attention_scores中提取对角元素,是因为我们关注Q与K对应位置(如解码器当前状态与源语言各词位置匹配 )的关联,提取后形状变为[batch, 1, seq_len],对应聚焦源语言每个词与当前解码器状态的关联。F.softmax(attention_scores, dim=-1):对提取后的分数做Softmax归一化,得到注意力分布attention_dist,值在 0 - 1 之间且和为 1 ,对应图中 “Softmax 层” ,表示解码器预测当前词时,对源语言各词的关注权重,比如运行结果里对 “欢迎、来、北京” 的权重分布[0.3784, 0.0705, 0.5511]。
3.4 计算上下文向量(对应图中注意力输出与 V 交互)
# 5. 计算上下文向量(注意力加权的值)
context_vector = torch.bmm(attention_dist, value) # [batch, 1, value_dim]
- 作用:
value是编码器最后一层输出(如形状[1, 3, 3],代表 “欢迎、来、北京” 各自的编码语义 ),用注意力分布attention_dist对value加权求和(bmm是批量矩阵乘法 ),得到context_vector,融合源语言各词语义且突出注意力权重高的词,对应图中注意力输出与V(编码器隐层输出 )的交互,提取对生成当前目标词重要的源语言语义。
3.5 合并与投影(对应图中 Q 拼接 C 及后续线性变换)
# 6. 合并上下文向量和查询向量并投影到输出维度
combined = torch.cat([query, context_vector], dim=2) # [batch, 1, query_dim+value_dim]
output = self.output_projection(combined) # [batch, 1, output_dim]
- 操作解析:
torch.cat([query, context_vector], dim=2):把解码器原始状态query(代表目标语言已生成部分的语义 )和上下文向量context_vector(融合源语言关键语义 )在特征维度拼接,对应图中Q与注意力输出(如C)的拼接,整合目标语言已有信息和源语言关键信息。self.output_projection(combined):用线性层对拼接后的特征变换,输出output,将维度转换为output_dim(如示例中为 3 ),对应图中最后 “神经网络线性变换” ,得到可输入解码器 GRU 层的向量,辅助预测下一个目标词(如to)。
四、主函数
if __name__ == '__main__':
# 设置随机种子以便结果可复现
torch.manual_seed(44)
# 模拟"欢迎来北京"翻译为"Welcome to Beijing"的场景
source_seq_len = 3 # 源语言序列长度
hidden_dim = 3 # 隐藏层维度
# 随机生成编码器输出(值向量V)
encoder_outputs = torch.randn(1, source_seq_len, hidden_dim)
# 随机生成解码器当前状态(查询向量Q)
decoder_state = torch.randn(1, 1, hidden_dim)
# 随机生成编码器键向量(K)
encoder_keys = torch.randn(1, source_seq_len, hidden_dim)
# 创建注意力模型
attention = TranslationAttention(
query_dim=hidden_dim,
key_dim=hidden_dim,
value_seq_len=source_seq_len,
value_dim=hidden_dim,
output_dim=hidden_dim
)
# 计算注意力
output, attn_weights = attention(decoder_state, encoder_keys, encoder_outputs)
# 打印结果
print("注意力权重分布(表示解码器在预测'to'时对源语言各词的关注程度):")
for i, weight in enumerate(attn_weights[0, 0]):
print(f"词 {i+1}: {weight.item():.4f}")
print("\n上下文向量形状:", output.shape)
print("注意力分布形状:", attn_weights.shape)
- 这里通过随机生成数据模拟编码器、解码器的中间状态(如
encoder_outputs对应图中编码器最后一层输出V,decoder_state对应Q等 ),调用模型后得到的attn_weights就是图里注意力机制计算出的对源语言各词的关注权重,context是整合后的上下文向量,完整模拟了图中从编码器输出、解码器状态输入,到注意力计算、结果输出的流程。
五、完整代码
import torch
import torch.nn as nn
import torch.nn.functional as F
class TranslationAttention(nn.Module):
"""
翻译任务中的注意力机制,用于捕捉源语言和目标语言之间的语义关联
"""
def __init__(self, query_dim, key_dim, value_seq_len, value_dim, output_dim):
"""
初始化注意力机制模型。
参数:
query_dim (int): 查询向量(Q)的维度,通常为解码器隐藏状态的维度。
key_dim (int): 键向量(K)的维度,通常为编码器隐藏状态的维度。
value_seq_len (int): 值序列(V)的长度,即源语言句子的长度。
value_dim (int): 值向量(V)的维度,通常与键向量维度相同。
output_dim (int): 输出向量的维度,通常与解码器输入维度相同。
"""
super().__init__()
self.attention_weights = nn.Linear(query_dim + key_dim, value_seq_len)
self.output_projection = nn.Linear(query_dim + value_dim, output_dim)
def forward(self, query, key, value):
"""
前向传播过程。
参数:
query: 当前解码器状态 [batch_size, 1, query_dim]
key: 编码器键向量 [batch_size, seq_len, key_dim]
value: 编码器值向量 [batch_size, seq_len, value_dim]
返回:
output: 最终输出向量 [batch_size, 1, output_dim]
attention_dist: 注意力分布 [batch_size, 1, seq_len]
"""
# 1. 扩展query以便与key拼接
batch_size = query.size(0)
query_expanded = query.expand(-1, value.size(1), -1) # [batch, seq_len, query_dim]
# 2. 拼接Q和K并计算注意力分数
qk_concat = torch.cat([query_expanded, key], dim=2) # [batch, seq_len, query_dim+key_dim]
attention_scores = self.attention_weights(qk_concat) # [batch, seq_len, seq_len]
# 3. 提取对角元素作为位置注意力分数
attention_scores = attention_scores.diagonal(dim1=1, dim2=2).unsqueeze(1) # [batch, 1, seq_len]
# 4. 应用softmax获得注意力分布
attention_dist = F.softmax(attention_scores, dim=-1) # [batch, 1, seq_len]
# 5. 计算上下文向量(注意力加权的值)
context_vector = torch.bmm(attention_dist, value) # [batch, 1, value_dim]
# 6. 合并上下文向量和查询向量并投影到输出维度
combined = torch.cat([query, context_vector], dim=2) # [batch, 1, query_dim+value_dim]
output = self.output_projection(combined) # [batch, 1, output_dim]
return output, attention_dist
if __name__ == '__main__':
# 设置随机种子以便结果可复现
torch.manual_seed(44)
# 模拟"欢迎来北京"翻译为"Welcome to Beijing"的场景
source_seq_len = 3 # 源语言序列长度
hidden_dim = 3 # 隐藏层维度
# 随机生成编码器输出(值向量V)
encoder_outputs = torch.randn(1, source_seq_len, hidden_dim)
# 随机生成解码器当前状态(查询向量Q)
decoder_state = torch.randn(1, 1, hidden_dim)
# 随机生成编码器键向量(K)
encoder_keys = torch.randn(1, source_seq_len, hidden_dim)
# 创建注意力模型
attention = TranslationAttention(
query_dim=hidden_dim,
key_dim=hidden_dim,
value_seq_len=source_seq_len,
value_dim=hidden_dim,
output_dim=hidden_dim
)
# 计算注意力
output, attn_weights = attention(decoder_state, encoder_keys, encoder_outputs)
# 打印结果
print("注意力权重分布(表示解码器在预测'to'时对源语言各词的关注程度):")
for i, weight in enumerate(attn_weights[0, 0]):
print(f"词 {i+1}: {weight.item():.4f}")
print("\n上下文向量形状:", output.shape)
print("注意力分布形状:", attn_weights.shape)
综上,代码通过类结构定义和前向传播逻辑,完整实现了图中注意力机制的核心流程,从 Q、K、V 交互,到注意力分数计算、分布归一化,再到上下文向量生成与输出适配,与图中展示的注意力过程一一对应,为翻译任务里解码器利用源语言语义提供了机制支持。

1016

被折叠的 条评论
为什么被折叠?



