介绍llama2|带有SwiGlu的FeedForward

概要

    观察llama2的每个layer模型结构结果发现,经过RMS norm的Group-Head Attention的输出做了残差连接后,又做了一次经过RMS norm的FeedForward的残差连接。我们知道,FeedForward(又名MLP即多层感知机)通过非线性转换来更好地表达特征。本篇文章解密llama2的FeedForward的实现。

Swish激活函数

   swish激活函数是silu激活函更一般的形式,它的表达式为

f(x)=xsigmoid(\beta x)

上图中左图为swish激活函数在β分别取不同的的值时的函数曲线,可以发现,当\beta=1时,swish激活函数等价于silu激活函数,当\beta=10时,该激活函数逼近于relu激活函数。

    同时,我们发现该激活函数连续并且可微,而relu激活函数在0点不可微。

    在负数域取值为趋近于0的负数,不会像relu激活函数在负数域取值为0,不会使得负值神经元失活,并且具有非单调的特性,此外,\beta可以作为优化的超参。

Glu门控线性单元

    glu(gated linear unit)门控线性单元的结构如下图所示,它能够将输入分别经过权重为W的linear层变换得到A、权重为V的linear层变换得到B,然后对A用激活函数\sigma激活B哈达玛积。公式如下:

glu(X)=\sigma (WX+b)\otimes (VX+c)

    经过linear变换得到的A的激活值\otimes经过另外的linear变化得到的B,能实现门控的效果。

带有SwiGlu单元的MLP

    llama2在FeedForward network(或称为MLP)中使用SwiGLUSwiGLU相当于glu门控线性单元中的激活函数换为swish激活函数,即下式:

SwiGlu(X)=Swish(WX+b)\otimes (VX+c)

可以按下图理解,右图为huggingface的源码整体实现

SwiGlu的MLP源码

    transformers的llama源码中对于swiglu的MLP的源码实现如下:

class LlamaMLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, x):
        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

        return down_proj

    其中,源码中采用silu激活函数代替swish激活函数(相当于\beta =1)。

小结

    本文分别从理论和源码的角度分析了llama2中带有SwiGlu的MLP层,希望对你有所启发。

    点赞、关注和收藏,后续有更多深度学习的解析文章!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值