flash-linear-attention的fused_recurrent_rwkv6 Triton实现精读

0x0. 前言

继续补 在GPU上加速RWKV6模型的Linear Attention计算 没有写完的内容,对flash-linear-attention库(https://github.com/sustcsonglin/flash-linear-attention)中的fused_recurrent_rwkv6和chunk_rwkv6的前向实现进行解析,也是对Triton写cuda kernel进行继续学习。这里先解读一下fused_recurrent_rwkv6的实现,chunk_rwkv6的实现后续随缘说。

0x1. fused_recurrent_rwkv6 naive python实现

还是从naive的python实现看起,https://github.com/sustcsonglin/flash-linear-attention/blob/main/fla/ops/rwkv6/recurrent_naive.py 。fused_recurrent_rwkv6计算算法对应下面的基础python流程:

def naive_recurrent_rwkv6(
    q,
    k,
    v,
    w,
    u,
    initial_state=None,
    output_final_state=False
):
    # 记录输入张量 q 的原始数据类型。
    orig_dtype = q.dtype
    # 将输入张量转换为 32 位浮点数类型。
    q, k, v, w, u = map(lambda x: x.float(), (q, k, v, w, u))
    # 获取query张量的形状信息。
    batch_size, n_heads, seq_len, d_head_k = q.shape
    # 获取值张量的形状信息。
    _, _, _, d_head_v = v.shape
    # 初始化注意力张量为全零张量,形状为 (B, H, D, D),在 GPU 上进行计算。
    h = torch.zeros(batch_size, n_heads, d_head_k, d_head_v, dtype=torch.float32, device=q.device)
    # 初始化输出张量为全零张量,形状同值张量 v
    o = torch.zeros_like(v)

    # 如果提供了初始状态 initial_state,则将注意力张量 h 更新为初始状态:
    if initial_state is not None:
        h += initial_state

    # 对序列长度进行迭代,每次迭代处理一个位置的输入:
    for i in range(seq_len):
        q_i = q[:, :, i, :] # 获取当前位置的query张量。shape为[B, H, D]
        k_i = k[:, :, i] # 获取当前位置的key张量。shape为[B, H, D]
        v_i = v[:, :, i, :] # 获取当前位置的value张量。shape为[B, H, D]
        # 获取当前位置的权重张量,并使用指数函数进行处理。shape为[B, H, D]
        w_i = w[:, :, i].exp()
        # 计算当前位置的键值乘积,elementwise操作。
        # shape变化为[B, H, D, 1] * [B, H, D, 1] -> [B, H, D, 1]
        kv_i = k_i[..., None] * v_i[..., None, :] 
        # 计算当前位置的注意力加权输出,都是elementwise操作。
        # h的shape为[B, H, D, D]
        # u[None, ..., None]的shape为[1, H, D, 1]
        # q_i[..., None]的shape为[B, H, D, 1]
        # h + u[None, ..., None] * kv_i 的shape为:
        # [B, H, D, D] + [1, H, D, 1] * [B, H, D, 1] ->
        # [B, H, D, D] + [B, H, D, 1] ->
        # [B, H, D, D]
        o_i = (h + u[None, ..., None] * kv_i) * q_i[..., None] 
        # 将当前位置的输出加入到输出张量中。
        # o[:, :, i]的shape为[B, H, D],o_i.sum(-2)的shape为[B, H, D]
        o[:, :, i] = o_i.sum(-2)
        # 更新注意力张量 h
        # h的shape为[B, H, D, D]
        # w_i[..., None]的shape为[B, H, D, 1]
        # kv_i的shape为[B, H, D, 1]
        # h * w_i[..., None] 的shape为[B, H, D, D]也是element-wise操作
        h = h * w_i[..., None] + kv_i
    return o.to(orig_dtype)

q, k, v, w, u等定义如下:

B = 4 # 批量大小(batch size)为 4。
H = 4 # 头数(number of heads)为 4。
L = 1024 # 序列长度(sequence length)为 1024。
D = 100 # 每个头的维度(dimension)为 100。
dtype = torch.float32 # 定义了张量的数据类型为 32 位浮点数。
# q, k, v 分别是查询(query)、键(key)、值(value)的张量,形状为 (B, H, L, D),
# 使用随机初始化,并且在 GPU 上进行计算。
q = (torch.randn(B, H, L, D).cuda().to(dtype)).requires_grad_(True)
k = (torch.randn(B, H, L, D).cuda().to(dtype)).requires_grad_(True)
v = torch.randn(B, H, L, D).cuda().to(dtype)
Building wheel for flash-attn (setup.py) ... error error: subprocess-exited-with-error × python setup.py bdist_wheel did not run successfully. │ exit code: 1 ╰─> [173 lines of output] fatal: not a git repository (or any of the parent directories): .git torch.__version__ = 2.6.0+cu118 /usr/local/miniconda3/envs/deepseek-ocr/lib/python3.12/site-packages/setuptools/__init__.py:92: _DeprecatedInstaller: setuptools.installer and fetch_build_eggs are deprecated. !! ******************************************************************************** Requirements should be satisfied by a PEP 517 installer. If you are using pip, you can try `pip install --use-pep517`. ******************************************************************************** !! dist.fetch_build_eggs(dist.setup_requires) /usr/local/miniconda3/envs/deepseek-ocr/lib/python3.12/site-packages/setuptools/dist.py:759: SetuptoolsDeprecationWarning: License classifiers are deprecated. !! ******************************************************************************** Please consider removing the following classifiers in favor of a SPDX license expression: License :: OSI Approved :: BSD License See https://packaging.python.org/en/latest/guides/writing-pyproject-toml/#license for details. ******************************************************************************** !! self._finalize_license_expression() running bdist_wheel Guessing wheel URL: https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.3/flash_attn-2.7.3+cu11torch2.6cxx11abiFALSE-cp312-cp312-linux_x86_64.whl Precompiled wheel not found. Building from source... running build running build_py creating build/lib.linux-x86_64-cpython-312/hopper copying hopper/flash_attn_interface.py -> build/lib.linux-x86_64-cpython-312/hopper copying hopper/benchmark_flash_attention_fp8.py -> build/lib.linux-x86_64-cpython-312/hopper copying hopper/test_util.py -> build/lib.linux-x86_64-cpython-312/hopper copying hopper/generate_kernels.py -> build/lib.linux-x86_64-cpython-312/hopper copying hopper/test_attn_kvcache.py -> build/lib.linux-x86_64-cpython-312/hopper copying hopper/__init__.py -> build/lib.linux-x86_64-cpython-312/hopper copying hopper/benchmark_attn.py -> build/lib.linux-x86_64-cpython-312/hopper copying hopper/test_kvcache.py -> build/lib.linux-x86_64-cpython-312/hopper copying hopper/benchmark_split_kv.py -> build/lib.linux-x86_64-cpython-312/hopper copying hopper/padding.py -> build/lib.linux-x86_64-cpython-312/hopper copying hopper/setup.py -> build/lib.linux-x86_64-cpython-312/hopper copying hopper/test_flash_attn.py -> build/lib.linux-x86_64-cpython-312/hopper creating build/lib.linux-x86_64-cpython-312/flash_attn copying flash_attn/bert_padding.py -> build/lib.linux-x86_64-cpython-312/flash_attn copying flash_attn/flash_attn_interface.py -> build/lib.linux-x86_64-cpython-312/flash_attn copying flash_attn/flash_blocksparse_attn_interface.py -> build/lib.linux-x86_64-cpython-312/flash_attn copying flash_attn/flash_attn_triton.py -> build/lib.linux-x86_64-cpython-312/flash_attn copying flash_attn/flash_blocksparse_attention.py -> build/lib.linux-x86_64-cpython-312/flash_attn copying flash_attn/__init__.py -> build/lib.linux-x86_64-cpython-312/flash_attn copying flash_attn/fused_softmax.py -> build/lib.linux-x86_64-cpython-312/flash_attn copying flash_attn/flash_attn_triton_og.py -> build/lib.linux-x86_64-cpython-312/flash_attn creating build/lib.linux-x86_64-cpython-312/flash_attn/utils copying flash_attn/utils/pretrained.py -> build/lib.linux-x86_64-cpython-312/flash_attn/utils copying flash_attn/utils/distributed.py -> build/lib.linux-x86_64-cpython-312/flash_attn/utils copying flash_attn/utils/__init__.py -> build/lib.linux-x86_64-cpython-312/flash_attn/utils copying flash_attn/utils/benchmark.py -> build/lib.linux-x86_64-cpython-312/flash_attn/utils copying flash_attn/utils/generation.py -> build/lib.linux-x86_64-cpython-312/flash_attn/utils creating build/lib.linux-x86_64-cpython-312/flash_attn/models copying flash_attn/models/btlm.py -> build/lib.linux-x86_64-cpython-312/flash_attn/models copying flash_attn/models/vit.py -> build/lib.linux-x86_64-cpython-312/flash_attn/models copying flash_attn/models/falcon.py -> build/lib.linux-x86_64-cpython-312/flash_attn/models copying flash_attn/models/bigcode.py -> build/lib.linux-x86_64-cpython-312/flash_attn/models copying flash_attn/models/llama.py -> build/lib.linux-x86_64-cpython-312/flash_attn/models copying flash_attn/models/bert.py -> build/lib.linux-x86_64-cpython-312/flash_attn/models copying flash_attn/models/__init__.py -> build/lib.linux-x86_64-cpython-312/flash_attn/models copying flash_attn/models/gpt.py -> build/lib.linux-x86_64-cpython-312/flash_attn/models copying flash_attn/models/baichuan.py -> build/lib.linux-x86_64-cpython-312/flash_attn/models copying flash_attn/models/opt.py -> build/lib.linux-x86_64-cpython-312/flash_attn/models copying flash_attn/models/gptj.py -> build/lib.linux-x86_64-cpython-312/flash_attn/models copying flash_attn/models/gpt_neox.py -> build/lib.linux-x86_64-cpython-312/flash_attn/models creating build/lib.linux-x86_64-cpython-312/flash_attn/flash_attn_triton_amd copying flash_attn/flash_attn_triton_amd/bwd_prefill.py -> build/lib.linux-x86_64-cpython-312/flash_attn/flash_attn_triton_amd copying flash_attn/flash_attn_triton_amd/interface_fa.py -> build/lib.linux-x86_64-cpython-312/flash_attn/flash_attn_triton_amd copying flash_attn/flash_attn_triton_amd/utils.py -> build/lib.linux-x86_64-cpython-312/flash_attn/flash_attn_triton_amd copying flash_attn/flash_attn_triton_amd/bwd_ref.py -> build/lib.linux-x86_64-cpython-312/flash_attn/flash_attn_triton_amd copying flash_attn/flash_attn_triton_amd/fwd_ref.py -> build/lib.linux-x86_64-cpython-312/flash_attn/flash_attn_triton_amd copying flash_attn/flash_attn_triton_amd/interface_torch.py -> build/lib.linux-x86_64-cpython-312/flash_attn/flash_attn_triton_amd copying flash_attn/flash_attn_triton_amd/test.py -> build/lib.linux-x86_64-cpython-312/flash_attn/flash_attn_triton_amd copying flash_attn/flash_attn_triton_amd/__init__.py -> build/lib.linux-x86_64-cpython-312/flash_attn/flash_attn_triton_amd copying flash_attn/flash_attn_triton_amd/fwd_prefill.py -> build/lib.linux-x86_64-cpython-312/flash_attn/flash_attn_triton_amd copying flash_attn/flash_attn_triton_amd/fwd_decode.py -> build/lib.linux-x86_64-cpython-312/flash_attn/flash_attn_triton_amd copying flash_attn/flash_attn_triton_amd/bench.py -> build/lib.linux-x86_64-cpython-312/flash_attn/flash_attn_triton_amd creating build/lib.linux-x86_64-cpython-312/flash_attn/layers copying flash_attn/layers/patch_embed.py -> build/lib.linux-x86_64-cpython-312/flash_attn/layers copying flash_attn/layers/rotary.py -> build/lib.linux-x86_64-cpython-312/flash_attn/layers copying flash_attn/layers/__init__.py -> build/lib.linux-x86_64-cpython-312/flash_attn/layers creating build/lib.linux-x86_64-cpython-312/flash_attn/ops copying flash_attn/ops/rms_norm.py -> build/lib.linux-x86_64-cpython-312/flash_attn/ops copying flash_attn/ops/layer_norm.py -> build/lib.linux-x86_64-cpython-312/flash_attn/ops copying flash_attn/ops/__init__.py -> build/lib.linux-x86_64-cpython-312/flash_attn/ops copying flash_attn/ops/activations.py -> build/lib.linux-x86_64-cpython-312/flash_attn/ops copying flash_attn/ops/fused_dense.py -> build/lib.linux-x86_64-cpython-312/flash_attn/ops creating build/lib.linux-x86_64-cpython-312/flash_attn/modules copying flash_attn/modules/mlp.py -> build/lib.linux-x86_64-cpython-312/flash_attn/modules copying flash_attn/modules/__init__.py -> build/lib.linux-x86_64-cpython-312/flash_attn/modules copying flash_attn/modules/mha.py -> build/lib.linux-x86_64-cpython-312/flash_attn/modules copying flash_attn/modules/embedding.py -> build/lib.linux-x86_64-cpython-312/flash_attn/modules copying flash_attn/modules/block.py -> build/lib.linux-x86_64-cpython-312/flash_attn/modules creating build/lib.linux-x86_64-cpython-312/flash_attn/losses copying flash_attn/losses/__init__.py -> build/lib.linux-x86_64-cpython-312/flash_attn/losses copying flash_attn/losses/cross_entropy.py -> build/lib.linux-x86_64-cpython-312/flash_attn/losses creating build/lib.linux-x86_64-cpython-312/flash_attn/ops/triton copying flash_attn/ops/triton/mlp.py -> build/lib.linux-x86_64-cpython-312/flash_attn/ops/triton copying flash_attn/ops/triton/layer_norm.py -> build/lib.linux-x86_64-cpython-312/flash_attn/ops/triton copying flash_attn/ops/triton/rotary.py -> build/lib.linux-x86_64-cpython-312/flash_attn/ops/triton copying flash_attn/ops/triton/__init__.py -> build/lib.linux-x86_64-cpython-312/flash_attn/ops/triton copying flash_attn/ops/triton/linear.py -> build/lib.linux-x86_64-cpython-312/flash_attn/ops/triton copying flash_attn/ops/triton/cross_entropy.py -> build/lib.linux-x86_64-cpython-312/flash_attn/ops/triton copying flash_attn/ops/triton/k_activations.py -> build/lib.linux-x86_64-cpython-312/flash_attn/ops/triton running build_ext Traceback (most recent call last): File "/tmp/pip-install-_i95ci_5/flash-attn_a353fee249ee4821b5f73b0e34c2ef7f/setup.py", line 478, in run urllib.request.urlretrieve(wheel_url, wheel_filename) File "/usr/local/miniconda3/envs/deepseek-ocr/lib/python3.12/urllib/request.py", line 276, in urlretrieve raise ContentTooShortError( urllib.error.ContentTooShortError: <urlopen error retrieval incomplete: got only 146800640 out of 193185550 bytes> During handling of the above exception, another exception occurred: Traceback (most recent call last): File "<string>", line 2, in <module> File "<pip-setuptools-caller>", line 35, in <module> File "/tmp/pip-install-_i95ci_5/flash-attn_a353fee249ee4821b5f73b0e34c2ef7f/setup.py", line 518, in <module> setup( File "/usr/local/miniconda3/envs/deepseek-ocr/lib/python3.12/site-packages/setuptools/__init__.py", line 115, in setup return distutils.core.setup(**attrs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/miniconda3/envs/deepseek-ocr/lib/python3.12/site-packages/setuptools/_distutils/core.py", line 186, in setup return run_commands(dist) ^^^^^^^^^^^^^^^^^^ File "/usr/local/miniconda3/envs/deepseek-ocr/lib/python3.12/site-packages/setuptools/_distutils/core.py", line 202, in run_commands dist.run_commands() File "/usr/local/miniconda3/envs/deepseek-ocr/lib/python3.12/site-packages/setuptools/_distutils/dist.py", line 1002, in run_commands self.run_command(cmd) File "/usr/local/miniconda3/envs/deepseek-ocr/lib/python3.12/site-packages/setuptools/dist.py", line 1102, in run_command super().run_command(command) File "/usr/local/miniconda3/envs/deepseek-ocr/lib/python3.12/site-packages/setuptools/_distutils/dist.py", line 1021, in run_command cmd_obj.run() File "/tmp/pip-install-_i95ci_5/flash-attn_a353fee249ee4821b5f73b0e34c2ef7f/setup.py", line 495, in run super().run() File "/usr/local/miniconda3/envs/deepseek-ocr/lib/python3.12/site-packages/setuptools/command/bdist_wheel.py", line 370, in run self.run_command("build") File "/usr/local/miniconda3/envs/deepseek-ocr/lib/python3.12/site-packages/setuptools/_distutils/cmd.py", line 357, in run_command self.distribution.run_command(command) File "/usr/local/miniconda3/envs/deepseek-ocr/lib/python3.12/site-packages/setuptools/dist.py", line 1102, in run_command super().run_command(command) File "/usr/local/miniconda3/envs/deepseek-ocr/lib/python3.12/site-packages/setuptools/_distutils/dist.py", line 1021, in run_command cmd_obj.run() File "/usr/local/miniconda3/envs/deepseek-ocr/lib/python3.12/site-packages/setuptools/_distutils/command/build.py", line 135, in run self.run_command(cmd_name) File "/usr/local/miniconda3/envs/deepseek-ocr/lib/python3.12/site-packages/setuptools/_distutils/cmd.py", line 357, in run_command self.distribution.run_command(command) File "/usr/local/miniconda3/envs/deepseek-ocr/lib/python3.12/site-packages/setuptools/dist.py", line 1102, in run_command super().run_command(command) File "/usr/local/miniconda3/envs/deepseek-ocr/lib/python3.12/site-packages/setuptools/_distutils/dist.py", line 1021, in run_command cmd_obj.run() File "/usr/local/miniconda3/envs/deepseek-ocr/lib/python3.12/site-packages/setuptools/command/build_ext.py", line 96, in run _build_ext.run(self) File "/usr/local/miniconda3/envs/deepseek-ocr/lib/python3.12/site-packages/setuptools/_distutils/command/build_ext.py", line 368, in run self.build_extensions() File "/usr/local/miniconda3/envs/deepseek-ocr/lib/python3.12/site-packages/torch/utils/cpp_extension.py", line 552, in build_extensions _check_cuda_version(compiler_name, compiler_version) File "/usr/local/miniconda3/envs/deepseek-ocr/lib/python3.12/site-packages/torch/utils/cpp_extension.py", line 447, in _check_cuda_version raise RuntimeError(CUDA_MISMATCH_MESSAGE.format(cuda_str_version, torch.version.cuda)) RuntimeError: The detected CUDA version (12.8) mismatches the version that was used to compile PyTorch (11.8). Please make sure to use the same CUDA versions. [end of output] note: This error originates from a subprocess, and is likely not a problem with pip. ERROR: Failed building wheel for flash-attn Running setup.py clean for flash-attn Failed to build flash-attn error: failed-wheel-build-for-install × Failed to build installable wheels for some pyproject.toml based projects ╰─> flash-attn, 怎么办
11-01
对于在Python环境中构建flash - attn库的wheel文件时出现的`urllib.error.ContentTooShortError`(下载不完整)和CUDA版本与PyTorch编译版本不匹配的问题,可以按以下方法解决: ### 解决`urllib.error.ContentTooShortError`问题 此错误通常是由于网络不稳定,导致文件下载不完整。可以采取以下措施: - **重新下载**:在命令行中,若使用`pip`下载依赖项时出现该错误,可尝试重新运行下载命令。例如: ```bash pip install <package_name> ``` - **使用镜像源**:虽然引用中提到不建议使用国内源下载PyTorch,但对于其他依赖项,可临时使用国内镜像源加速下载。例如使用阿里云镜像源: ```bash pip install <package_name> -i https://mirrors.aliyun.com/pypi/simple/ ``` - **手动下载**:可手动从官方网站或镜像网站下载所需的wheel文件,然后使用`pip`进行本地安装。例如: ```bash pip install /path/to/downloaded/file.whl ``` ### 解决CUDA版本与PyTorch编译版本不匹配的问题 已知torch版本为2.6.0+cu118,CUDA版本为12.8,存在版本不匹配情况。可采取以下措施: - **安装匹配的CUDA版本**:根据PyTorch编译版本,安装对应的CUDA 11.8版本。可从NVIDIA官方网站下载CUDA 11.8的安装包,并按照官方文档进行安装。 - **安装匹配的PyTorch版本**:若不想更改CUDA版本,可安装与CUDA 12.8匹配的PyTorch版本。可参考PyTorch官方的版本选择页面(Previous PyTorch Versions | PyTorch ),使用如下命令安装: ```bash pip3 install torch==<desired_version>+cu128 torchvision==<corresponding_version>+cu128 -f https://download.pytorch.org/whl/cu128/torch_stable.html ``` 安装完成后,可通过以下Python代码验证CUDA是否可以正常工作: ```python import torch print(torch.cuda.is_available()) ``` 若返回`True`,则说明CUDA可以正常工作。
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值