【Triton 教程】低内存 Dropout

Triton 是一种用于并行编程的语言和编译器。它旨在提供一个基于 Python 的编程环境,以高效编写自定义 DNN 计算内核,并能够在现代 GPU 硬件上以最大吞吐量运行。

更多 Triton 中文文档可访问 →https://triton.hyper.ai/

在本教程中,您将编写一个内存高效的 Dropout 实现,其状态将由单个 int32 seed 组成。这与传统 Dropout 实现不同,传统实现通常由与输入 shape 相同的位掩码张量组成。

在这过程中,您将学习到以下内容:

  • PyTorch 中 原生实现 Dropout 的局限性。

  • Triton 中的并行伪随机数生成。

简介

Dropout 是在 [SRIVASTAVA2014] 中引入的一种技术,用于改善低数据条件下深度神经网络的性能,通常用于正则化。它接受一个向量作为输入,并生成相同 shape 的输出向量。输出中的每个标量都有概率 p 被设为零,否则直接从输入复制。这使得网络在仅有输入的 1−p 标量时也能表现良好。

在评估阶段,为了充分利用网络的能力,将 p 设为 0。但是简单地将 p 设为 0 会增加输出的范数,可能会人为地降低输出的 softmax temperature。为了防止这种情况发生,输出被缩放为 1/(1-p),这使得无论 dropout 概率如何都能保持一致的范数。

Baseline

首先看一下 baseline 的实现。

import tabulate
import torch


import triton
import triton.language as tl


@triton.jit
def _dropout(
    x_ptr,      # 输入指针
    x_keep_ptr, # pointer to a mask of 0s and 1s 由 01 组成的掩码的指针
    output_ptr, # pointer to the output 输出指针
    n_elements, # number of elements in the `x` tensor `x` 张量的元素数量
    p,          # probability that an element of `x` is changed to zero 元素 `x` 被设置为 0 的概率
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    # Load data
    # 加载数据
    x = tl.load(x_ptr + offsets, mask=mask)
    x_keep = tl.load(x_keep_ptr + offsets, mask=mask)
    # The line below is the crucial part, described in the paragraph above!
    # 下一行是上段描述的关键部分
    output = tl.where(x_keep, x / (1 - p), 0.0)
    # Write-back output
    # 写回输出
    tl.store(output_ptr + offsets, output, mask=mask)


def dropout(x, x_keep, p):
    output = torch.empty_like(x)
    assert x.is_contiguous()
    n_elements = x.numel()
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
    _dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024)
    return output


# Input tensor
# 输入张量
x = torch.randn(size=(10, )).cuda()
# Dropout mask
# Dropout 掩码
p = 0.5
x_keep = (torch.rand(size=(10, )) > p).to(torch.int32).cuda()
#
output = dropout(x, x_keep=x_keep, p=p)
print(tabulate.tabulate([
    ["input"] + x.tolist(),
    ["keep mask"] + x_keep.tolist(),
    ["output"] + output.tolist(),
]))

Out:
在这里插入图片描述

种子化 Dropout

上述 Dropout 实现效果良好,但管理 Dropout 状态可能会变得复杂,特别是在考虑反向传播和重新计算/检查点场景时。在这里,我们描述一种替代实现,它具有以下优点:

  1. 更小的内存占用。
  2. 较少的数据移动。
  3. 简化了在多次调用内核函数时持久化随机性的管理。

生成 Triton 中的伪随机数很简单!在本教程中,我们将使用 triton.language.rand 函数,该函数基于给定的种子和一组 int32 偏移量生成一个块的均匀分布的 float32 值,范围在 (0, 1) 内。但如果你需要,Triton 也提供其他随机数生成策略。

注意 Triton 的 PRNG 实现基于 Philox 算法(详见 [SALMON2011])。

现在将所有内容整合起来。

@triton.jit
def _seeded_dropout(
    x_ptr,
    output_ptr,
    n_elements,
    p,
    seed,
    BLOCK_SIZE: tl.constexpr,
):
    # compute memory offsets of elements handled by this instance
    # 计算由此实例处理的元素的内存偏移量
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    # load data from x
    # 从 x 读取数据
    mask = offsets < n_elements
    x = tl.load(x_ptr + offsets, mask=mask)
    # randomly prune it
    # 随机修剪
    random = tl.rand(seed, offsets)
    x_keep = random > p
    # write-back
    # 写回
    output = tl.where(x_keep, x / (1 - p), 0.0)
    tl.store(output_ptr + offsets, output, mask=mask)


def seeded_dropout(x, p, seed):
    output = torch.empty_like(x)
    assert x.is_contiguous()
    n_elements = x.numel()
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
    _seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024)
    return output




x = torch.randn(size=(10, )).cuda()
# Compare this to the baseline - dropout mask is never instantiated!
# 与基线相比 - dropout 掩码从未被实例化!
output = seeded_dropout(x, p=0.5, seed=123)
output2 = seeded_dropout(x, p=0.5, seed=123)
output3 = seeded_dropout(x, p=0.5, seed=512)


print(tabulate.tabulate([
    ["input"] + x.tolist(),
    ["output (seed = 123)"] + output.tolist(),
    ["output (seed = 123)"] + output2.tolist(),
    ["output (seed = 512)"] + output3.tolist(),
]))

Out:

在这里插入图片描述

大功告成!我们现在有了一个 Triton 内核,可以在给定相同种子的情况下应用一致的 dropout 掩码。与传统的 dropout 实现相比,这种方法减少了内存开销并简化了状态管理。

练习

  1. 扩展内核以处理矩阵,并使用一个种子向量 — 每行一个种子。
  2. 添加对 striding 的支持。
  3. (挑战)实现稀疏 Johnson-Lindenstrauss 变换的内核,每次使用种子动态生成投影矩阵。

参考文献

  • [SALMON2011] John K. Salmon, Mark A. Moraes, Ron O. Dror, and David E. Shaw, “Parallel Random Numbers: As Easy as 1, 2, 3”, 2011

  • [SRIVASTAVA2014] Nitish Srivastava et al., “Dropout: A Simple Way to Prevent Neural Networks from Overfitting”, JMLR 2014

​Download Jupyter notebook: 04-low-memory-dropout.ipynb

Download Python source code: 04-low-memory-dropout.py

Download zipped: 04-low-memory-dropout.zip

<think>好的,用户现在想了解如何安装FlashAttention 2。首先,我需要回想一下之前用户询问的是Git的安装,现在转向深度学习相关的库安装,可能用户正在进入模型优化或需要高效注意力机制的开发工作。 用户可能的环境是什么?他们可能在使用PyTorch,因为FlashAttention 2主要支持PyTorch。需要确认他们的CUDA版本是否兼容,因为FlashAttention 2需要CUDA 11.4以上,特别是11.6和11.7更稳定。如果用户CUDA版本过低,可能需要先升级,这时候应该提醒他们检查nvidia-smi和nvcc --version的区别,避免混淆运行时和编译时的版本。 接下来,安装步骤可能需要先安装依赖项,比如ninja,用于构建。然后是PyTorch,如果用户没有安装或者版本不对,需要指导他们正确安装。特别是PyTorch 2.0以上,需要对应CUDA版本,比如通过pip命令指定cu118之类的。 然后是安装FlashAttention 2本身,使用pip install命令,但需要注意是否从源码编译,可能需要添加--no-build-isolation选项,或者遇到问题时的解决方法,比如安装pyproject.toml中的依赖。如果用户的环境不支持高版本的PyTorch,可能需要源码编译,这时候要详细说明克隆仓库、安装依赖、编译安装的步骤。 验证安装部分,需要提供一个简单的测试代码,让用户确认是否安装成功。如果出现错误,比如SM版本不支持,可能需要检查GPU架构,这时候需要解释如何查找GPU的Compute Capability,以及可能需要的环境变量设置,如FLASH_ATTENTION_INTERNAL_FAST_PATH=1。 常见问题方面,用户可能会遇到版本不兼容,比如PyTorch或CUDA版本过低,这时候需要明确最低要求。编译错误的话,可能需要检查gcc版本,安装正确的CUDA工具包,或者权限问题。性能没提升可能需要检查是否启用了正确的前缀,或者输入是否符合要求,比如对齐到特定长度。 此外,用户可能对内核安装失败不太清楚,可能需要更详细的解释,比如为什么需要安装fused_attention等内核,以及如何通过环境变量禁用某些功能作为临时解决方案。 最后,需要提醒用户,安装过程可能因环境而异,建议参考官方文档,特别是如果遇到问题的话。保持步骤清晰,分点说明,确保用户能一步步跟随,并处理可能出现的错误情况。</think>以下是 **FlashAttention 2** 的详细安装教程,分步骤说明环境要求、安装方法和常见问题解决: --- ### **1. 环境要求** - **Python**: ≥ 3.8 - **PyTorch**: ≥ 2.0.0(需与CUDA版本匹配) - **CUDA**: ≥ 11.4(推荐 11.6/11.7/11.8) - **GPU架构**: SM70+(例如 V100、A100、RTX 30/40系列等) - **操作系统**: Linux 或 Windows(Linux兼容性更好) --- ### **2. 安装步骤** #### **步骤 1:检查CUDA和PyTorch版本** 1. **查看CUDA版本**: ```bash nvcc --version # 查看CUDA编译器版本(编译时依赖) nvidia-smi # 查看驱动支持的CUDA版本(运行时依赖) ``` 2. **确认PyTorch版本**: - 若未安装PyTorch,根据CUDA版本安装: ```bash # 示例:CUDA 11.8 pip install torch==2.0.0+cu118 torchvision==0.15.1+cu118 --index-url https://download.pytorch.org/whl/cu118 ``` - 若已安装,检查版本: ```bash python -c "import torch; print(torch.__version__)" ``` --- #### **步骤 2:安装依赖项** ```bash # 安装必要工具(Linux需安装ninja) pip install ninja packaging ``` --- #### **步骤 3:安装FlashAttention 2** **方法 1:通过PyPI直接安装** ```bash pip install flash-attn==2.5.6 --no-build-isolation ``` - 若报错,尝试先安装 `flash-attn` 的依赖: ```bash pip install -U flash-attn --no-cache-dir ``` **方法 2:从源码编译安装(推荐)** 1. 克隆仓库: ```bash git clone https://github.com/Dao-AILab/flash-attention cd flash-attention ``` 2. 安装依赖: ```bash pip install -e . # 安装主库 cd csrc/rotary && pip install -e . # 安装旋转位置编码内核 ``` --- #### **步骤 4:验证安装** 运行以下Python代码测试是否成功: ```python import torch from flash_attn import flash_attn_qkvpacked_func # 生成随机输入(形状:[batch_size, seq_len, 3, num_heads, head_dim]) qkv = torch.randn(2, 1024, 3, 12, 64, device="cuda", dtype=torch.float16) output = flash_attn_qkvpacked_func(qkv, dropout_p=0.0, causal=True) print(output.shape) # 应输出 torch.Size([2, 1024, 12, 64]) ``` --- ### **3. 常见问题解决** #### **问题 1:CUDA版本不兼容** - **现象**:`error: identifier "xxx" is undefined` 或 `SMXX arch not supported` - **解决**: 1. 检查GPU架构是否支持(如RTX 3090为SM86,需FlashAttention≥2.3.1)。 2. 设置环境变量强制启用兼容模式: ```bash export FLASH_ATTENTION_INTERNAL_FAST_PATH=1 ``` #### **问题 2:内核编译失败** - **现象**:`RuntimeError: Failed to compile CUDA kernels` - **解决**: 1. 升级 `nvcc` 和 `gcc`(需≥gcc 9.3)。 2. 确保已安装完整CUDA工具包: ```bash sudo apt install cuda-nvcc-11-8 # 根据CUDA版本调整 ``` #### **问题 3:性能未提升** - **现象**:速度与普通Attention无差异 - **解决**: 1. 确保输入序列长度是 **128的倍数**(FlashAttention 2对长序列优化更明显)。 2. 使用半精度(`torch.float16` 或 `bfloat16`)。 --- ### **4. 高级配置** - **禁用某些内核**:若特定内核导致崩溃,可通过环境变量禁用: ```bash export FLASH_ATTENTION_DISABLE_TRITON=1 # 禁用Triton内核 ``` - **启用内存高效模式**:减少显存占用(速度略有下降): ```python from flash_attn import flash_attn_func output = flash_attn_func(q, k, v, causal=True, softmax_scale=None) ``` --- ### **5. 官方参考** - GitHub仓库:$[https://github.com/Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)$ - 文档:$[https://github.com/Dao-AILab/flash-attention/tree/main#installation-and-features](https://github.com/Dao-AILab/flash-attention/tree/main#installation-and-features)$ --- 按照以上步骤操作后,即可在项目中调用 `flash_attn` 加速Transformer模型训练/推理。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值