核心功能是通过计算输入序列中不同位置之间的注意力权重,生成融合全局信息的上下文表示:
输入参数
def forward(self, hidden_states, attention_mask):
# hidden_states: [batch_size, seq_len, hidden_size]
# attention_mask: [batch_size, 1, 1, seq_len](值为0或-10000)
• hidden_states
:输入序列的隐藏状态,形状为 [batch_size, seq_len, hidden_size]
。
• attention_mask
:注意力掩码,标识填充位置(-10000
表示掩码,0
表示有效位置)。
解析
1. 生成查询(Q)、键(K)、值(V)
mixed_query_layer = self.query(hidden_states) # [N, L, hidden_size] → [N, L, all_head_size]
mixed_key_layer = self.key(hidden_states) # 同上
mixed_value_layer = self.value(hidden_states) # 同上
• 作用:通过线性变换将输入映射到查询、键、值空间。
• 参数:
• self.query
, self.key
, self.value
:线性层(nn.Linear
),输出维度为 all_head_size = num_heads * head_size
。
2. 调整维度以支持多头计算
query_layer = self.transpose_for_scores(mixed_query_layer) # [N, L, all_head_size] → [N, num_heads, L, head_size]
key_layer = self.transpose_for_scores(mixed_key_layer) # 同上
value_layer = self.transpose_for_scores(mixed_value_layer) # 同上
• transpose_for_scores
方法:
• 输入:[batch_size, seq_len, all_head_size]
。
• 操作:将 all_head_size
拆分为 [num_heads, head_size]
,并转置维度为 [batch_size, num_heads, seq_len, head_size]
。
• 目的:分割为多个注意力头,便于并行计算。
3. 计算注意力分数
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) # [N, num_heads, L, L]
attention_scores = attention_scores / math.sqrt(self.attention_head_size) # 缩放
• 点积计算:计算每个位置查询与键的相似度。
• 缩放:防止点积结果过大导致梯度不稳定。
4. 应用注意力掩码
attention_scores = attention_scores + attention_mask # 掩码填充位置(-10000 → Softmax后趋近0)
• 掩码作用:将填充位置的注意力分数设为极小值,使 Softmax 后权重趋近于零。
5. 归一化与 Dropout
attention_probs = nn.Softmax(dim=-1)(attention_scores) # 沿最后一