flash-linear-attention中的Chunkwise并行算法的理解

这里提一下,我维护的几三个记录个人学习笔记以及社区中其它大佬们的优秀博客链接的仓库都获得了不少star,感谢读者们的认可,我也会继续在开源社区多做贡献。github主页:https://github.com/BBuf ,欢迎来踩

在这里插入图片描述

0x0. 前言

我之前解读过causal linear attention的cuda实现,文章见:https://zhuanlan.zhihu.com/p/673896906 ,也是在评论区通过@sonta 了解到了flash-linear-attention的Chunkwise并行实现。上篇文章https://mp.weixin.qq.com/s/H6wWBxwIJNCzkIlH_uIuiw中说到后续想继续解析一下chunk_rwkv6的实现,chunk_rwkv6的实现思路仍然是沿用flash-linear-attention中的Chunkwise并行思路,由于之前没有认真看过这个Chunkwise的算法所以读起来有点困难,这里需要用普通并行以及RNN递归的视角去看待才能理解这个算法流程。这篇文章就从 Gated Linear Attention Transformers with Hardware-Efficient Training (https://arxiv.org/pdf/2312.06635) 这篇Paper对线性Attention的Chunwise并行讲解和伪代码入手深入理解下这个方法,另外我们也会在后面深入分析下代码的实现。这篇Paper的作者也是flash-linear-attention的作者。

0x1. Paper部分

Paper部分这里只关注Background里面和Linear Attention相关的两节。这里对其进行翻译和解读。

在这里插入图片描述

我们首先简要介绍一下线性注意力层的背景。对于符号表示,我们使用黑体大写字母表示矩阵(例如,S、Q),黑体小写字母表示向量(例如, q t q_t qt k t k_t kt),斜体大写字母表示可学习的参数矩阵(例如, W K W_K WK)。通常我们使用相同的字母表示矩阵的行,例如, q t q_t qt 表示矩阵 Q Q Q 的第 t t t 行。

在这里插入图片描述

2.1 并行和递归形式

标准的Transformers采用softmax注意力机制,该机制接受输入序列 X ∈ R L × d X \in \mathbb{R}^{L \times d} XRL×d(其中 L L L 是长度, d d d 是隐藏维度)并通过以下方式计算输出 O ∈ R L × d O \in \mathbb{R}^{L \times d} ORL×d

Q , K , V = X W Q , X W K , X W V , O = softmax ( ( Q K T ) ⊙ M ) V , Q, K, V = XW_Q, XW_K, XW_V, O = \text{softmax}\left((QK^T) \odot M\right) V, Q,K,V=XWQ,XWK,XWV,O=softmax((QKT)M)V,

其中 W Q , W K , W V ∈ R d × d W_Q, W_K, W_V \in \mathbb{R}^{d \times d} WQ,WK,WVRd×d 是可学习的矩阵, M ∈ { − ∞ , 1 } L × L M \in \{-\infty, 1\}^{L \times L} M{ ,1}L×L 是一个掩码,用于防止模型关注未来的token,即 M i j = 1 M_{ij} = 1 Mij=1 i ≥ j i \geq j ij M i j = − ∞ M_{ij} = -\infty Mij= i < j i < j i<j。 (这里我们假设一个简单的单头注意力。)上述的并行注意力形式可以在给定完整输入 X X X 的情况下并行计算 O O O,从而实现高效训练。然而,在推理过程中,Transformer必须使用以下递归形式:

q t , k t , v t = x t W Q , x t W K , x t W V q_t, k_t, v_t = x_t W_Q, x_t W_K, x_t W_V qt,kt,vt=xtWQ,xtWK,xtWV

o t = ∑ i = 1 t exp ⁡ ( q t k i T ) v i ∑ i = 1 t exp ⁡ ( q t k i T ) o_t = \frac{\sum_{i=1}^{t} \exp(q_t k_i^T) v_i}{\sum_{i=1}^{t} \exp(q_t k_i^T)} ot=i=1texp(qtkiT)i=1texp(qtkiT)

### Self-Attention Mechanism Architecture Diagram Self-attention机制的核心在于其能够并行处理输入序列中的所有元素,并通过计算查询(Q)、键(K)和值(V)之间的关系来捕获全局依赖性。以下是self-attention机制架构的一个简化描述及其对应的流程图实现方式。 #### 多头自注意力机制概述 多头自注意力机制通过分解输入向量到多个子空间中,分别学习不同的特征表示,从而增强模型的表达能力[^3]。具体来说,它会生成 \( h \) 个独立的注意力头,每个头都会单独计算自己的 \( Q, K, V \),并通过线性变换得到最终的结果。这些结果会被拼接在一起形成完整的输出。 下面是multi-head self-attention 的伪代码展示: ```python import torch.nn as nn class MultiHeadedAttention(nn.Module): def __init__(self, d_model, num_heads): super(MultiHeadedAttention, self).__init__() assert d_model % num_heads == 0 self.d_k = d_model // num_heads self.num_heads = num_heads self.linears = clones(nn.Linear(d_model, d_model), 4) def forward(self, query, key, value, mask=None): nbatches = query.size(0) # 将query、key 和value 分解成num_heads份 query, key, value = [ l(x).view(nbatches, -1, self.num_heads, self.d_k).transpose(1, 2) for l, x in zip(self.linears, (query, key, value)) ] # 应用缩放点积注意算法 scores = attention(query, key, value, mask=mask, dropout=self.dropout) # 拼接各个头部的结果 concat_scores = scores.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k) return self.linears[-1](concat_scores) ``` 此代码片段展示了如何构建一个多头注意力层,其中包含了关键步骤如线性映射、分割以及最后的重新组合过程。 对于想要了解具体的 **self-attention mechanism architecture diagram**, 下面是一个常见的图形化解释: ![Self Attention Mechanism](https://miro.medium.com/max/700/1*9cNzJZfEaFmO8sYHkSbBjA.png)[^3] 该图表清晰地描绘了从输入经过线性转换成为\( Q,K,V\) 向量组,再到应用scaled dot-product attention操作直至最终输出的整体流经路径。 --- ####
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值