一、LLaMA的核心改进全景
Meta开源的LLaMA模型凭借其卓越的性能表现成为大模型发展的重要里程碑。相较于标准Transformer架构,LLaMA主要在以下几个方面进行了关键改进:
- 位置编码升级:采用旋转位置编码(Rotary Position Embedding, RoPE)
- 归一化革新:对每个 Transformer 子层的输入进行归一化(Pre-normalization)而非传统Transformer结构中对输出进行归一化(Post - normalization),并使用RMS-Norm替代传统LayerNorm。
- 激活函数优化:引入 SwiGLU 激活函数取代 ReLU 非线性函数,以提高性能。
- 注意力优化(LLaMA 2):引入分组查询注意力(Grouped Query Attention)
这些改进显著提升了模型的计算效率和长文本处理能力,今天我们来学习一下SwiGLU激活函数。
其余部件的学习链接持续更新中,欢迎关注:
- 一杯咖啡的时间学习大模型(LLM):LLaMA解读之旋转编码RoPE(含代码实现)
- 一杯咖啡的时间学习大模型(LLM):LLaMA解读之均方根误差标准化RMSNorm(含代码实现)
- 一杯咖啡的时间学习大模型(LLM):LLaMA解读之SwiGLU激活函数(含代码实现)
- 一杯咖啡的时间学习大模型(LLM):LLaMA解读之分组查询注意力(Grouped Query Attention)(含代码实现)
二、SwiGLU激活函数
2.1 改进动机
传统Transformer中广泛使用ReLU(Rectified Linear Unit)激活函数,其定义为:
ReLU ( x ) = max ( 0 , x ) \text{ReLU}(x) = \max(0, x) ReLU(x)=max(0,x)
虽然ReLU计算简单且能缓解梯度消失问题,但它存在以下局限性:
- 神经元死亡:负值输入梯度恒为0,导致部分神经元永久失活。
- 表达能力有限:输出仅为单侧线性函数,难以捕捉复杂非线性关系。
SwiGLU(Swish-Gated Linear Unit)通过结合门控机制和Swish激活函数,显著提升了模型的表达能力。其核心思想是:
- 动态门控:通过可学习的参数控制信息流动。
- 平滑非线性:Swish函数( Swish ( x ) = x ⋅ σ ( β x ) \text{Swish}(x) = x \cdot \sigma(\beta x) Swish(x)=x⋅σ(βx), β \beta β为可学习参数)相比ReLU更平滑,梯度更稳定。
2.2 数学原理
SwiGLU是GLU(Gated Linear Unit)的改进版本。给定输入张量 X X X,其计算过程如下:
SwiGLU ( X ) = ( Swish ( X W + b ) ) ⊗ ( X V + c ) \text{SwiGLU}(X) = (\text{Swish}(XW + b)) \otimes (XV + c) SwiGLU(X)=(Swish(XW+b))⊗(XV+c)
其中:
- W , V W, V W,V 为可学习的权重矩阵, b , c b, c b,c 为偏置项。
- ⊗ \otimes ⊗ 表示逐元素乘法(Hadamard积)。
- Swish \text{Swish} Swish 函数通常取 β = 1 \beta=1 β=1,即 Swish ( x ) = x ⋅ sigmoid ( x ) \text{Swish}(x) = x \cdot \text{sigmoid}(x) Swish(x)=x⋅sigmoid(x)。
与标准GLU(使用双曲正切函数)相比,SwiGLU通过Swish函数引入了更强的非线性,同时保留了门控机制对信息流的动态调节能力。
2.3 源码实现
import torch
import torch.nn as nn
import torch.nn.functional as F
class SwiGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
# 初始化两个线性变换层
self.w = nn.Linear(dim_in, dim_out, bias=False)
self.v = nn.Linear(dim_in, dim_out, bias=False)
self.bias = nn.Parameter(torch.zeros(dim_out)) # 可学习的偏置项
def forward(self, x):
# 计算Swish(Wx)和Vx,并进行逐元素相乘
swish = F.silu(self.w(x)) # F.silu即Swish函数
gate = self.v(x)
return swish * gate + self.bias # 添加偏置项
# 示例用法
if __name__ == "__main__":
batch_size = 2
seq_len = 128
dim_in = 512
dim_out = 1024
swiglu = SwiGLU(dim_in, dim_out)
x = torch.randn(batch_size, seq_len, dim_in)
output = swiglu(x)
print("输入形状:", x.shape) # [2, 128, 512]
print("输出形状:", output.shape) # [2, 128, 1024]