Flash Attention源码学习之原理(第一篇)

背景与原理

Self-Attention 公式如下
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k} } )V Attention(Q,K,V)=softmax(dk QKT)V
其中:
Q - Query, 查询张量,Shape是[Batch, Seq_len, Head_num, Hide_dim]
K - Key, 索引张量,Shape是[Batch, Seq_len_kv, Head_num, Hide_dim]
V - Value, 内存张量,Shape是[Batch, Seq_len_kv, Head_num, Hide_dim]
d k d_k dk - 特征维度,主要是为了控制矩阵点乘后的范围,以确保softmax的梯度稳定性,经常取值等于Hide_dim

传统的Self-Attention有三种注意力机制:

  1. MHA (Multi-head Attention)
    MHA(Multi-head Attention)是Google团队在 2017 年在Attention Is All You Need提出的一种NLP经典模型,首次提出并使用了 Self-Attention 机制,也就是 Multi Head Attention,是标准的多头注意力机制,有H个Query、Key 和 Value 矩阵。
  2. MQA (Multi-Query Attention)
    MQA也是Google团队在2019年,在Fast Transformer Decoding: One Write-Head is All You Need提出的,是MHA的一种变体,也是用于自回归解码的一种注意力机制。与MHA不同的是,MQA 让所有的头之间共享同一份 Key 和 Value 矩阵,每个头只单独保留了一份 Query 参数,从而大大减少 Key 和 Value 矩阵的参数量,以此来达到提升推理速度,但是会带来精度上的损失。
  3. GQA (Group-Query Attention)
    GQA(分组查询注意力)同样,也是Google在2023年,于[GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints]( GQA(分组查询注意力)同样,也是Google在2023年,于GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints提出的一种MHA变体,GQA将查询头分成G组,对于Query是每个头单独保留一份参数,每个组共享一个Key 和 Value 矩阵。GQA-G是指具有G组的grouped-query attention。)提出的一种MHA变体,GQA将查询头分成G组,对于Query是每个头单独保留一份参数,每个组共享一个Key 和 Value 矩阵。GQA-G是指具有G组的grouped-query attention。 在这里插入图片描述

2. Flash Attention公式推导

论文《FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness》
FlashAttention是一种加速注意力计算方法,目前已经应用在:GPT-3、Falcon2(阿联酋大模型)、Llama2、Megatron-LM、GPT-4等知名LLM上。Flash Attention已经集成到了pytorch2.0中,可以很便捷的调用。

FlashAttention旨在加速注意力计算并减少内存占用。FlashAttention利用底层硬件的内存层次知识,例如GPU的内存层次结构,来提高计算速度和减少内存访问开销。 FlashAttention的核心原理是通过将输入分块并在每个块上执行注意力操作,从而减少对高带宽内存(HBM)的读写操作。具体而言,FlashAttention使用平铺和重计算等经典技术,将输入块从HBM加载到SRAM(快速缓存),在SRAM上执行注意力操作,并将结果更新回HBM。FlashAttention减少了内存读写量,从而实现了2-4倍的时钟时间加速。

原论文中的公式:
在这里插入图片描述
下面我们用大白话来重新推导一遍去前向(反向也是一样的方式,不再重复赘叙)

  1. 首先是对 Q , K , V , l , m Q,K, V,l,m Q,K,V,l,m进行分块
  2. 在外循环中变量K, V每一块,在内循环中变量Q
  3. 将分块后的Q, K,V的小块加载的SRAM中
  4. 计算每一个小块Q, K的矩阵乘 S i j = Q i K j T S_{ij} = Q_i K_j^T Sij=QiKjT
  5. 计算 M A S K ( S i j ) MASK(S_{ij}) MASK(Sij)
  6. 使用online softmax进行softmax的计算,减少数据搬移到HBM的成本
    在这里插入图片描述
    于是 softmax 的计算可以优化成:(这个就 Online Softmax)
    在这里插入图片描述
  7. 计算 Q i = ( Q i + P i j ~ ∗ V j ) ∗ ( e ( m i j ~ − m i n e w ) ) Q_i = (Q_i + \widetilde{P_{ij}}*V_j)* (e^(\widetilde{m_{ij}} - m_i^{new})) Qi=(Qi+Pij Vj)(e(mij minew)),并将结果写回HBM
    计算示意图如下:
    在这里插入图片描述
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值