文章目录
一、整体功能
本文介绍了一个基于 RNN 的注意力机制实现,具体应用于机器翻译任务中。该机制通过动态聚焦源语言(如“欢迎、来、北京”)的不同词,辅助解码器(如预测“to”)生成更准确的目标语言词汇。
RNN中的注意力机制详解:网页链接
核心流程:
- 编码器(Encoder):将源语言句子转换为一系列隐状态(
h1, h2, ..., hn
)。 - 解码器(Decoder):在生成目标语言词汇时,利用注意力机制关注源语言中的关键信息。
- 注意力机制:计算解码器当前状态与编码器各隐状态的相关性(注意力分数),并加权得到上下文向量,最终融合解码器状态和上下文向量进行输出。
示意图解析:
- 左侧(蓝色框):编码器处理源语言(“欢迎、来、北京”)的过程。
- 中间(绿色框):注意力机制的具体计算步骤。
- 右侧(粉色框):解码器根据注意力结果生成目标语言(“Welcome to Beijing”)的过程。
二、类初始化
def __init__(self, query_dim, key_dim, value_seq_len, value_dim, output_dim):
super().__init__()
# 注意力权重计算层 - 对应图中 “神经网络线性变换(Q拼K后)” 等操作,计算注意力分数
self.attention_weights = nn.Linear(query_dim + key_dim, value_seq_len)
# 输出映射层 - 对应图中 后续“神经网络线性变换(Q拼接C后)”,将融合后的特征映射到指定输出维度
self.output_projection = nn.Linear(query_dim + value_dim, output_dim)
self.attention_weights
:对应图里对拼接后的Q
(解码器上一步输出语义,如Welcome
的词向量 )和K
(编码器隐层输出,如h1
等 )进行线性变换的操作,用于计算注意力分数,决定对源语言各词的关注程度。self.output_projection
:对应图中注意力输出与Q
再次拼接后,进行线性变换以适配解码器输入的环节,整合信息并转换维度。
三、前向传播流程(forward
函数)
3.1 扩展 query
以便拼接(对应图中 Q
与 K
交互准备)
# 1. 扩展query以便与key拼接
batch_size = query.size(0)
query_expanded = query.expand(-1, value.size(1), -1) # [batch, seq_len, query_dim]
- 作用:
query
初始是解码器当前时间步状态(如形状[1, 1, 3]
),要和长度为source_seq_len
(源语言词数量,如 3 )的key
拼接,需扩展query
维度,让其在序列长度维度(第 2 维 )与key
一致,为后续拼接Q
和K
做准备,对应图中Q
参与和K
交互的前置处理。
3.2 拼接 Q
和 K
并计算注意力分数(对应图中 Q
拼 K
与线性变换)
# 2. 拼接Q和K并计算注意力分数
qk_concat = torch.cat([query_expanded, key], dim=2) # [batch, seq_len, query_dim+key_dim]
attention_scores = self.attention_weights(qk_concat) # [batch, seq_len, seq_len]
- 操作解析:
torch.cat([query_expanded, key], dim=2)
:把扩展后的Q
(解码器状态,代表目标语言已生成部分语义 )和K
(编码器隐层输出,代表源语言各词语义 )在特征维度(第 3 维 )拼接,对应图里Q
与K
的拼接操作,融合二者初始特征。self.attention_weights(qk_concat)
:用线性层对拼接后的特征进行变换,输出注意力分数。此时输出形状[batch, seq_len, seq_len]
,可理解为计算出Q
与K
各位置间的关联分数,对应图中 “神经网络线性变换” 环节,初步得到注意力相关数值。
3.3 提取对角元素并获注意力分布(对应图中 Softmax
层等)
# 3. 提取对角元素作为位置注意力分数
attention_scores = attention_scores.diagonal(dim1=1, dim2=2).unsqueeze(1) # [batch, 1, seq_len]
# 4. 应用softmax获得注意力分布
attention_dist = F.softmax(attention_scores, dim=-1) # [batch, 1, seq_len]
- 操作解析:
attention_scores.diagonal(dim1=1, dim2=2).unsqueeze(1)
:从attention_scores
中提取对角元素,是因为我们关注Q
与K
对应位置(如解码器当前状态与源语言各词位置匹配 )的关联,提取后形状变为[batch, 1, seq_len]
,对应聚焦源语言每个词与当前解码器状态的关联。F.softmax(attention_scores, dim=-1)
:对提取后的分数做Softmax
归一化,得到注意力分布attention_dist
,值在 0 - 1 之间且和为 1 ,对应图中 “Softmax 层” ,表示解码器预测当前词时,对源语言各词的关注权重,比如运行结果里对 “欢迎、来、北京” 的权重分布[0.3784, 0.0705, 0.5511]
。
3.4 计算上下文向量(对应图中注意力输出与 V
交互)
# 5. 计算上下文向量(注意力加权的值)
context_vector = torch.bmm(attention_dist, value) # [batch, 1, value_dim]
- 作用:
value
是编码器最后一层输出(如形状[1, 3, 3]
,代表 “欢迎、来、北京” 各自的编码语义 ),用注意力分布attention_dist
对value
加权求和(bmm
是批量矩阵乘法 ),得到context_vector
,融合源语言各词语义且突出注意力权重高的词,对应图中注意力输出与V
(编码器隐层输出 )的交互,提取对生成当前目标词重要的源语言语义。
3.5 合并与投影(对应图中 Q
拼接 C
及后续线性变换)
# 6. 合并上下文向量和查询向量并投影到输出维度
combined = torch.cat([query, context_vector], dim=2) # [batch, 1, query_dim+value_dim]
output = self.output_projection(combined) # [batch, 1, output_dim]
- 操作解析:
torch.cat([query, context_vector], dim=2)
:把解码器原始状态query
(代表目标语言已生成部分的语义 )和上下文向量context_vector
(融合源语言关键语义 )在特征维度拼接,对应图中Q
与注意力输出(如C
)的拼接,整合目标语言已有信息和源语言关键信息。self.output_projection(combined)
:用线性层对拼接后的特征变换,输出output
,将维度转换为output_dim
(如示例中为 3 ),对应图中最后 “神经网络线性变换” ,得到可输入解码器 GRU 层的向量,辅助预测下一个目标词(如to
)。
四、主函数
if __name__ == '__main__':
# 设置随机种子以便结果可复现
torch.manual_seed(44)
# 模拟"欢迎来北京"翻译为"Welcome to Beijing"的场景
source_seq_len = 3 # 源语言序列长度
hidden_dim = 3 # 隐藏层维度
# 随机生成编码器输出(值向量V)
encoder_outputs = torch.randn(1, source_seq_len, hidden_dim)
# 随机生成解码器当前状态(查询向量Q)
decoder_state = torch.randn(1, 1, hidden_dim)
# 随机生成编码器键向量(K)
encoder_keys = torch.randn(1, source_seq_len, hidden_dim)
# 创建注意力模型
attention = TranslationAttention(
query_dim=hidden_dim,
key_dim=hidden_dim,
value_seq_len=source_seq_len,
value_dim=hidden_dim,
output_dim=hidden_dim
)
# 计算注意力
output, attn_weights = attention(decoder_state, encoder_keys, encoder_outputs)
# 打印结果
print("注意力权重分布(表示解码器在预测'to'时对源语言各词的关注程度):")
for i, weight in enumerate(attn_weights[0, 0]):
print(f"词 {i+1}: {weight.item():.4f}")
print("\n上下文向量形状:", output.shape)
print("注意力分布形状:", attn_weights.shape)
- 这里通过随机生成数据模拟编码器、解码器的中间状态(如
encoder_outputs
对应图中编码器最后一层输出V
,decoder_state
对应Q
等 ),调用模型后得到的attn_weights
就是图里注意力机制计算出的对源语言各词的关注权重,context
是整合后的上下文向量,完整模拟了图中从编码器输出、解码器状态输入,到注意力计算、结果输出的流程。
五、完整代码
import torch
import torch.nn as nn
import torch.nn.functional as F
class TranslationAttention(nn.Module):
"""
翻译任务中的注意力机制,用于捕捉源语言和目标语言之间的语义关联
"""
def __init__(self, query_dim, key_dim, value_seq_len, value_dim, output_dim):
"""
初始化注意力机制模型。
参数:
query_dim (int): 查询向量(Q)的维度,通常为解码器隐藏状态的维度。
key_dim (int): 键向量(K)的维度,通常为编码器隐藏状态的维度。
value_seq_len (int): 值序列(V)的长度,即源语言句子的长度。
value_dim (int): 值向量(V)的维度,通常与键向量维度相同。
output_dim (int): 输出向量的维度,通常与解码器输入维度相同。
"""
super().__init__()
self.attention_weights = nn.Linear(query_dim + key_dim, value_seq_len)
self.output_projection = nn.Linear(query_dim + value_dim, output_dim)
def forward(self, query, key, value):
"""
前向传播过程。
参数:
query: 当前解码器状态 [batch_size, 1, query_dim]
key: 编码器键向量 [batch_size, seq_len, key_dim]
value: 编码器值向量 [batch_size, seq_len, value_dim]
返回:
output: 最终输出向量 [batch_size, 1, output_dim]
attention_dist: 注意力分布 [batch_size, 1, seq_len]
"""
# 1. 扩展query以便与key拼接
batch_size = query.size(0)
query_expanded = query.expand(-1, value.size(1), -1) # [batch, seq_len, query_dim]
# 2. 拼接Q和K并计算注意力分数
qk_concat = torch.cat([query_expanded, key], dim=2) # [batch, seq_len, query_dim+key_dim]
attention_scores = self.attention_weights(qk_concat) # [batch, seq_len, seq_len]
# 3. 提取对角元素作为位置注意力分数
attention_scores = attention_scores.diagonal(dim1=1, dim2=2).unsqueeze(1) # [batch, 1, seq_len]
# 4. 应用softmax获得注意力分布
attention_dist = F.softmax(attention_scores, dim=-1) # [batch, 1, seq_len]
# 5. 计算上下文向量(注意力加权的值)
context_vector = torch.bmm(attention_dist, value) # [batch, 1, value_dim]
# 6. 合并上下文向量和查询向量并投影到输出维度
combined = torch.cat([query, context_vector], dim=2) # [batch, 1, query_dim+value_dim]
output = self.output_projection(combined) # [batch, 1, output_dim]
return output, attention_dist
if __name__ == '__main__':
# 设置随机种子以便结果可复现
torch.manual_seed(44)
# 模拟"欢迎来北京"翻译为"Welcome to Beijing"的场景
source_seq_len = 3 # 源语言序列长度
hidden_dim = 3 # 隐藏层维度
# 随机生成编码器输出(值向量V)
encoder_outputs = torch.randn(1, source_seq_len, hidden_dim)
# 随机生成解码器当前状态(查询向量Q)
decoder_state = torch.randn(1, 1, hidden_dim)
# 随机生成编码器键向量(K)
encoder_keys = torch.randn(1, source_seq_len, hidden_dim)
# 创建注意力模型
attention = TranslationAttention(
query_dim=hidden_dim,
key_dim=hidden_dim,
value_seq_len=source_seq_len,
value_dim=hidden_dim,
output_dim=hidden_dim
)
# 计算注意力
output, attn_weights = attention(decoder_state, encoder_keys, encoder_outputs)
# 打印结果
print("注意力权重分布(表示解码器在预测'to'时对源语言各词的关注程度):")
for i, weight in enumerate(attn_weights[0, 0]):
print(f"词 {i+1}: {weight.item():.4f}")
print("\n上下文向量形状:", output.shape)
print("注意力分布形状:", attn_weights.shape)
综上,代码通过类结构定义和前向传播逻辑,完整实现了图中注意力机制的核心流程,从 Q
、K
、V
交互,到注意力分数计算、分布归一化,再到上下文向量生成与输出适配,与图中展示的注意力过程一一对应,为翻译任务里解码器利用源语言语义提供了机制支持。