一杯咖啡的时间学习大模型(LLM):LLaMA解读之SwiGLU激活函数

一、LLaMA的核心改进全景

Meta开源的LLaMA模型凭借其卓越的性能表现成为大模型发展的重要里程碑。相较于标准Transformer架构,LLaMA主要在以下几个方面进行了关键改进:

  1. 位置编码升级:采用旋转位置编码(Rotary Position Embedding, RoPE)
  2. 归一化革新:对每个 Transformer 子层的输入进行归一化(Pre-normalization)而非传统Transformer结构中对输出进行归一化(Post - normalization),并使用RMS-Norm替代传统LayerNorm。
  3. 激活函数优化:引入 SwiGLU 激活函数取代 ReLU 非线性函数,以提高性能。
  4. 注意力优化(LLaMA 2):引入分组查询注意力(Grouped Query Attention)

这些改进显著提升了模型的计算效率和长文本处理能力,今天我们来学习一下SwiGLU激活函数

其余部件的学习链接持续更新中,欢迎关注:

  1. 一杯咖啡的时间学习大模型(LLM):LLaMA解读之旋转编码RoPE(含代码实现)
  2. 一杯咖啡的时间学习大模型(LLM):LLaMA解读之均方根误差标准化RMSNorm(含代码实现)
  3. 一杯咖啡的时间学习大模型(LLM):LLaMA解读之SwiGLU激活函数(含代码实现)
  4. 一杯咖啡的时间学习大模型(LLM):LLaMA解读之分组查询注意力(Grouped Query Attention)(含代码实现)

二、SwiGLU激活函数

2.1 改进动机

传统Transformer中广泛使用ReLU(Rectified Linear Unit)激活函数,其定义为:

ReLU ( x ) = max ⁡ ( 0 , x ) \text{ReLU}(x) = \max(0, x) ReLU(x)=max(0,x)

虽然ReLU计算简单且能缓解梯度消失问题,但它存在以下局限性:

  • 神经元死亡:负值输入梯度恒为0,导致部分神经元永久失活。
  • 表达能力有限:输出仅为单侧线性函数,难以捕捉复杂非线性关系。

SwiGLU(Swish-Gated Linear Unit)通过结合门控机制Swish激活函数,显著提升了模型的表达能力。其核心思想是:

  • 动态门控:通过可学习的参数控制信息流动。
  • 平滑非线性:Swish函数( Swish ( x ) = x ⋅ σ ( β x ) \text{Swish}(x) = x \cdot \sigma(\beta x) Swish(x)=xσ(βx) β \beta β为可学习参数)相比ReLU更平滑,梯度更稳定。

2.2 数学原理

SwiGLU是GLU(Gated Linear Unit)的改进版本。给定输入张量 X X X,其计算过程如下:

SwiGLU ( X ) = ( Swish ( X W + b ) ) ⊗ ( X V + c ) \text{SwiGLU}(X) = (\text{Swish}(XW + b)) \otimes (XV + c) SwiGLU(X)=(Swish(XW+b))(XV+c)

其中:

  • W , V W, V W,V 为可学习的权重矩阵, b , c b, c b,c 为偏置项。
  • ⊗ \otimes 表示逐元素乘法(Hadamard积)。
  • Swish \text{Swish} Swish 函数通常取 β = 1 \beta=1 β=1,即 Swish ( x ) = x ⋅ sigmoid ( x ) \text{Swish}(x) = x \cdot \text{sigmoid}(x) Swish(x)=xsigmoid(x)

与标准GLU(使用双曲正切函数)相比,SwiGLU通过Swish函数引入了更强的非线性,同时保留了门控机制对信息流的动态调节能力。

2.3 源码实现

import torch
import torch.nn as nn
import torch.nn.functional as F

class SwiGLU(nn.Module):
    def __init__(self, dim_in, dim_out):
        super().__init__()
        # 初始化两个线性变换层
        self.w = nn.Linear(dim_in, dim_out, bias=False)
        self.v = nn.Linear(dim_in, dim_out, bias=False)
        self.bias = nn.Parameter(torch.zeros(dim_out))  # 可学习的偏置项

    def forward(self, x):
        # 计算Swish(Wx)和Vx,并进行逐元素相乘
        swish = F.silu(self.w(x))  # F.silu即Swish函数
        gate = self.v(x)
        return swish * gate + self.bias  # 添加偏置项

# 示例用法
if __name__ == "__main__":
    batch_size = 2
    seq_len = 128
    dim_in = 512
    dim_out = 1024

    swiglu = SwiGLU(dim_in, dim_out)
    x = torch.randn(batch_size, seq_len, dim_in)
    output = swiglu(x)
    print("输入形状:", x.shape)       # [2, 128, 512]
    print("输出形状:", output.shape)  # [2, 128, 1024]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值