11,LlamaModel中前向中attention_mask的计算
casualLM的attention_mask是上三角的,只attend到前面的token,transformers中的相关代码:
class LlamaModel(LlamaPreTrainedModel):
...
def forward(...):
...
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
...
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids, ...)
比如:
input_ids:
tensor([[2, 2, 2, 3, 3, 0], [3, 3, 3, 0, 0, 0]])
attention_mask:
tensor([[ True, True, True, True, True, False],
[ True, True, True, False, False, False]])
计算出来的causal_mask:
tensor([[[[ 0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
[ 0.0000e+00, 0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, -3.4028e+38, -3.4028e+38],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, -3.4028e+38],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, -3.4028e+38]]],
[[[ 0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
[ 0.0000e+00, 0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38]]]])
这个casual_mask作为参数最终传到attention模块计算:
class LlamaAttention(nn.Module):
def forward(...):
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to

最低0.47元/天 解锁文章
917

被折叠的 条评论
为什么被折叠?



