大模型学习 (Datawhale_Happy-LLM)笔记10: 动手实现一个 LLaMA2 大模型
动手实现一个 LLaMA2 大模型
Meta(原Facebook)于 2023 年 2 月发布第一款基于Transformer结构的大型语言模型LLaMA, 并于同年7月发布同系列模型LLaMA2。 我们现在就来尝试动手写一个 LLaMA2 模型。
1. 定义超参数 (需要手动设定而非通过训练数据自动学习的参数)
首先我们需要定义一些超参数,这些超参数包括模型的大小、层数、头数、词嵌入维度、隐藏层维度等等。这些超参数可以根据实际情况进行调整。这里我们自定义一个 ModelConfig 类,我们可以通过继承这个类来方便的使用 transformer 库中的一些功能,也方便在后续导出 Hugging Face 模型。
# 须要导入的库
import math
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
Model Config 类
class ModelConfig(PretrainedConfig):
model_type = "Tiny-K"
def __init__(
self,
dim: int=768, # 模型维度
n_layers: int=12, # Transformer的层数
n_heads: int=16, # 注意力机制的头数
n_kv_heads: int=8, # 键值头的数量
vocab_size: int=6144, # 词汇表的大小
hidden_dim: int=None, # 隐藏层维度
multiple_of: int=64,
norm_eps: float=1e-5, # 归一化层的eps
max_seq_len: int=512, # 最大序列长度
dropout: float=0.0, # dropout 概率
flash_attn: bool=True, # 是否使用 flash attention
**kwargs,
):
self.dim = dim
self.n_layers = n_layers
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.vocab_size = vocab_size
self.hidden_dim = hidden_dim
self.multiple_of = multiple_of
self.norm_eps = norm_eps
self.max_seq_len = max_seq_len
self.dropout = dropout
self.flash_attn = flash_attn
super().__init__(**kwargs)
args = ModelConfig()
2. 构建 RMSNorm (Root Mean Square Norm, RMSNorm)
RMSnorm可以用如下数学公式表示:
RMSNorm(x)=x1n∑i+1nxi2+ϵ⋅γ\displaystyle RMSNorm(x) = \frac{x}{\sqrt{\frac{1}{n}\sum_{i+1}^nx_{i}^2+\epsilon}}·\gammaRMSNorm(x)=n1∑i+1nxi2+ϵx⋅γ
其中:
- xi 是输入向量的第 i 个元素x_i\ 是输入向量的第\ i\ 个元素xi 是输入向量的第 i 个元素
- γ 是可学习的缩放参数(对应代码中的self.weight)\gamma\ 是可学习的缩放参数 (对应代码中的 self.weight)γ 是可学习的缩放参数(对应代码中的self.weight)
- n 是输入向量的维度向量
- ϵ 是一个小常数,用于数值稳定性(避免出现除以零的情况)\epsilon\ 是一个小常数,用于数值稳定性(避免出现除以零的情况)ϵ 是一个小常数,用于数值稳定性(避免出现除以零的情况)
# RMSNorm
class RMSNorm(nn.Module):
def __init__(self, dim:int, eps:float):
super().__init__()
# eps 是为了防止除以 0 的情况
self.eps = eps
# weight 是一个可学习的参数,全部初始化为 1
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
# 计算 RMSNorm 的核心部分
# x.pow(2).mean(-1, keepdim=True) 计算了输入 x 的平凡的均值
# torch.rsqrt 是平方根的倒数, 这样就得到了 RMSNorm 的分母部分,再加上 eps 防止分母为 0
# 最后乘以 x,得到 RMSNorm 的结果
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(

最低0.47元/天 解锁文章
1057

被折叠的 条评论
为什么被折叠?



