flash attention是一个用于加速模型训练推理的可选项,且仅适用于Turing、Ampere、Ada、Hopper架构的Nvidia GPU显卡(如H100、A100、RTX X090、T4)
1.首先检查一下GPU是否支持:FlashAttention
import torch
def supports_flash_attention(device_id: int):
"""Check if a GPU supports FlashAttention."""
major, minor = torch.cuda.get_device_capability(device_id)
# Check if the GPU

最低0.47元/天 解锁文章
2500





