RNN中的注意力机制代码实现


一、整体功能

本文介绍了一个基于 RNN 的注意力机制实现,具体应用于机器翻译任务中。该机制通过动态聚焦源语言(如“欢迎、来、北京”)的不同词,辅助解码器(如预测“to”)生成更准确的目标语言词汇。

RNN中的注意力机制详解网页链接

核心流程:

  1. 编码器(Encoder):将源语言句子转换为一系列隐状态(h1, h2, ..., hn)。
  2. 解码器(Decoder):在生成目标语言词汇时,利用注意力机制关注源语言中的关键信息。
  3. 注意力机制:计算解码器当前状态与编码器各隐状态的相关性(注意力分数),并加权得到上下文向量,最终融合解码器状态和上下文向量进行输出。

示意图解析:

注意力机制示意图

  • 左侧(蓝色框):编码器处理源语言(“欢迎、来、北京”)的过程。
  • 中间(绿色框):注意力机制的具体计算步骤。
  • 右侧(粉色框):解码器根据注意力结果生成目标语言(“Welcome to Beijing”)的过程。

二、类初始化

def __init__(self, query_dim, key_dim, value_seq_len, value_dim, output_dim):
    super().__init__()
    # 注意力权重计算层 - 对应图中 “神经网络线性变换(Q拼K后)” 等操作,计算注意力分数
    self.attention_weights = nn.Linear(query_dim + key_dim, value_seq_len)
    # 输出映射层 - 对应图中 后续“神经网络线性变换(Q拼接C后)”,将融合后的特征映射到指定输出维度
    self.output_projection = nn.Linear(query_dim + value_dim, output_dim)
  • self.attention_weights:对应图里对拼接后的 Q(解码器上一步输出语义,如 Welcome 的词向量 )和 K(编码器隐层输出,如 h1 等 )进行线性变换的操作,用于计算注意力分数,决定对源语言各词的关注程度。
  • self.output_projection:对应图中注意力输出与 Q 再次拼接后,进行线性变换以适配解码器输入的环节,整合信息并转换维度。

三、前向传播流程(forward 函数)

3.1 扩展 query 以便拼接(对应图中 QK 交互准备)

# 1. 扩展query以便与key拼接
batch_size = query.size(0)
query_expanded = query.expand(-1, value.size(1), -1)  # [batch, seq_len, query_dim]
  • 作用:query 初始是解码器当前时间步状态(如形状 [1, 1, 3] ),要和长度为 source_seq_len(源语言词数量,如 3 )的 key 拼接,需扩展 query 维度,让其在序列长度维度(第 2 维 )与 key 一致,为后续拼接 QK 做准备,对应图中 Q 参与和 K 交互的前置处理。

3.2 拼接 QK 并计算注意力分数(对应图中 QK 与线性变换)

# 2. 拼接Q和K并计算注意力分数
qk_concat = torch.cat([query_expanded, key], dim=2)  # [batch, seq_len, query_dim+key_dim]
attention_scores = self.attention_weights(qk_concat)  # [batch, seq_len, seq_len]
  • 操作解析:
    • torch.cat([query_expanded, key], dim=2):把扩展后的 Q(解码器状态,代表目标语言已生成部分语义 )和 K(编码器隐层输出,代表源语言各词语义 )在特征维度(第 3 维 )拼接,对应图里 QK 的拼接操作,融合二者初始特征。
    • self.attention_weights(qk_concat):用线性层对拼接后的特征进行变换,输出注意力分数。此时输出形状 [batch, seq_len, seq_len] ,可理解为计算出 QK 各位置间的关联分数,对应图中 “神经网络线性变换” 环节,初步得到注意力相关数值。

3.3 提取对角元素并获注意力分布(对应图中 Softmax 层等)

# 3. 提取对角元素作为位置注意力分数
attention_scores = attention_scores.diagonal(dim1=1, dim2=2).unsqueeze(1)  # [batch, 1, seq_len]
# 4. 应用softmax获得注意力分布
attention_dist = F.softmax(attention_scores, dim=-1)  # [batch, 1, seq_len]
  • 操作解析:
    • attention_scores.diagonal(dim1=1, dim2=2).unsqueeze(1):从 attention_scores 中提取对角元素,是因为我们关注 QK 对应位置(如解码器当前状态与源语言各词位置匹配 )的关联,提取后形状变为 [batch, 1, seq_len] ,对应聚焦源语言每个词与当前解码器状态的关联。
    • F.softmax(attention_scores, dim=-1):对提取后的分数做 Softmax 归一化,得到注意力分布 attention_dist ,值在 0 - 1 之间且和为 1 ,对应图中 “Softmax 层” ,表示解码器预测当前词时,对源语言各词的关注权重,比如运行结果里对 “欢迎、来、北京” 的权重分布 [0.3784, 0.0705, 0.5511]

3.4 计算上下文向量(对应图中注意力输出与 V 交互)

# 5. 计算上下文向量(注意力加权的值)
context_vector = torch.bmm(attention_dist, value)  # [batch, 1, value_dim]
  • 作用:value 是编码器最后一层输出(如形状 [1, 3, 3] ,代表 “欢迎、来、北京” 各自的编码语义 ),用注意力分布 attention_distvalue 加权求和(bmm 是批量矩阵乘法 ),得到 context_vector ,融合源语言各词语义且突出注意力权重高的词,对应图中注意力输出与 V(编码器隐层输出 )的交互,提取对生成当前目标词重要的源语言语义。

3.5 合并与投影(对应图中 Q 拼接 C 及后续线性变换)

# 6. 合并上下文向量和查询向量并投影到输出维度
combined = torch.cat([query, context_vector], dim=2)  # [batch, 1, query_dim+value_dim]
output = self.output_projection(combined)  # [batch, 1, output_dim]
  • 操作解析:
    • torch.cat([query, context_vector], dim=2):把解码器原始状态 query(代表目标语言已生成部分的语义 )和上下文向量 context_vector(融合源语言关键语义 )在特征维度拼接,对应图中 Q 与注意力输出(如 C )的拼接,整合目标语言已有信息和源语言关键信息。
    • self.output_projection(combined):用线性层对拼接后的特征变换,输出 output ,将维度转换为 output_dim(如示例中为 3 ),对应图中最后 “神经网络线性变换” ,得到可输入解码器 GRU 层的向量,辅助预测下一个目标词(如 to )。

四、主函数

if __name__ == '__main__':
    # 设置随机种子以便结果可复现
    torch.manual_seed(44)

    # 模拟"欢迎来北京"翻译为"Welcome to Beijing"的场景
    source_seq_len = 3  # 源语言序列长度
    hidden_dim = 3      # 隐藏层维度

    # 随机生成编码器输出(值向量V)
    encoder_outputs = torch.randn(1, source_seq_len, hidden_dim)

    # 随机生成解码器当前状态(查询向量Q)
    decoder_state = torch.randn(1, 1, hidden_dim)

    # 随机生成编码器键向量(K)
    encoder_keys = torch.randn(1, source_seq_len, hidden_dim)

    # 创建注意力模型
    attention = TranslationAttention(
        query_dim=hidden_dim,
        key_dim=hidden_dim,
        value_seq_len=source_seq_len,
        value_dim=hidden_dim,
        output_dim=hidden_dim
    )

    # 计算注意力
    output, attn_weights = attention(decoder_state, encoder_keys, encoder_outputs)

    # 打印结果
    print("注意力权重分布(表示解码器在预测'to'时对源语言各词的关注程度):")
    for i, weight in enumerate(attn_weights[0, 0]):
        print(f"词 {i+1}: {weight.item():.4f}")
    
    print("\n上下文向量形状:", output.shape)
    print("注意力分布形状:", attn_weights.shape)
  • 这里通过随机生成数据模拟编码器、解码器的中间状态(如 encoder_outputs 对应图中编码器最后一层输出 Vdecoder_state 对应 Q 等 ),调用模型后得到的 attn_weights 就是图里注意力机制计算出的对源语言各词的关注权重,context 是整合后的上下文向量,完整模拟了图中从编码器输出、解码器状态输入,到注意力计算、结果输出的流程。

五、完整代码

import torch
import torch.nn as nn
import torch.nn.functional as F


class TranslationAttention(nn.Module):
    """
    翻译任务中的注意力机制,用于捕捉源语言和目标语言之间的语义关联
    """

    def __init__(self, query_dim, key_dim, value_seq_len, value_dim, output_dim):
        """
        初始化注意力机制模型。

        参数:
            query_dim (int): 查询向量(Q)的维度,通常为解码器隐藏状态的维度。
            key_dim (int): 键向量(K)的维度,通常为编码器隐藏状态的维度。
            value_seq_len (int): 值序列(V)的长度,即源语言句子的长度。
            value_dim (int): 值向量(V)的维度,通常与键向量维度相同。
            output_dim (int): 输出向量的维度,通常与解码器输入维度相同。
        """
        super().__init__()
        self.attention_weights = nn.Linear(query_dim + key_dim, value_seq_len)
        self.output_projection = nn.Linear(query_dim + value_dim, output_dim)

    def forward(self, query, key, value):
        """
        前向传播过程。

        参数:
            query: 当前解码器状态 [batch_size, 1, query_dim]
            key: 编码器键向量 [batch_size, seq_len, key_dim]
            value: 编码器值向量 [batch_size, seq_len, value_dim]

        返回:
            output: 最终输出向量 [batch_size, 1, output_dim]
            attention_dist: 注意力分布 [batch_size, 1, seq_len]
        """
        # 1. 扩展query以便与key拼接
        batch_size = query.size(0)
        query_expanded = query.expand(-1, value.size(1), -1)  # [batch, seq_len, query_dim]

        # 2. 拼接Q和K并计算注意力分数
        qk_concat = torch.cat([query_expanded, key], dim=2)  # [batch, seq_len, query_dim+key_dim]
        attention_scores = self.attention_weights(qk_concat)  # [batch, seq_len, seq_len]

        # 3. 提取对角元素作为位置注意力分数
        attention_scores = attention_scores.diagonal(dim1=1, dim2=2).unsqueeze(1)  # [batch, 1, seq_len]

        # 4. 应用softmax获得注意力分布
        attention_dist = F.softmax(attention_scores, dim=-1)  # [batch, 1, seq_len]

        # 5. 计算上下文向量(注意力加权的值)
        context_vector = torch.bmm(attention_dist, value)  # [batch, 1, value_dim]

        # 6. 合并上下文向量和查询向量并投影到输出维度
        combined = torch.cat([query, context_vector], dim=2)  # [batch, 1, query_dim+value_dim]
        output = self.output_projection(combined)  # [batch, 1, output_dim]

        return output, attention_dist


if __name__ == '__main__':
    # 设置随机种子以便结果可复现
    torch.manual_seed(44)

    # 模拟"欢迎来北京"翻译为"Welcome to Beijing"的场景
    source_seq_len = 3  # 源语言序列长度
    hidden_dim = 3      # 隐藏层维度

    # 随机生成编码器输出(值向量V)
    encoder_outputs = torch.randn(1, source_seq_len, hidden_dim)

    # 随机生成解码器当前状态(查询向量Q)
    decoder_state = torch.randn(1, 1, hidden_dim)

    # 随机生成编码器键向量(K)
    encoder_keys = torch.randn(1, source_seq_len, hidden_dim)

    # 创建注意力模型
    attention = TranslationAttention(
        query_dim=hidden_dim,
        key_dim=hidden_dim,
        value_seq_len=source_seq_len,
        value_dim=hidden_dim,
        output_dim=hidden_dim
    )

    # 计算注意力
    output, attn_weights = attention(decoder_state, encoder_keys, encoder_outputs)

    # 打印结果
    print("注意力权重分布(表示解码器在预测'to'时对源语言各词的关注程度):")
    for i, weight in enumerate(attn_weights[0, 0]):
        print(f"词 {i+1}: {weight.item():.4f}")
    
    print("\n上下文向量形状:", output.shape)
    print("注意力分布形状:", attn_weights.shape)

综上,代码通过类结构定义和前向传播逻辑,完整实现了图中注意力机制的核心流程,从 QKV 交互,到注意力分数计算、分布归一化,再到上下文向量生成与输出适配,与图中展示的注意力过程一一对应,为翻译任务里解码器利用源语言语义提供了机制支持。


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值