1. 类定义与作用
• 功能:实现一个多层的跨注意力编码器,用于处理两个序列之间的交互(如文本-文本、文本-图像对齐),并返回每一层的隐藏状态和注意力图。
• 应用场景:问答系统、文本匹配、跨模态检索等需要双序列交互的任务。
• 核心设计:堆叠多个独立的跨注意力层(CrossAttentionLayer
),逐层融合两个序列的信息。
2. 初始化方法 (__init__
)
def __init__(self, config, layer_num):
super(AttnMap, self).__init__()
layer = CrossAttentionLayer(config) # 创建基础跨注意力层
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(layer_num)]) # 深拷贝多层
关键组件
-
基础层创建:
•CrossAttentionLayer(config)
:定义单个跨注意力层,其内部应包含:
◦ 交叉注意力机制:计算序列1(s1
)对序列2(s2
)的注意力。
◦ 残差连接与层归一化:稳定训练过程。
◦ 前馈网络:增强非线性表达能力。 -
多层堆叠:
• 使用copy.deepcopy
复制基础层,确保每层参数独立(避免共享权重)。
• 通过nn.ModuleList
管理多个层,使PyTorch能正确追踪梯度。
参数说明
• config
:模型配置(隐藏层维度、注意力头数等)。
• layer_num
:跨注意力层的堆叠数量。
3. 前向传播 (forward
方法)
def forward(self, s1_hidden_states, s2_hidden_states, s2_attention_mask, output_all_encoded_layers=True):
all_encoder_layers = [] # 存储所有层的s1隐藏状态
all_attn_maps = [] # 存储所有层的注意力图
for layer_module in self.layer:
s1_hidden_states, attn_map = layer_module(s1_hidden_states, s2_hidden_states, s2_attention_mask)
if output_all_encoded_layers:
all_encoder_layers.append(s1_hidden_states)
all_attn_maps.append(attn_map)
if not output_all_encoded_layers:
all_encoder_layers.append(s1_hidden_states)
all_attn_maps.append(attn_map)
return all_encoder_layers, all_attn_maps
输入参数
• s1_hidden_states
:主序列的隐藏状态(如问题或查询),形状为 [batch_size, seq_len_s1, hidden_size]
。
• **s2_hidden_states``**:参考序列的隐藏状态(如文档或候选内容),形状为
[batch_size, seq_len_s2, hidden_size]`。
• s2_attention_mask
:参考序列的注意力掩码,标识有效内容(如1
为有效,0
为填充),形状为 [batch_size, 1, 1, seq_len_s2]
。
• output_all_encoded_layers
:是否输出所有层的中间结果(默认True)。
处理流程
-
逐层处理:
• 每层接受当前的s1_hidden_states
,计算其对s2_hidden_states
的交叉注意力。
• 更新后的s1_hidden_states
传递到下一层,逐步融合s2
的信息。
• 每层生成注意力图attn_map
,形状为[batch_size, num_heads, seq_len_s1, seq_len_s2]
。 -
结果收集:
•output_all_encoded_layers=True
:保存每一层的输出和注意力图。
•output_all_encoded_layers=False
:仅保存最后一层的结果。
输出
• all_encoder_layers
:各层更新后的 s1_hidden_states
列表。
• all_attn_maps
:各层的交叉注意力图列表。
4. 关键设计细节
组件 | 作用 |
---|---|
深拷贝跨注意力层 | 确保各层参数独立,避免梯度共享导致模型退化。 |
逐层更新s1_hidden_states | 通过多层堆叠逐步细化主序列的表示,增强跨序列交互能力。 |
注意力图保存 | 支持可视化分析(如对齐热力图)或后续任务(如基于注意力的特征选择)。 |
5. 与标准BERT编码器的差异
特性 | BertCrossEncoder_AttnMap | 标准BERT编码器 (BertEncoder ) |
---|---|---|
注意力类型 | 跨序列注意力(s1 对s2 的交叉注意力) | 自注意力(序列内部注意力) |
输入 | 两个序列的隐藏状态 | 单个序列的隐藏状态 |
输出 | 主序列的更新状态 + 交叉注意力图 | 同一序列的更新状态 |
6. 示例应用流程
假设任务为问答匹配(Question-Answer Matching):
- 输入:
•s1_hidden_states
:问题的BERT编码结果,形状[2, 32, 768]
。
•s2_hidden_states
:候选答案的BERT编码结果,形状[2, 128, 768]
。
•s2_attention_mask
:屏蔽答案中的填充部分。 - 处理:
• 经过3层跨注意力编码,每层计算问题对答案的注意力,更新问题的表示。 - 输出:
•all_encoder_layers
:3层更新后的问题表示,形状均为[2, 32, 768]
。
•all_attn_maps
:3层的注意力图,形状为[2, 12, 32, 128]
(假设12个注意力头)。