文章目录
1. 前言
前一篇我们讲解了Flash Attention的原理,后续会陆陆续续推出Flash Attention的源码学习笔记,本篇主要讲解如何在本地Flash Attention的安装使用与接口参数说明
2. 编译与安装
检查CUDA版本,CUDA版本需要在在11.6及以上。如果版本较低,请自行从cuda-toolkit下载安装
nvidia-smi
安装pytorch依赖,注意需要选择带有CUDA版本的torch
pip3 install torch torchvision torchaudio
编译安装flash attention:
git clone --recursive https://github.com/Dao-AILab/flash-attention.git
cd flash-attention
python setup.py install
3. Flash Attention 仓目录结构
├── assets # 资源目录,存放一些图片文件
├── benchmarks # 性能测试
├── csrc # 基于CUTLASS与CK为基础实现的flash attention算子kernel,以及一些前后处理的kernel,比如LayerNorm
├── flash_attn # python interface
├── hopper # 基于Hoper框架 实现的kernel
├── tests # 测试
└── training # 训练测试
后续主要以讲解Hoper的实现源码细节为主,所以集中在Hoper目录的讲解,Hoper 头文件引用关系
4 Flash attention Interface 参数说明
flash_attn_interface.py中定义了 flash attention的一些常见函数,下面对这些函数的参数做一个详细说明
4.1 flash_attn_func
函数flash_attn_func
用于定长序列输出的注意力函数, 根据对Q, K, V 的输入方式的不同,提供了三个接口
- qkv 打包成一个张量的flash attention 函数输入, 要求QKV的形状一样,通过将 query、key 和 value 打包成一个张量作为输入,避免了显式连接 Q、K、V 的梯度,从而提高了计算效率。
def flash_attn_qkvpacked_func(
qkv, # query、key 和 value 张量,形状为 (batch_size, seqlen, 3, nheads, headdim)
dropout_p=0.0, # 设置dropout的概率值,在评估(evaluation)时应设置为 0.0,以保留所有神经元的输出,防止信息损失。在训练(training)时,可以设置一个小于 1 的值,例如 0.1,表示将有 10% 的神经元被随机丢弃
softmax_scale=None, # softmax之前的缩放系数,默认为1 / sqrt(headdim),例如如果 headdim=64,则默认缩放因子为 1 / sqrt(64) ≈ 0.126
causal=False, # 是否应用注意力掩码,如果设置为True,则查询只能关注之前的输出,无法关注未来的输出 ,这在诸如语言模型等自回归任务中非常有用
window_size=(-1, -1), # 用于实现滑动窗口局部注意力,(-1, -1) 表示无限上下文窗口,即不进行任何窗口限制。如果设置为 i, 关注窗口为 [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] 范围内(包括边界)的键
softcap=0.0, # 0.0 means deactivated. Anything > 0 activates softcapping attention
alibi_slopes=None, # 一种用于注意力分数偏置的方法,形状为 (nheads,) 或 (batch_size, nheads),数据类型为 fp32。例如,如果 nheads=16,则 alibi_slopes 可以是形状为 (16,) 的张量,每个元素对应一个注意力头的偏置斜率。如果设置了该参数,则会对查询 i 和键 j 之间的注意力分数加上一个偏置 (-alibi_slope * |i + seqlen_k - seqlen_q - j|),这种偏置可以鼓励模型关注更靠近查询的键,或者更远离查询的键,具体取决于 alibi_slopes 的值
deterministic=False, # 是否使用确定性的反向传播实现,这种实现稍慢且内存占用更高,但是前向传播始终是确定性的。在评估(evaluation)时,通常不需要使用确定性实现,因为评估时不需要计算梯度
return_attn_probs=False, # 是否返回注意力概率,仅用于测试,返回的概率可能缩放不正确,因为在实际应用中,通常不需要直接访问注意力概率,而是更关注注意力层的输出
):
""""
返回值:
out: (batch_size, seqlen, nheads, headdim).
softmax_lse: 可选的,如果return_attn_probs=true, 则会返回softmax_lse, 形状是(nheads, total_q_seqlen)。多矩阵 QK^T * scaling的每行求logsumexp
S_dmask: 可选的,如果return_attn_probs=true, 则返回,形状是batch_size, nheads, seqlen, seqlen)
""""
- kv打包成一个张量的flash attention 函数输入, 同样也是为了提高性能
def flash_attn_kvpacked_func(
q, # query 张量,形状是(batch_size, seqlen, nheads, headdim)
kv, # key 和 value 张量, 形状是(batch_size, seqlen, 2, nheads_k, headdim)
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
softcap=0.0, # 0.0 means deactivated
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
):
""""
返回值:
out: (batch_size, seqlen, nheads, headdim).
softmax_lse: 可选的,如果return_attn_probs=true, 则会返回softmax_lse, 形状是(nheads, total_q_seqlen)。多矩阵 QK^T * scaling的每行求logsumexp
S_dmask: 可选的,如果return_attn_probs=true, 则返回,形状是batch_size, nheads, seqlen, seqlen)
""""
- Q, K, V 分成三个输入,当KV的
nheads_k
小于Q, 则是MQA/GQA
. 需要注意Q的nhead
必须能被KV的头整除,即nhead/nheads_k
为整数
def flash_attn_func(
q, # query 张量,形状是(batch_size, seqlen, nheads, headdim)
k, # key 张量,形状是(batch_size, seqlen, nheads_k, headdim)
v, # value 张量,形状是(batch_size, seqlen, nheads_k, headdim)
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
):
""""
返回值:
out: (batch_size, seqlen, nheads, headdim).
softmax_lse: 可选的,如果return_attn_probs=true, 则会返回softmax_lse, 形状是(nheads, total_q_seqlen)。多矩阵 QK^T * scaling的每行求logsumexp
S_dmask: 可选的,如果return_attn_probs=true, 则返回,形状是batch_size, nheads, seqlen, seqlen)
""""
4. 2 flash_attn_varlen_func
函数flash_attn_varlen_func
用于可变长的注意力输出,同样的根据Q, K, V参数的输入方式,也有三种接口
- Q, K, V 打包成一个张量
def flash_attn_func(
qkv, # query、key 和 value 张量,形状为 (total, 3, nheads, headdim)
cu_seqlens, # query、key 和 value的序列的累积长度,形状为 (batch_size + 1),数据类型为 torch.int32,例如 [0, 10, 20, 32, 42],表示批次中有 4 个序列,第一个序列长度为 10,第二个序列长度为 10,第三个序列长度为 12,第四个序列长度为 10。这些累积长度用于索引 Q, K, V 张量
max_seqlen, # 序列的最大长度
dropout_p=0.0, # dropout 概率,在评估(evaluation)时应设置为 0.0,以保留所有神经元的输出,防止信息损失。
softmax_scale=None, # softmax之前的缩放系数,默认为1 / sqrt(headdim),例如如果 headdim=64,则默认缩放因子为 1 / sqrt(64) ≈ 0.126
causal=False, # 是否应用注意力掩码,如果设置为True,则查询只能关注之前的输出,无法关注未来的输出 ,这在诸如语言模型等自回归任务中非常有用
window_size=(-1, -1), # 用于实现滑动窗口局部注意力,(-1, -1) 表示无限上下文窗口,即不进行任何窗口限制。如果设置为 i, 关注窗口为 [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] 范围内(包括边界)的键
softcap=0.0, # 0.0 means deactivated. Anything > 0 activates softcapping attention
alibi_slopes=None,# 一种用于注意力分数偏置的方法,形状为 (nheads,) 或 (batch_size, nheads),数据类型为 fp32。例如,如果 nheads=16,则 alibi_slopes 可以是形状为 (16,) 的张量,每个元素对应一个注意力头的偏置斜率。如果设置了该参数,则会对查询 i 和键 j 之间的注意力分数加上一个偏置 (-alibi_slope * |i + seqlen_k - seqlen_q - j|),这种偏置可以鼓励模型关注更靠近查询的键,或者更远离查询的键,具体取决于 alibi_slopes 的值
deterministic=False, # 是否使用确定性的反向传播实现,这种实现稍慢且内存占用更高,但是前向传播始终是确定性的。在评估(evaluation)时,通常不需要使用确定性实现,因为评估时不需要计算梯度
return_attn_probs=False,# 是否返回注意力概率,仅用于测试,返回的概率可能缩放不正确,因为在实际应用中,通常不需要直接访问注意力概率,而是更关注注意力层的输出
):
""""
返回值:
out: (batch_size, seqlen, nheads, headdim).
softmax_lse: 可选的,如果return_attn_probs=true, 则会返回softmax_lse, 形状是(nheads, total_q_seqlen)。多矩阵 QK^T * scaling的每行求logsumexp
S_dmask: 可选的,如果return_attn_probs=true, 则返回,形状是batch_size, nheads, seqlen, seqlen)
""""
- K, V 打包成一个张量
def flash_attn_varlen_kvpacked_func(
q, # query 张量,形状是(total_q, nheads, headdim),例如 (1024, 16, 64),其中 total_q=1024 是批次中 query token 的总数,nheads=16 是注意力头数,headdim=64 是每个注意力头的维度
kv, # key 和 value 张量, 形状是(total_k, 2, nheads_k, headdim)
cu_seqlens_q, # query 序列的累积长度,形状为 (batch_size + 1),数据类型为 torch.int32
cu_seqlens_k, # key 序列的累积长度,形状为 (batch_size + 1),数据类型为 torch.int32,例如 [0, 15, 30, 45, 60],用于索引 kv 张量
max_seqlen_q, # 批次中 query 序列的最大长度,例如 12,表示批次中最长的 query 序列长度为 12
max_seqlen_k, # 批次中 key 序列的最大长度,例如 15,表示批次中最长的 key 序列长度为 15
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
softcap=0.0, # 0.0 means deactivated
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
):
""""
返回值:
out: (batch_size, seqlen, nheads, headdim).
softmax_lse: 可选的,如果return_attn_probs=true, 则会返回softmax_lse, 形状是(nheads, total_q_seqlen)。多矩阵 QK^T * scaling的每行求logsumexp
S_dmask: 可选的,如果return_attn_probs=true, 则返回,形状是batch_size, nheads, seqlen, seqlen)
""""
- Q, K, V 分开输入
def flash_attn_varlen_func(
q, # query 张量,形状是(total_q, nheads, headdim)
k, # key 张量,形状是(total_k, nheads_k, headdim)
v, # value 张量,形状是(total_k, nheads_k, headdim)
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
softcap=0.0, # 0.0 means deactivated
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
block_table=None,
):
""""
返回值:
out: (batch_size, seqlen, nheads, headdim).
softmax_lse: 可选的,如果return_attn_probs=true, 则会返回softmax_lse, 形状是(nheads, total_q_seqlen)。多矩阵 QK^T * scaling的每行求logsumexp
S_dmask: 可选的,如果return_attn_probs=true, 则返回,形状是batch_size, nheads, seqlen, seqlen)
""""
4. 2 flash_attn_with_kvcache
带有KV cache的flash attention 接口, 该函数用于在推理(inference)过程中计算注意力层的输出,同时支持使用 key 和 value 缓存,以及旋转位置嵌入等技术, 并且支持更新 key 和 value 缓存:
def flash_attn_with_kvcache(
q, # query 张量,形状是(batch_size, seqlen, nheads, headdim)
k_cache, # key 缓存张量,形状为 (batch_size_cache, seqlen_cache, nheads_k, headdim) 或 (num_blocks, page_block_size, nheads_k, headdim)
v_cache, # value 缓存张量,形状与 k_cache 相同
k=None, # 可选的,新 key 张量,形状为 (batch_size, seqlen_new, nheads_k, headdi
v=None, # 可选的, 新 value 张量,形状与 k 相同
rotary_cos=None, # 可选的, 旋转位置嵌入余弦值,形状为 (seqlen_ro, rotary_dim / 2)
rotary_sin=None, # 可选的,旋转位置嵌入正弦值,形状与 rotary_cos 相同
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, # 缓存序列长度,可以是整数或张量
cache_batch_idx: Optional[torch.Tensor] = None, # 缓存批次索引张量,形状为 (batch_size,)
cache_leftpad: Optional[torch.Tensor] = None, # KV 缓存的开始的index, 形状为 (batch_size,), 默认为0
block_table: Optional[torch.Tensor] = None, # 可选的块表张量,形状为 (batch_size, max_num_blocks_per_seq)
softmax_scale=None, # softmax 缩放系数,默认为 1 / sqrt(headdim)
causal=False, # 是否进行因果注意力掩码
window_size=(-1, -1), # 滑动窗口大小,(-1, -1)表示无限上下文窗口
softcap=0.0, # 0.0 means deactivated. Anything > 0 activates softcapping attention
rotary_interleaved=True, # 是否交错旋转位置嵌入
alibi_slopes=None, # 可选的,注意力分数偏置,形状为 (nheads,) 或 (batch_size, nheads)
num_splits=0, # 将 key/value 沿序列维度分割的数量,0 表示自动确定
return_softmax_lse=False, # 布尔值,判断是否返回注意力分数的logsumexp,主要用于反向计算
):
""""
返回值:
out: (batch_size, seqlen, nheads, headdim).
softmax_lse: 可选的,如果 return_softmax_lse=true, 则会返回softmax_lse, 形状是(batch_size, nheads, seqlen)
""""