Abstract&Introduction&Related Work
- 研究任务
语言模型的基础架构 - 已有方法和相关工作
- S4,H3,Hyena,Linear Transformer
- 用核函数近似注意力,以便将自回归推理重写为循环形式
- 回归到使用循环模型进行高效推理,但牺牲了训练并行性。为了弥补这一点,使用元素级操作[PAA+23]进行加速,但同时损害了表示能力和性能
- 尝试用其他机制取代注意力,例如S4[GGR21]及其变体[DFS+22,PMN+23]
- 面临挑战
- 创新思路
- RetNet = linear attention + rope + 显式衰减(即 γ \gamma γ)
- 实验结论
实现了不可能三角,实现了O(1)推理
Retentive Networks
先来看一下Retention跟Attention的区别,首先第一眼感觉retention有点像RNN和LSTM
attention的计算方式是QK做矩阵乘法,使用query和key计算权重分布,对value加权
retention使用了一个线性衰减参数 γ \gamma γ,使用了一个状态向量S
给定输入 X ∈ R ∣ x ∣ × d m o d e l X\in\mathbb{R}^{|x|\times d_{\mathrm{model}}} X∈R∣x∣×dmodel,我们将其投影到一维函数 v ( n ) = X n ⋅ w V v(n) = X_n · w_V v(n)=Xn⋅wV。考虑一个序列建模问题,通过状态 s n s_n sn 将 v ( n ) v(n) v(n)映射为 o ( n ) o(n) o(n), 为简单起见,用 v n 、 o n v_n、o_n vn、on 表示 v ( n ) v(n) v(n) 和 o ( n ) o(n) o(n), 以递归方式形式化映射过程:
s n = A s n − 1 + K n T v n , A ∈ R d × d , K n ∈ R 1 × d o n = Q n s n = ∑ m = 1 n Q n A n − m K m T v m , Q n ∈ R 1 × d \begin{aligned}s_n&=As_{n-1}+K_n^\mathsf{T}v_n,&A\in\mathbb{R}^{d\times d},K_n\in\mathbb{R}^{1\times d}\\o_n&=Q_ns_n=\sum_{m=1}^nQ_nA^{n-m}K_m^\mathsf{T}v_m,&Q_n\in\mathbb{R}^{1\times d}\end{aligned} snon=Asn−1+KnTvn,=Qnsn=m=1∑nQnAn−mKmTvm,A∈Rd×d,Kn∈R1×dQn∈R1×d
将vn映射到状态向量sn,并通过线性变换来递归地编码序列信息。接下来使投影 Q n Q_n Qn、 K n K_n Kn变得与内容相关: Q = X W Q , K = X W K Q=XW_{Q},\quad K=XW_{K} Q=XW