文章目录
1. Llama3 整体结构
llama3 的整体结构还是延续transformer decoder 架构,其整体架构如下图左侧蓝色虚线框中所示。模型结构并不复杂,其主要组件为32个Transformer Block(32 为meta llama3 中的默认值)(见下图红色虚线框中所示)。
注 1 注_1 注1: 下一节中会参照上图中 红色圆形序号 讲解各模块。
注 2 注_2 注2: llama3的RoPE算法被拆成了3个方法来实现,上图中的模块2只包含了一个方法,另两个方法是在Attention模块(模块5)中进行的调用。
2. 模块详解
2.1 模块1: Embeddings
llama3 的embedding 使用的是VocabParallelEmbedding这个类进行的向量转换,这个类是meta的fairscale包中的一个类,可以理解为对torch.nn.embedding做了并行化。
2.2 模块2: RoPE
前文中已经提及llama3的RoPE算法被拆成了3个方法来实现,模块2只包含了一个方法,另两个方法是在Attention模块(模块5)中进行的调用。本小节具体按照RoPE的原始论文来讲解,主要阐述RoPE的算法原理。
2.2.1 从一个2维的例子说起 RoPE
我们知道,寻找位置编码的基本思路是 输入位置编码经过特征提取的核心算法后的值,应能反应出两个位置之间的先后顺序(这点不是必要的)和相对位置信息。(《Transformer(二)–论文理解:transformer 结构详解》 2.1节 中有简单说明),RoPE的原始论文中给出了一个数学表达,如下式:
<
f
q
(
x
m
,
m
)
,
f
k
(
x
n
,
n
)
>
=
g
(
x
m
,
x
n
,
m
−
n
)
(2.1)
<f_q(x_m,m),f_k(x_n,n)>=g(x_m,x_n,m-n) \tag{2.1}
<fq(xm,m),fk(xn,n)>=g(xm,xn,m−n)(2.1)
f q ( x m , m ) f_q(x_m,m) fq(xm,m)和 f k ( x n , n ) f_k(x_n,n) fk(xn,n)分别表示 q m q_m qm和 k n k_n kn。式子的左侧为点积形式,之所以为点积是因为tansformer中使用的attention score计算方法通常为点积。右侧 g ( x m , x n , m − n ) g(x_m,x_n,m-n) g(xm,xn,m−n)表示计算结果是与 x m , x n , m − n x_m,x_n,m-n xm,xn,m−n相关的。在这里 m − n m-n m−n的绝对值能反应出位置的距离,大小反应出前后顺序。 我们的目的就是找到一个这样的变换函数 f { q , k } f_{\{q,k\}} f{q,k}能表达 f q ( x m , m ) f_q(x_m,m) fq(xm,m)与 f k ( x n , n ) f_k(x_n,n) fk(xn,n),使 f q f_q fq与 f k f_k fk做点积操作后能保留 m − n m-n m−n的信息。当然我们找到了,见公式2.2
RoPE的论文中是先从2D情况下举例说明我们找到的 f ( x ) f(x) f(x)的,如下,当 d = 2 d=2 d=2时:
f q ( x m , m ) = ( W q x m ) e i m θ f k ( x n , n ) = ( W k x n ) e i n θ g ( x m , x n , m − n ) = R e [ ( W q x m ) ( W k x n ) ∗ e i ( m − n ) θ ] (2.2) f_q(x_m,m) = (\pmb{W}_{q}x_m)e^{im\theta} \\ f_k(x_n,n) = (\pmb{W}_{k}x_n)e^{in\theta} \\ g(x_m,x_n,m-n) = Re[(\pmb{W_q}x_m)(\pmb{W}_kx_n)^{*}e^{i(m-n)\theta}] \tag{2.2} fq(xm,m)=(Wqxm)eimθfk(xn,n)=(Wkxn)einθg(xm,xn,m−n)=Re[(Wqxm)(Wkxn)∗ei(m−n)θ](2.2)
其中
R
e
[
⋅
]
Re[ \cdot ]
Re[⋅]是复数的实部,
(
W
k
x
n
)
∗
(\pmb{W}_{k}x_n)^{*}
(Wkxn)∗表示
(
W
k
x
n
)
(\pmb{W}_{k}x_n)
(Wkxn)的共轭复数。
θ
∈
R
\theta \in \mathbb{R}
θ∈R 是一个预设的非零常数。我们可以进一步将
f
{
q
,
k
}
f_{\{q,k\}}
f{q,k}写成乘法矩阵:
f
{
q
,
k
}
(
x
m
,
m
)
=
(
c
o
s
m
θ
−
s
i
n
m
θ
s
i
n
m
θ
c
o
s
m
θ
)
(
W
{
q
,
k
}
(
11
)
W
{
q
,
k
}
(
12
)
W
{
q
,
k
}
(
21
)
W
{
q
,
k
}
(
22
)
)
(
x
m
(
1
)
x
m
(
2
)
)
(2.3)
f_{\{q,k\}}(x_m,m)= \left( \begin{matrix} cos\ m\theta & -sin\ m\theta \\ sin\ m\theta & cos\ m\theta \\ \end{matrix} \right) \left( \begin{matrix} W^{(11)}_{\{q,k\}} & W^{(12)}_{\{q,k\}} \\ W^{(21)}_{\{q,k\}} & W^{(22)}_{\{q,k\}} \\ \end{matrix} \right) \left( \begin{matrix} x^{(1)}_{m} \\ x^{(2)}_{m} \end{matrix} \right) \tag{2.3}
f{q,k}(xm,m)=(cos mθsin mθ−sin mθcos mθ)(W{q,k}(11)W{q,k}(21)W{q,k}(12)W{q,k}(22))(xm(1)xm(2))(2.3)
其中, ( x m ( 1 ) , x m ( 2 ) ) (x^{(1)}_{m},x^{(2)}_{m}) (xm(1),xm(2))是 x m x_m xm在二维坐标系中的表示。同样的, g g g也可以看作一个矩阵,因此可以在2维情况下求解公式(2.1)。
2.2.2 RoPE的一般形式
为了将我们在2D中的结果推广到任意的
x
i
∈
R
d
x_i \in \mathbb{R}^d
xi∈Rd,我们将d维空间划分为d/2个子空间,并根据内积的线性性质将它们组合起来,将
f
{
q
,
k
}
(
x
m
,
n
)
f_{\{q,k\}}(x_m,n)
f{q,k}(xm,n)转化为:
f
{
q
,
k
}
(
x
m
,
m
)
=
R
Θ
,
m
d
W
{
q
,
k
}
x
m
(2.4)
f_{\{q,k\}}(x_m,m)=\pmb{R}^{d}_{\Theta, m}\pmb{W}_{\{q,k\}}x_m \tag{2.4}
f{q,k}(xm,m)=RΘ,mdW{q,k}xm(2.4)
其中,
W
{
q
,
m
}
\pmb{W}_{\{q,m\}}
W{q,m} 表示与query和key 所对应的转换矩阵 ,
x
m
x_m
xm 为输入向量,
R
Θ
,
m
d
\pmb{R}^d_{\Theta,m}
RΘ,md为旋转矩阵,具体如下:
R
Θ
,
m
d
=
(
c
o
s
m
θ
1
−
s
i
n
m
θ
1
0
0
⋯
0
0
s
i
n
m
θ
1
c
o
s
m
θ
1
0
0
⋯
0
0
0
0
c
o
s
m
θ
2
−
s
i
n
m
θ
2
⋯
0
0
0
0
s
i
n
m
θ
2
c
o
s
m
θ
2
⋯
0
0
⋮
⋮
⋮
⋮
⋱
⋮
⋮
0
0
0
0
⋯
c
o
s
m
θ
d
/
2
−
s
i
n
m
θ
d
/
2
0
0
0
0
⋯
s
i
n
m
θ
d
/
2
c
o
s
m
θ
d
/
2
)
(2.5)
\pmb{R}^{d}_{\Theta,m}= \left( \begin{matrix} cos\ m\theta_1 & -sin\ m\theta_1 &0 &0 & \cdots &0 &0 \\ sin\ m\theta_1 & cos\ m\theta_1 &0 &0 & \cdots &0 &0 \\ 0 & 0 & cos\ m\theta_2 & -sin\ m\theta_2 & \cdots &0 &0 \\ 0 & 0 & sin\ m\theta_2 & cos\ m\theta_2 & \cdots &0 &0 \\ \vdots & \vdots &\vdots & \vdots & \ddots & \vdots & \vdots \\ 0 & 0 &0 &0 & \cdots & cos\ m\theta_{d/2} & -sin\ m\theta_{d/2} \\ 0 & 0 &0 &0 & \cdots & sin\ m\theta_{d/2} & cos\ m\theta_{d/2} \\ \end{matrix} \right) \tag{2.5}
RΘ,md=
cos mθ1sin mθ100⋮00−sin mθ1cos mθ100⋮0000cos mθ2sin mθ2⋮0000−sin mθ2cos mθ2⋮00⋯⋯⋯⋯⋱⋯⋯0000⋮cos mθd/2sin mθd/20000⋮−sin mθd/2cos mθd/2
(2.5)
Θ = { θ i = 1000 0 − 2 ( i − 1 ) / d , i ∈ [ 1 , 2 , . . . , d / 2 ] } (2.6) \Theta=\{ \theta_i = 10000^{-2(i-1)/d}, i \in [1,2,...,d/2] \} \tag{2.6} Θ={θi=10000−2(i−1)/d,i∈[1,2,...,d/2]}(2.6)
2.2.3 RoPE的理解
这里我们把我们求出的
f
{
q
,
k
}
(
x
m
,
m
)
=
R
Θ
,
m
d
W
{
q
,
k
}
x
m
f_{\{q,k\}}(x_m,m)=\pmb{R}^{d}_{\Theta, m}\pmb{W}_{\{q,k\}}x_m
f{q,k}(xm,m)=RΘ,mdW{q,k}xm代入attention score的计算公式
a
m
,
n
=
exp
(
q
m
T
k
n
d
)
∑
j
=
1
N
exp
(
q
m
T
k
j
d
)
(2.7)
a_{m,n}=\frac{\exp{(\frac{q^{T}_mk_n}{\sqrt{d}})}}{\sum^N_{j=1}{\exp{(\frac{q^{T}_mk_j}{\sqrt{d}})}}} \tag{2.7}
am,n=∑j=1Nexp(dqmTkj)exp(dqmTkn)(2.7)
这里我们只需要看 q m T k m q^T_{m}k_m qmTkm即可,公式的其余部分不会改变结果形式。把公式2.4代入2.7
q m T k n = ( R Θ , m d W q x m ) T ( R Θ , n d W k x n ) = x T W q R Θ , n − m d W k x n (2.8) q^{T}_{m}k_n=(\pmb{R}^d_{\Theta,m}\pmb{W}_qx_m)^T(\pmb{R}^d_{\Theta,n}\pmb{W}_kx_n)=x^T\pmb{W}_{q}R^d_{\Theta,n-m}\pmb{W}_kx_n \tag{2.8} qmTkn=(RΘ,mdWqxm)T(RΘ,ndWkxn)=xTWqRΘ,n−mdWkxn(2.8)
其中, R Θ , n − m d = ( R Θ , m d ) T R Θ , n d \pmb{R}^d_{\Theta,n-m} = (\pmb{R}^d_{\Theta,m})^T\pmb{R}^d_{\Theta,n} RΘ,n−md=(RΘ,md)TRΘ,nd,注意 R Θ d \pmb{R}^d_{\Theta} RΘd是一个正交矩阵,这保证了位置信息在处理过程中的稳定性。此外,由于 R Θ d \pmb{R}^d_{\Theta} RΘd的稀疏性,式(2.8)的计算效率不高,作者在理论上提供了另一种实现。
2.3 模块3: Transformer Block
Transformer Block 模块是llama3的核心模块,或者说,llama3为Transformer Block模块堆叠而成。Transformer Block有模块4、5、6、7组成,具体内容见对应模块。
2.4 模块4: RMSNorm
RSMNorm 是在 layer normalization 基础上优化而来,所以先简单回顾下layer normalization。(详细介绍见《Transformer(二)–论文理解:transformer 结构详解》 2.4节)
layer normalization 是根据下面的公式对
x
x
x的分布进行调整。
x
=
a
∗
x
−
x
‾
s
t
d
+
e
p
s
+
b
(2.9)
x = a * \frac{x - \overline{x}}{std + eps} + b \tag{2.9}
x=a∗std+epsx−x+b(2.9)
其中,
x
‾
\overline{x}
x是均值,
s
t
d
std
std是标准差,
e
p
s
eps
eps为一个很小的数,防止分母为零。
a
a
a、
b
b
b为参数,
b
b
b可以为零。
我们现在来看看RMSNorm做了什么优化呢,其实他对上面的试子
x
=
a
∗
x
−
x
‾
s
t
d
+
e
p
s
+
b
x = a * \frac{x - \overline{x}}{std + eps} + b
x=a∗std+epsx−x+b进行了简化。RMSNorm的计算公式如下:
x
‾
i
=
x
i
R
M
S
(
x
)
g
i
,
w
h
e
r
e
R
M
S
(
x
)
=
1
n
Σ
i
=
1
n
x
i
2
(2.10)
\overline{x}_i=\frac{x_i}{RMS(x)}g_{i}, \quad where \quad RMS(x) = \sqrt{\frac{1}{n}\Sigma^n_{i=1}{x^{2}_{i}}} \tag{2.10}
xi=RMS(x)xigi,whereRMS(x)=n1Σi=1nxi2(2.10)
上式中 g i g_i gi为权重参数,可以看出,RMSNorm移除了LayerNorm中的均值项(原式中的 x ‾ \overline{x} x项), s t d std std的计算中,也没有做减去均值的操作( s t d = 1 n Σ i = 1 n ( x i − x ‾ ) std=\sqrt{\frac{1}{n}\Sigma^n_{i=1}({x_i - \overline{x})}} std=n1Σi=1n(xi−x))。这种简化在计算效率上有一定提高,且原始论文也说了,在效果上没有明显影响。
下面附上meta llama3中RMSNorm的源码,方便大家理解。
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
2.5 模块5: Attention
llama3的attention模块主要做了4部分工作,分别是RoPE计算、分注意力分组机制实现、点积注意力计算 及 kv缓存策略实现。其中RoPE的计算在模块2中已经讲解,这里不在赘述。下文对GQA,点积注意力计算及KV缓存进行简单的讲解。
2.5.1 分组注意力机制(GQA)
llama3中的attention模块与《Attention is all you need》中使用的attention技术有些许优化。同样是使用Scaled Dot-Product Attention来计算attention score,但分组优化这块没有延续使用MHA(Multi-head Attention)技术,而是使用了GQA(Grouped-Query Attention)分组技术。具体的Scaled Dot-Product Attention 与MHA我之前在《Transformer(二)–论文理解:transformer 结构详解》一文的2.2节中,已经写的非常详细了,所以这里不再展开,只讲解下GQA。
我们知道,在MHA中,由于每个head都有独立的键和值,内存和计算成本较高,特别是在处理长序列或大批量数据时。然后就有大牛Noam Shazeer提出了MQA(Multi Query Attention)方法,将原来的h个KV对缩减为1个,所有query只使用一个共享的KV对,这种改造虽然大大减少了显存消耗,但其特征捕捉能力也受到影响。因此又提出了GQA(Grouped-Query Attention ), 将query 进行分组,每组共享一个KV对。下面是GQA原始论文中给出的对比图。
2.5.2 注意力计算(Scaled Dot-Product Attention)
llama3 计算attention score时,使用了与《attention is all you need》一文中相同的计算方法,即点积注意力方法(Scaled Dot-Product Attention),由于Scaled Dot-Product Attention在《Transformer(二)–论文理解:transformer 结构详解》 一文中的2.2.1章节有详细的讲解,这里就不再展开。
2.5.3 KV缓存
llama3在计算 attention 时采用了kv cache策略。此策略的思想是缓存每个时间步的key和value的值,在推理阶段,由于模型是自回归模式生成文本,所以当我们对过往时间步有缓存结果时,会减少计算量,提高解码效率。
下面是llama3中Attention类的源码,大家可以参考理解
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
model_parallel_size = fs_init.get_model_parallel_world_size()
.
.
.
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
keys = self.cache_k[:bsz, : start_pos + seqlen]
values = self.cache_v[:bsz, : start_pos + seqlen]
# repeat k/v heads if n_kv_heads < n_heads
keys = repeat_kv(
keys, self.n_rep
) # (bs, cache_len + seqlen, n_local_heads, head_dim)
values = repeat_kv(
values, self.n_rep
) # (bs, cache_len + seqlen, n_local_heads, head_dim)
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
values = values.transpose(
1, 2
) # (bs, n_local_heads, cache_len + seqlen, head_dim)
# 以下是Scaled Dot-Product Attention的计算
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
return self.wo(output)
2.6 模块6: ADD
此模块做了个类似残差的操作,但与残差不同的是,不是用输入减去输出,而是用输入加上输出。具体操作就是把模块4的输入与模块5的输出做加法运算。
2.7 模块7: FFN
由3个Linear组成的FeedForward网络,这里的激活函数使用的siLU。siLU的数学公式如下:
s
i
l
u
(
x
)
=
x
∗
σ
(
x
)
,
w
h
e
r
e
σ
(
x
)
i
s
t
h
e
l
o
g
i
s
t
i
c
s
i
g
m
o
i
d
.
silu(x)=x*\sigma(x), \ \ where\ \sigma(x)\ is\ the\ logistic\ sigmoid.
silu(x)=x∗σ(x), where σ(x) is the logistic sigmoid.
函数的激活曲线如下图:
在里注意下,siLU 还有一个名字叫“swish function”,这个在 pytorch 的官方文档中有说明。
下面给出主要源码。
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float],
):
super().__init__()
self.w1 = ColumnParallelLinear(
dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
)
.
.
.
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
2.8 模块8: Linear
此模块的目的是把模型中 decoder的输出从 d m o d e l d_{model} dmodel维度映射到词表大小的维度。下面是meta llama中的linear层的初始化。
self.output = ColumnParallelLinear(
params.dim, params.vocab_size, bias=False, init_method=lambda x: x
)