AttnMap类

1. 类定义与作用

功能:实现一个多层的跨注意力编码器,用于处理两个序列之间的交互(如文本-文本、文本-图像对齐),并返回每一层的隐藏状态和注意力图。
应用场景:问答系统、文本匹配、跨模态检索等需要双序列交互的任务。
核心设计:堆叠多个独立的跨注意力层(CrossAttentionLayer),逐层融合两个序列的信息。


2. 初始化方法 (__init__)

def __init__(self, config, layer_num):
    super(AttnMap, self).__init__()
    layer = CrossAttentionLayer(config)  # 创建基础跨注意力层
    self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(layer_num)])  # 深拷贝多层
关键组件
  1. 基础层创建
    CrossAttentionLayer(config):定义单个跨注意力层,其内部应包含:
    交叉注意力机制:计算序列1(s1)对序列2(s2)的注意力。
    残差连接与层归一化:稳定训练过程。
    前馈网络:增强非线性表达能力。

  2. 多层堆叠
    • 使用 copy.deepcopy 复制基础层,确保每层参数独立(避免共享权重)。
    • 通过 nn.ModuleList 管理多个层,使PyTorch能正确追踪梯度。

参数说明

config:模型配置(隐藏层维度、注意力头数等)。
layer_num:跨注意力层的堆叠数量。


3. 前向传播 (forward方法)

def forward(self, s1_hidden_states, s2_hidden_states, s2_attention_mask, output_all_encoded_layers=True):
    all_encoder_layers = []  # 存储所有层的s1隐藏状态
    all_attn_maps = []       # 存储所有层的注意力图
    for layer_module in self.layer:
        s1_hidden_states, attn_map = layer_module(s1_hidden_states, s2_hidden_states, s2_attention_mask)
        if output_all_encoded_layers:
            all_encoder_layers.append(s1_hidden_states)
            all_attn_maps.append(attn_map)
    if not output_all_encoded_layers:
        all_encoder_layers.append(s1_hidden_states)
        all_attn_maps.append(attn_map)
    return all_encoder_layers, all_attn_maps
输入参数

s1_hidden_states:主序列的隐藏状态(如问题或查询),形状为 [batch_size, seq_len_s1, hidden_size]
• **s2_hidden_states``**:参考序列的隐藏状态(如文档或候选内容),形状为 [batch_size, seq_len_s2, hidden_size]`。

s2_attention_mask:参考序列的注意力掩码,标识有效内容(如1为有效,0为填充),形状为 [batch_size, 1, 1, seq_len_s2]
output_all_encoded_layers:是否输出所有层的中间结果(默认True)。

处理流程
  1. 逐层处理
    • 每层接受当前的 s1_hidden_states,计算其对 s2_hidden_states 的交叉注意力。
    • 更新后的 s1_hidden_states 传递到下一层,逐步融合 s2 的信息。
    • 每层生成注意力图 attn_map,形状为 [batch_size, num_heads, seq_len_s1, seq_len_s2]

  2. 结果收集
    output_all_encoded_layers=True:保存每一层的输出和注意力图。
    output_all_encoded_layers=False:仅保存最后一层的结果。

输出

all_encoder_layers:各层更新后的 s1_hidden_states 列表。
all_attn_maps:各层的交叉注意力图列表。


4. 关键设计细节

组件作用
深拷贝跨注意力层确保各层参数独立,避免梯度共享导致模型退化。
逐层更新s1_hidden_states通过多层堆叠逐步细化主序列的表示,增强跨序列交互能力。
注意力图保存支持可视化分析(如对齐热力图)或后续任务(如基于注意力的特征选择)。

5. 与标准BERT编码器的差异

特性BertCrossEncoder_AttnMap标准BERT编码器 (BertEncoder)
注意力类型跨序列注意力(s1s2的交叉注意力)自注意力(序列内部注意力)
输入两个序列的隐藏状态单个序列的隐藏状态
输出主序列的更新状态 + 交叉注意力图同一序列的更新状态

6. 示例应用流程

假设任务为问答匹配(Question-Answer Matching):

  1. 输入
    s1_hidden_states:问题的BERT编码结果,形状 [2, 32, 768]
    s2_hidden_states:候选答案的BERT编码结果,形状 [2, 128, 768]
    s2_attention_mask:屏蔽答案中的填充部分。
  2. 处理
    • 经过3层跨注意力编码,每层计算问题对答案的注意力,更新问题的表示。
  3. 输出
    all_encoder_layers:3层更新后的问题表示,形状均为 [2, 32, 768]
    all_attn_maps:3层的注意力图,形状为 [2, 12, 32, 128](假设12个注意力头)。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值