torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,is_causal=False, scale=None, enable_gqa=False) -> Tensor:
参数:
-
query (Tensor) – Query tensor; shape (batch_size,...,head_size,token_size,embeding_size)
-
key (Tensor) – Key tensor; shape (batch_size,...,head_size,token_size,embeding_size)
-
value (Tensor) – Value tensor; shape (batch_size,...,head_size,token_size,embeding_size)
-
attn_mask (optional Tensor) – Attention mask
scaled_dot_product_attention实现逻辑解析

最低0.47元/天 解锁文章
5092

被折叠的 条评论
为什么被折叠?



