llama3 结构详解

1. Llama3 整体结构

  llama3 的整体结构还是延续transformer decoder 架构,其整体架构如下图左侧蓝色虚线框中所示。模型结构并不复杂,其主要组件为32个Transformer Block(32 为meta llama3 中的默认值)(见下图红色虚线框中所示)。

在这里插入图片描述

注 1 注_1 1: 下一节中会参照上图中 红色圆形序号 讲解各模块。
注 2 注_2 2: llama3的RoPE算法被拆成了3个方法来实现,上图中的模块2只包含了一个方法,另两个方法是在Attention模块(模块5)中进行的调用。

2. 模块详解

2.1 模块1: Embeddings

  llama3 的embedding 使用的是VocabParallelEmbedding这个类进行的向量转换,这个类是meta的fairscale包中的一个类,可以理解为对torch.nn.embedding做了并行化。

2.2 模块2: RoPE

  前文中已经提及llama3的RoPE算法被拆成了3个方法来实现,模块2只包含了一个方法,另两个方法是在Attention模块(模块5)中进行的调用。本小节具体按照RoPE的原始论文来讲解,主要阐述RoPE的算法原理。

2.2.1 从一个2维的例子说起 RoPE

  我们知道,寻找位置编码的基本思路是 输入位置编码经过特征提取的核心算法后的值,应能反应出两个位置之间的先后顺序(这点不是必要的)和相对位置信息。(《Transformer(二)–论文理解:transformer 结构详解》 2.1节 中有简单说明),RoPE的原始论文中给出了一个数学表达,如下式:
< f q ( x m , m ) , f k ( x n , n ) > = g ( x m , x n , m − n ) (2.1) <f_q(x_m,m),f_k(x_n,n)>=g(x_m,x_n,m-n) \tag{2.1} <fq(xm,m),fk(xn,n)>=g(xm,xn,mn)(2.1)

   f q ( x m , m ) f_q(x_m,m) fq(xm,m) f k ( x n , n ) f_k(x_n,n) fk(xn,n)分别表示 q m q_m qm k n k_n kn。式子的左侧为点积形式,之所以为点积是因为tansformer中使用的attention score计算方法通常为点积。右侧 g ( x m , x n , m − n ) g(x_m,x_n,m-n) g(xm,xn,mn)表示计算结果是与 x m , x n , m − n x_m,x_n,m-n xm,xn,mn相关的。在这里 m − n m-n mn的绝对值能反应出位置的距离,大小反应出前后顺序。 我们的目的就是找到一个这样的变换函数 f { q , k } f_{\{q,k\}} f{q,k}能表达 f q ( x m , m ) f_q(x_m,m) fq(xm,m) f k ( x n , n ) f_k(x_n,n) fk(xn,n),使 f q f_q fq f k f_k fk做点积操作后能保留 m − n m-n mn的信息。当然我们找到了,见公式2.2

  RoPE的论文中是先从2D情况下举例说明我们找到的 f ( x ) f(x) f(x)的,如下,当 d = 2 d=2 d=2时:

f q ( x m , m ) = ( W q x m ) e i m θ f k ( x n , n ) = ( W k x n ) e i n θ g ( x m , x n , m − n ) = R e [ ( W q x m ) ( W k x n ) ∗ e i ( m − n ) θ ] (2.2) f_q(x_m,m) = (\pmb{W}_{q}x_m)e^{im\theta} \\ f_k(x_n,n) = (\pmb{W}_{k}x_n)e^{in\theta} \\ g(x_m,x_n,m-n) = Re[(\pmb{W_q}x_m)(\pmb{W}_kx_n)^{*}e^{i(m-n)\theta}] \tag{2.2} fq(xm,m)=(Wqxm)eimθfk(xn,n)=(Wkxn)einθg(xm,xn,mn)=Re[(Wqxm)(Wkxn)ei(mn)θ](2.2)

  其中 R e [ ⋅ ] Re[ \cdot ] Re[]是复数的实部, ( W k x n ) ∗ (\pmb{W}_{k}x_n)^{*} (Wkxn)表示 ( W k x n ) (\pmb{W}_{k}x_n) (Wkxn)的共轭复数。
θ ∈ R \theta \in \mathbb{R} θR 是一个预设的非零常数。我们可以进一步将 f { q , k } f_{\{q,k\}} f{q,k}写成乘法矩阵:
f { q , k } ( x m , m ) = ( c o s   m θ − s i n   m θ s i n   m θ c o s   m θ ) ( W { q , k } ( 11 ) W { q , k } ( 12 ) W { q , k } ( 21 ) W { q , k } ( 22 ) ) ( x m ( 1 ) x m ( 2 ) ) (2.3) f_{\{q,k\}}(x_m,m)= \left( \begin{matrix} cos\ m\theta & -sin\ m\theta \\ sin\ m\theta & cos\ m\theta \\ \end{matrix} \right) \left( \begin{matrix} W^{(11)}_{\{q,k\}} & W^{(12)}_{\{q,k\}} \\ W^{(21)}_{\{q,k\}} & W^{(22)}_{\{q,k\}} \\ \end{matrix} \right) \left( \begin{matrix} x^{(1)}_{m} \\ x^{(2)}_{m} \end{matrix} \right) \tag{2.3} f{q,k}(xm,m)=(cos mθsin mθsin mθcos mθ)(W{q,k}(11)W{q,k}(21)W{q,k}(12)W{q,k}(22))(xm(1)xm(2))(2.3)

其中, ( x m ( 1 ) , x m ( 2 ) ) (x^{(1)}_{m},x^{(2)}_{m}) (xm(1),xm(2)) x m x_m xm在二维坐标系中的表示。同样的, g g g也可以看作一个矩阵,因此可以在2维情况下求解公式(2.1)。

2.2.2 RoPE的一般形式

  为了将我们在2D中的结果推广到任意的 x i ∈ R d x_i \in \mathbb{R}^d xiRd,我们将d维空间划分为d/2个子空间,并根据内积的线性性质将它们组合起来,将 f { q , k } ( x m , n ) f_{\{q,k\}}(x_m,n) f{q,k}(xm,n)转化为:
f { q , k } ( x m , m ) = R Θ , m d W { q , k } x m (2.4) f_{\{q,k\}}(x_m,m)=\pmb{R}^{d}_{\Theta, m}\pmb{W}_{\{q,k\}}x_m \tag{2.4} f{q,k}(xm,m)=RΘ,mdW{q,k}xm(2.4)

  其中, W { q , m } \pmb{W}_{\{q,m\}} W{q,m} 表示与query和key 所对应的转换矩阵 , x m x_m xm 为输入向量, R Θ , m d \pmb{R}^d_{\Theta,m} RΘ,md为旋转矩阵,具体如下:
R Θ , m d = ( c o s   m θ 1 − s i n   m θ 1 0 0 ⋯ 0 0 s i n   m θ 1 c o s   m θ 1 0 0 ⋯ 0 0 0 0 c o s   m θ 2 − s i n   m θ 2 ⋯ 0 0 0 0 s i n   m θ 2 c o s   m θ 2 ⋯ 0 0 ⋮ ⋮ ⋮ ⋮ ⋱ ⋮ ⋮ 0 0 0 0 ⋯ c o s   m θ d / 2 − s i n   m θ d / 2 0 0 0 0 ⋯ s i n   m θ d / 2 c o s   m θ d / 2 ) (2.5) \pmb{R}^{d}_{\Theta,m}= \left( \begin{matrix} cos\ m\theta_1 & -sin\ m\theta_1 &0 &0 & \cdots &0 &0 \\ sin\ m\theta_1 & cos\ m\theta_1 &0 &0 & \cdots &0 &0 \\ 0 & 0 & cos\ m\theta_2 & -sin\ m\theta_2 & \cdots &0 &0 \\ 0 & 0 & sin\ m\theta_2 & cos\ m\theta_2 & \cdots &0 &0 \\ \vdots & \vdots &\vdots & \vdots & \ddots & \vdots & \vdots \\ 0 & 0 &0 &0 & \cdots & cos\ m\theta_{d/2} & -sin\ m\theta_{d/2} \\ 0 & 0 &0 &0 & \cdots & sin\ m\theta_{d/2} & cos\ m\theta_{d/2} \\ \end{matrix} \right) \tag{2.5} RΘ,md= cos mθ1sin mθ10000sin mθ1cos mθ1000000cos mθ2sin mθ20000sin mθ2cos mθ2000000cos mθd/2sin mθd/20000sin mθd/2cos mθd/2 (2.5)

Θ = { θ i = 1000 0 − 2 ( i − 1 ) / d , i ∈ [ 1 , 2 , . . . , d / 2 ] } (2.6) \Theta=\{ \theta_i = 10000^{-2(i-1)/d}, i \in [1,2,...,d/2] \} \tag{2.6} Θ={θi=100002(i1)/d,i[1,2,...,d/2]}(2.6)

2.2.3 RoPE的理解

  这里我们把我们求出的 f { q , k } ( x m , m ) = R Θ , m d W { q , k } x m f_{\{q,k\}}(x_m,m)=\pmb{R}^{d}_{\Theta, m}\pmb{W}_{\{q,k\}}x_m f{q,k}(xm,m)=RΘ,mdW{q,k}xm代入attention score的计算公式
a m , n = exp ⁡ ( q m T k n d ) ∑ j = 1 N exp ⁡ ( q m T k j d ) (2.7) a_{m,n}=\frac{\exp{(\frac{q^{T}_mk_n}{\sqrt{d}})}}{\sum^N_{j=1}{\exp{(\frac{q^{T}_mk_j}{\sqrt{d}})}}} \tag{2.7} am,n=j=1Nexp(d qmTkj)exp(d qmTkn)(2.7)

这里我们只需要看 q m T k m q^T_{m}k_m qmTkm即可,公式的其余部分不会改变结果形式。把公式2.4代入2.7

q m T k n = ( R Θ , m d W q x m ) T ( R Θ , n d W k x n ) = x T W q R Θ , n − m d W k x n (2.8) q^{T}_{m}k_n=(\pmb{R}^d_{\Theta,m}\pmb{W}_qx_m)^T(\pmb{R}^d_{\Theta,n}\pmb{W}_kx_n)=x^T\pmb{W}_{q}R^d_{\Theta,n-m}\pmb{W}_kx_n \tag{2.8} qmTkn=(RΘ,mdWqxm)T(RΘ,ndWkxn)=xTWqRΘ,nmdWkxn(2.8)

其中, R Θ , n − m d = ( R Θ , m d ) T R Θ , n d \pmb{R}^d_{\Theta,n-m} = (\pmb{R}^d_{\Theta,m})^T\pmb{R}^d_{\Theta,n} RΘ,nmd=(RΘ,md)TRΘ,nd,注意 R Θ d \pmb{R}^d_{\Theta} RΘd是一个正交矩阵,这保证了位置信息在处理过程中的稳定性。此外,由于 R Θ d \pmb{R}^d_{\Theta} RΘd的稀疏性,式(2.8)的计算效率不高,作者在理论上提供了另一种实现。

2.3 模块3: Transformer Block

  Transformer Block 模块是llama3的核心模块,或者说,llama3为Transformer Block模块堆叠而成。Transformer Block有模块4、5、6、7组成,具体内容见对应模块。

2.4 模块4: RMSNorm

  RSMNorm 是在 layer normalization 基础上优化而来,所以先简单回顾下layer normalization。(详细介绍见《Transformer(二)–论文理解:transformer 结构详解》 2.4节)
  layer normalization 是根据下面的公式对 x x x的分布进行调整。
x = a ∗ x − x ‾ s t d + e p s + b (2.9) x = a * \frac{x - \overline{x}}{std + eps} + b \tag{2.9} x=astd+epsxx+b(2.9)
其中, x ‾ \overline{x} x是均值, s t d std std是标准差, e p s eps eps为一个很小的数,防止分母为零。 a a a b b b为参数, b b b可以为零。
  我们现在来看看RMSNorm做了什么优化呢,其实他对上面的试子 x = a ∗ x − x ‾ s t d + e p s + b x = a * \frac{x - \overline{x}}{std + eps} + b x=astd+epsxx+b进行了简化。RMSNorm的计算公式如下:
x ‾ i = x i R M S ( x ) g i , w h e r e R M S ( x ) = 1 n Σ i = 1 n x i 2 (2.10) \overline{x}_i=\frac{x_i}{RMS(x)}g_{i}, \quad where \quad RMS(x) = \sqrt{\frac{1}{n}\Sigma^n_{i=1}{x^{2}_{i}}} \tag{2.10} xi=RMS(x)xigi,whereRMS(x)=n1Σi=1nxi2 (2.10)

  上式中 g i g_i gi为权重参数,可以看出,RMSNorm移除了LayerNorm中的均值项(原式中的 x ‾ \overline{x} x项), s t d std std的计算中,也没有做减去均值的操作( s t d = 1 n Σ i = 1 n ( x i − x ‾ ) std=\sqrt{\frac{1}{n}\Sigma^n_{i=1}({x_i - \overline{x})}} std=n1Σi=1n(xix) )。这种简化在计算效率上有一定提高,且原始论文也说了,在效果上没有明显影响。

下面附上meta llama3中RMSNorm的源码,方便大家理解。

class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

2.5 模块5: Attention

  llama3的attention模块主要做了4部分工作,分别是RoPE计算、分注意力分组机制实现、点积注意力计算 及 kv缓存策略实现。其中RoPE的计算在模块2中已经讲解,这里不在赘述。下文对GQA,点积注意力计算及KV缓存进行简单的讲解。

2.5.1 分组注意力机制(GQA)

  llama3中的attention模块与《Attention is all you need》中使用的attention技术有些许优化。同样是使用Scaled Dot-Product Attention来计算attention score,但分组优化这块没有延续使用MHA(Multi-head Attention)技术,而是使用了GQA(Grouped-Query Attention)分组技术。具体的Scaled Dot-Product Attention 与MHA我之前在《Transformer(二)–论文理解:transformer 结构详解》一文的2.2节中,已经写的非常详细了,所以这里不再展开,只讲解下GQA。

  我们知道,在MHA中,由于每个head都有独立的键和值,内存和计算成本较高,特别是在处理长序列或大批量数据时。然后就有大牛Noam Shazeer提出了MQA(Multi Query Attention)方法,将原来的h个KV对缩减为1个,所有query只使用一个共享的KV对,这种改造虽然大大减少了显存消耗,但其特征捕捉能力也受到影响。因此又提出了GQA(Grouped-Query Attention ), 将query 进行分组,每组共享一个KV对。下面是GQA原始论文中给出的对比图。
在这里插入图片描述

2.5.2 注意力计算(Scaled Dot-Product Attention)

  llama3 计算attention score时,使用了与《attention is all you need》一文中相同的计算方法,即点积注意力方法(Scaled Dot-Product Attention),由于Scaled Dot-Product Attention在《Transformer(二)–论文理解:transformer 结构详解》 一文中的2.2.1章节有详细的讲解,这里就不再展开。

2.5.3 KV缓存

   llama3在计算 attention 时采用了kv cache策略。此策略的思想是缓存每个时间步的key和value的值,在推理阶段,由于模型是自回归模式生成文本,所以当我们对过往时间步有缓存结果时,会减少计算量,提高解码效率。

下面是llama3中Attention类的源码,大家可以参考理解

class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        model_parallel_size = fs_init.get_model_parallel_world_size()
		.
		.
		.
    

    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor],
    ):
        bsz, seqlen, _ = x.shape
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

        self.cache_k = self.cache_k.to(xq)
        self.cache_v = self.cache_v.to(xq)

        self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
        self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv

        keys = self.cache_k[:bsz, : start_pos + seqlen]
        values = self.cache_v[:bsz, : start_pos + seqlen]

        # repeat k/v heads if n_kv_heads < n_heads
        keys = repeat_kv(
            keys, self.n_rep
        )  # (bs, cache_len + seqlen, n_local_heads, head_dim)
        values = repeat_kv(
            values, self.n_rep
        )  # (bs, cache_len + seqlen, n_local_heads, head_dim)

        xq = xq.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
        keys = keys.transpose(1, 2)  # (bs, n_local_heads, cache_len + seqlen, head_dim)
        values = values.transpose(
            1, 2
        )  # (bs, n_local_heads, cache_len + seqlen, head_dim)
        # 以下是Scaled Dot-Product Attention的计算
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores + mask  # (bs, n_local_heads, seqlen, cache_len + seqlen)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        output = torch.matmul(scores, values)  # (bs, n_local_heads, seqlen, head_dim)
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
        return self.wo(output)

2.6 模块6: ADD

   此模块做了个类似残差的操作,但与残差不同的是,不是用输入减去输出,而是用输入加上输出。具体操作就是把模块4的输入与模块5的输出做加法运算。

2.7 模块7: FFN

  由3个Linear组成的FeedForward网络,这里的激活函数使用的siLU。siLU的数学公式如下:
s i l u ( x ) = x ∗ σ ( x ) ,    w h e r e   σ ( x )   i s   t h e   l o g i s t i c   s i g m o i d . silu(x)=x*\sigma(x), \ \ where\ \sigma(x)\ is\ the\ logistic\ sigmoid. silu(x)=xσ(x),  where σ(x) is the logistic sigmoid.

函数的激活曲线如下图:
在这里插入图片描述
在里注意下,siLU 还有一个名字叫“swish function”,这个在 pytorch 的官方文档中有说明。

下面给出主要源码。


class FeedForward(nn.Module):
    def __init__(
        self,
        dim: int,
        hidden_dim: int,
        multiple_of: int,
        ffn_dim_multiplier: Optional[float],
    ):
        super().__init__()

        self.w1 = ColumnParallelLinear(
            dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
        )
        .
        .
        .
  

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

2.8 模块8: Linear

  此模块的目的是把模型中 decoder的输出从 d m o d e l d_{model} dmodel维度映射到词表大小的维度。下面是meta llama中的linear层的初始化。

 self.output = ColumnParallelLinear(
            params.dim, params.vocab_size, bias=False, init_method=lambda x: x
        )
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值