flash_attn安装

我们在安装flash_attn库的时候,如果直接是使用pip install,老是容易卡住,等多久都没有成功。

这里给出两种安装方式

第一种安装方式:

访问该网站:

https://github.com/Dao-AILab/flash-attention/releases/

找到对应torch、python、cuda版本的flash_attn进行下载,并上传到服务器

#例如python3.10 torch2.4 cuda12
pip install flash_attn-2.7.4.post1+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl

可以在环境中运行如下指令判断是选择TRUE还是FALSE。

python  -c "import torch; print(torch._C._GLIBCXX_USE_CXX11_ABI)"

第二种安装方式:

同样是要访问网站:

https://github.com/Dao-AILab/flash-attention/releases/

查看要安装的版本:一般是选择最新的版本:比如是2.7.4.post1。将下面脚本中的版本号修改一下:

flash_attn_version="2.7.4.post1"

import platform
import sys

import torch


def get_cuda_version():
    if torch.cuda.is_available():
        cuda_version = torch.version.cuda
        return f"cu{cuda_version.replace('.', '')[:2]}"  # 例如:cu121
    return "cpu"


def get_torch_version():
    return f"torch{torch.__version__.split('+')[0]}"[:-2]  # 例如:torch2.2


def get_python_version():
    version = sys.version_info
    return f"cp{version.major}{version.minor}"  # 例如:cp310


def get_abi_flag():
    return "abiTRUE" if torch._C._GLIBCXX_USE_CXX11_ABI else "abiFALSE"


def get_platform():
    system = platform.system().lower()
    machine = platform.machine().lower()
    if system == "linux" and machine == "x86_64":
        return "linux_x86_64"
    elif system == "windows" and machine == "amd64":
        return "win_amd64"
    elif system == "darwin" and machine == "x86_64":
        return "macosx_x86_64"
    else:
        raise ValueError(f"Unsupported platform: {system}_{machine}")


def generate_flash_attn_filename(flash_attn_version="2.7.4.post1"):
    cuda_version = get_cuda_version()
    torch_version = get_torch_version()
    python_version = get_python_version()
    abi_flag = get_abi_flag()
    platform_tag = get_platform()

    filename = (
        f"flash_attn-{flash_attn_version}+{cuda_version}{torch_version}cxx11{abi_flag}-"
        f"{python_version}-{python_version}-{platform_tag}.whl"
    )
    return filename


if __name__ == "__main__":
    try:
        filename = generate_flash_attn_filename()
        print("Generated filename:", filename)
    except Exception as e:
        print("Error generating filename:", e)

然后在环境中运行这个脚本会自动得到你的环境对应的版本,例如:

flash_attn-2.7.4.post1+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl

然后访问网站下载对应的版本名字即可。

上传服务器,安装方式同一。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

花开明山

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值