代码会涉及 KV Cache 和 RoPE旋转位置编码以及基于调整RoPE旋转角度的长度外推方法
KV Cache知识请看:https://zhuanlan.zhihu.com/p/684078417
RoPE旋转位置编码以及基于调整RoPE旋转角度的长度外推方法请看 https://zhuanlan.zhihu.com/p/684078417
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.configuration_utils import PretrainedConfig
from .activations import ACT2FN
from .configuration_llama import LlamaConfig
from .utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_available,
logging,
replace_return_docstrings,
)
# 判断是否使用 flashattention
if is_flash_attn_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "LlamaConfig"
def _get_unpad_data(padding_mask):
"""处理填充掩码数据,并返回一些处理后的信息"""
seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32) # 计算了每个样本序列中的非填充标记的数量
indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten() # 将填充掩码张量展平为一维,并找到其中所有非零元素的索引。torch.nonzero 函数返回所有非零元素的索引,然后使用 flatten() 方法将结果展平为一维张量。这些索引表示了每个样本序列中非填充标记的位置
max_seqlen_in_batch = seqlens_in_batch.max().item() # 这一行代码找到批次中最长的样本序列的长度。首先,使用 max() 方法找到 seqlens_in_batch 张量中的最大值,然后使用 item() 方法将结果转换为 Python 标量。
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) # 计算了累积序列长度,并进行了填充操作。首先,使用 torch.cumsum 函数计算了 seqlens_in_batch 张量在 dim=0 维度上的累积和,即得到了累积的序列长度。然后,使用 F.pad 函数在结果的左侧填充一个元素值为0的元素,以便与索引对齐。填充的宽度为 (1, 0)
return (
indices, # 包含了所有非填充标记的位置索引的一维张量
cu_seqlens, # 包含了累积序列长度的张量,经过了左侧填充
max_seqlen_in_batch, # 批次中最长的样本序列的长度
)
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
"""用于生成causal mask, 确保在自注意力机制中, 模型只能关注当前位置及之前的位置, 以避免信息泄露和未来信息访问。
得到的掩码张量,mask部分是当前数据类型的-inf,非mask部分是0,下三角元素都为零,对角线元素也为0
Make causal mask used for bi-directional self-attention.
参数: input_ids_shape是一个torch.Size对象,表示输入张量 input_ids 的形状,通常是(batch_size, sequence_length)的形式。
dtype是一个torch.dtype对象,表示返回的 causal mask 的数据类型。
device是一个torch.device对象,表示返回的 causal mask 所在的设备
past_key_values_length是一个整数,表示过去键值对的长度,用于处理有状态的自注意力。"""
"""
假设 input_ids_shape 为 (2, 4), dtype 为 torch.float32, device 为 cuda, past_key_values_length 为 2
mask:tensor([[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38]])
mask_cond: tensor([0, 1, 2, 3])
new_cond = (mask_cond + 1).view(mask.size(-1), 1): [[1],
[2],
[3],
[4]]
mask_cond < new_cond: [[True, False, False, False],
[True, True, False, False],
[True, True, True, False],
[True, True, True, True]]
mask after applying condition: tensor([[0., -3.4028e+38, -3.4028e+38, -3.4028e+38],
[0., 0., -3.4028e+38, -3.4028e+38],
[0., 0., 0., -3.4028e+38],
[0., 0., 0., 0.]])
如果存在过去键值对的长度,添加额外的掩码:
mask after adding past_key_values_length: tensor([[0., 0., 0., -3.4028e+38, -3.4028e+38, -3.4028e+38],
[0., 0., 0., 0., -3.4028e+38, -3.4028e+38],
[0., 0., 0., 0., 0., -3.4028e+38],
[0., 0., 0., 0., 0., 0.]])
final causal mask: tensor([[[[0., 0., 0., -3.4028e+38, -3.4028e+38, -3.4028e+38],
[0., 0., 0., 0., -3.4028e+38, -3.4028e+38],
[0., 0., 0., 0., 0., -3.4028e+38],
[0., 0., 0., 0., 0., 0.]]],
[[[0., 0., 0., -3.4028e+38, -3.4028e+38, -3.4028e+38],
[0., 0., 0., 0., -3.4028e+38, -3.4028e+38],
[0., 0., 0., 0., 0., -3.4028e+38],
[0., 0., 0., 0., 0., 0.]]]])
"""
# input_ids_shape是典型的输入格式.是做embedding layer输入之前的张量格式
# input_ids_shape就是batch["input_ids"].shape,这里batch就是train_dataloader或eval_dataloader的一个批次数据
bsz, tgt_len = input_ids_shape # 从input_ids_shape中获取batch size(bsz)和目标长度(tgt_len)。
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) # 创建一个形状为(tgt_len, tgt_len)的张量,其中所有元素初始化为dtype数据类型的最小值,表示初始的causal mask
# 下面两行代码,用于得到一个下三角为0的矩阵, 其实可以通过torch.tril或triu来实现
mask_cond = torch.arange(mask.size(-1), device=device) # 创建一个长度为tgt_len的张量,从0到tgt_len-1。将用于构造causal mask的下三角部分
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) # 将mask张量的下三角部分全部设置为0,实现causal mask,确保每个位置只能注意到之前和当前的位置
mask = mask.to(dtype) # 将mask张量的数据类型转换为dtype,以匹配输入参数的数据类型
# 在推理使用KV cache时就会使用KV cache,KV cache解释看: https://zhuanlan.zhihu.com/p/684078417
if past_key_values_length > 0: # 如果有过去键值对的长度,就在mask的最后一个维度上连接一个形状为(tgt_len, past_key_values_length)的全零张量。这个操作是为了考虑有状态的自注意力机制,在处理过去的键值对时需要增加一些额外的掩码
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
# mask[None, None, :, :]添加两个维度, 然后复制batch size份
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) # 最后,将mask张量的维度扩展为(bsz, 1, tgt_len, tgt_len + past_key_values_length),其中bsz是批次大小,tgt_len是序列长度,past_key_values_length是过去键值对的长度。这个扩展操作是为了适应多个样本的输入,每个样本都有自己独立的causal mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
在自注意力机制中,我们希望模型在计算注意力分数时只关注当前位置及之前的位置,而不考虑之后的位置,以避免信息泄露和未来信息的访问。
因此,我们需要将掩码扩展为一个因果掩码,在这个因果掩码中,每个位置只能关注到当前位置及之前的位置。 通过生成补码张量 inverted_mask,
我们可以方便地将原掩码中的0变为1,1变为0,并在后续操作中使用这个补码来生成最终的扩展掩码,以确保在注意力计算中只能关注到当前位置及之前的位置。
"""
# mask的shape是(batch_size, source sequence length)
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
# 将形状为 [bsz, seq_len] 的注意力掩码扩展为形状为 [bsz, 1, tgt_seq_len, src_seq_len] 的注意力掩码
# 在掩码张量的第二和第三维度上添加了两个新的维度,将其形状变为 [bsz, 1, 1, src_len]。然后,我们使用 expand 方法将掩码张量在第三维度(目标序列长度)上进行复制扩展,以匹配目标序列的长度。最后,我们将扩展后的张量转换为指定的数据类型 dtype
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
# 用1减去expanded_mask, 于是mask部分变成1.0, 而非mask部分则变为0,
inverted_mask = 1.0 - expanded_mask # 生成一个与扩展掩码相同形状的张量,其中的值是扩展掩码的补码,即原掩码中为0的位置变为1,为1的位置变为0
# 最终生成了一个形状为 [bsz, 1, tgt_seq_len, src_seq_len] 的扩展掩码张量,并确保了在原掩码为0的位置上填充了最小值,以便在注意力计算中正确屏蔽无效位置的信息
# 将inverted_mask中的1用最小值替换,也就是mask用最小值替换, 非Mask部分用0替换,然后结果可以和QK^T的结果直接相加,进行位置掩码操作
return inverted_mask.fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) # 使用 masked_fill 方法根据补码张量 inverted_mask,将其转换为布尔类型张量,并在其为 True(即原掩码为0的位置)的位置上填充为指定数据类型 dtype 的最小值
class LlamaRNSNorm(nn.Module):
"""LlamaRNSNorm is equivalent(等价于) to T5LayerNorm"""
def __init__(self, hidden_size, eps=1e-6):
# RMSNorm 相当于 LayerNorm 去掉了μ这一项, 直接除以 σ
super().__init__()
# 权重,和hidden states相同尺寸.具体的计算过程中,进行相乘
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps # eps防止取倒数之后分母为0
def forward(self, hidden_states):
# 先保存输入hidden_state的数据类型
# 中间计算过程使用torch.float32数据类型
# 在计算最后会将其恢复为hidden_state的数据类型
input_dtype = hidden_states.dtype
# 将隐藏层张量转换为float32数据类型
hidden_states = hidden_states.to(torch.float32)
# 在最后一个维度计算均方值
variance = hidden_states.power(2).mean(-1, keepdim=True)
# rsqrt是开根号求倒数
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
# weight 是末尾乘的可训练参数,即 g_i
return self.weight * hidden_states.to(input_dtype)
class LlamaRotaryEmbedding(nn.Module):
""" 使用绝对编码方式完成了相对位置编码.有利于提升模型的长度外推能力.
LlamaRotaryEmbedding的 key 和 query 最终的位置编码如下, 乘以位置矩阵,
当前embedding只是返回了 cos 和 sin矩阵,后面的实现在 apply_rotary_pos_emb
[q0 [cosmθ0 [-q1 [sinmθ0
q1 cosmθ0 q0 sinmθ0
q2 cosmθ1 -q3 sinmθ1
q3 cosmθ1 q2 sinmθ1
. x . + . x .
. . . .
. . . .
qd-2 cosmθ(d/2)-1 -qd-1 sinmθ(d/2)-1
qd-1] cosmθ(d/2)-1] qd-2] sinmθ(d/2)-1]
公式的 θi = 1/10000^(2i/dim), m = (0, 1, ..., seq_len-1)
文本外推
训练数据的长度较短,推理数据的长度较长.
这个短是多少,一般是8k或4k,看模型大小和显存
这个长是多少,一般是32k或100k,甚至是1000k(希望能够达到吧!)
Position Interpolation: 目标长度是原来的n倍,则旋转弧度减小至原来的1/n。
NTK-Aware Interpolation:增大RoPE的base,保留高频信息;高频分量旋转速度降幅低,低频分量旋转速度降幅高;在高频部分进行外推,低频部分进行内插。
NTK-by-parts Interpolation:不改变高频部分,仅缩小低频部分的旋转弧度。
Dynamic NTK Interpolation:推理长度小于等于训练长度时,不进行插值;推理长度大于训练长度时,每一步都通过NTK-Aware插值动态放大base。
YaRN: NTK-by-parts Interpolation与注意力分布修正策略的结合,通过温度系数修正注意力分布。
具体RoPE外推方法可以参考链接:https://zhuanlan.zhihu.com/p/670280576
"""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
# 初始化基本参数
# dim是head_dim的大小,也就是每个head的长度. Llama使用GQA方法,因此query的head数量是key和value的整数倍
self.dim = dim
# 模型支持最大长度,不是说不能运算更大长度的输入,而是这种输入超过一定长度,模型的性能,急剧下降
self.max_position_embeddings = max_position_embeddings
# 这里的base是\theta中的10000,在有些模型中这个值非常大,例如Qwen中base=1000000
# 很大的话,也是有利于长度外推的,这个在后面的dynamic ntk中可以看到
self.base = base
# 计算 θi = 1/10000^(2i/dim)
inf_freq = 1.0 / (self.base ** (torch.arrange(0, self.dim, 2).float().to(device) / self.dim))
# 注册逆频率为buffer,意味着它是一个持久的状态但不会被认为是模型的参数,不会被反向传播优化
self.register_buffer("inf_freq", inf_freq, persistent=False)
# 在初始化时,计算并缓存 cos 和 sin 函数值,以便在前向传播时快速使用
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq_device, dtype=torch.get_default_dtype()
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
"""参数: seq_len表示缓存的序列长度, device表示缓存所在的设备, dtype表示缓存的数据类型。"""
self.max_seq_len_cached = seq_len
# 得到上面公式中的m,是一个长度序列,以1为步进.如果不以1为步进呢?
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
# 得到 mθ的乘积,然后再做cos和sin运算
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# 将计算得到的频率值复制一份并拼接,形成完整的嵌入矩阵
# 此处和原始公式不同,θ0 和 θ0 不再相邻,而是分在向量的前半部分和后半部分
# 论文是 cosmθ0, cosmθ0, cosmθ1, cosmθ1, ..., cosmθd/(2-1), cosmθd/(2-1)
# 这里是 cosmθ0, cosmθ1, ..., cosmθd/(2-1), cosmθ0, cosmθ1,..., cosmθd/(2-1)
emb = torch.cat((freqs, freqs), dim=-1)
# 构造cos和sin缓存。这些缓存用于在前向传播过程中计算旋转嵌入的cos和sin值
# 分别得到cos和sin计算张量,并将维度进行unsqueeze(),使得结果具有形状[1, 1, seq_len, dim],方便和q,k,v进行广播计算
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# 如果设置了seq_len,且seq_len比默认的max_seq_len_cached大,则重新计算cos和sin
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
# 根据当前序列长度返回对应的cos和sin值
return(
self.cos_cached[:, :, seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, seq_len, ...].to(dtype=x.dtype),
)
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
# 将线性缩放加入到旋转位置编码中,使得在Llama的旋转嵌入过程中,可以通过scaling_factor参数对输入序列的位置进行线性缩放,以适应不同长度的序列。
# 意思很明确,就是输入序列越长,base就需要越大,才能得到较好的性能. 底层原因是,base越大,插值后缩小旋转弧度,达到长度扩展的目标.
# RoPE旋转位置编码以及基于调整RoPE旋转角度的长度外推方法看 https://zhuanlan.zhihu.com/p/684078417
def __init__(self, dim, max_position_embedding=2048, base=10000, device=None, scaling_factor=1.0):
# 调用父类LlamaRotaryEmbedding的构造函数,以确保正确地初始化该扩展类
super().__init__(dim, max_position_embedding, base, device)
self.scaling_factor = scaling_factor # scaling_factor表示线性缩放因子(默认为1.0)。它用于对输入序列的位置进行线性缩放。
def _set_cos_sin_cache(self, seq_len, device, dtype):
# 线性插值比较简单,就是直接将原来的各整数之间插上带小数点的值
# 经过少量数据微调后,可以较好的扩展到较长的文本上
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = t / self.scaling_factor # 将序列t除以线性缩放因子scaling_factor,以实现线性缩放
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=1)
self.register_buffer("self.cosched", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("self.sinched", emb.sin()[None, None, :, :].to(dtype), persistent=False)
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
"""LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
# 引入了动态NTK(Neural Tangent Kernel)缩放。使得在Llama的旋转嵌入过程中,可以根据输入序列的长度动态调整旋转嵌入的频率
# 推理长度小于等于训练长度时,不进行插值;推理长度大于训练长度时,每一步都通过NTK-Aware插值动态放大base。
# RoPE旋转位置编码以及基于调整RoPE旋转角度的长度外推方法看 https://zhuanlan.zhihu.com/p/684078417
def __init__(self, dim, max_position_embedding=2048, base=10000, device=None, scaling_factor=1.0):
super().__init__(dim, max_position_embedding, base, device)
self.scaling_factor = scaling_factor # scaling_factor表示动态NTK缩放因子(默认为1.0)。它用于动态调整旋转嵌入的频率。
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
# 如果提供的序列长度大于最大位置嵌入长度,则需要动态调整旋转嵌入的频率
if seq_len > self.max_position_embeddings:
# 根据动态NTK缩放因子和当前序列长度,计算新的旋转嵌入的基础值
# Dynamic NTK方法的关键计算公式,通过修改base值来改变每个位置的频率
base = self.base * (
(self.scaling_factor * seq_len / self.max_seq_len_cached) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1 / (base ** (torch.arange(0, self.dim, 2).float().to(dtype) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=1)
self.register_buffer("self.cosched", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("self.sinched", emb.sin()[None, None, :, :].to(dtype), persistent=False)
def rotate_half(x):
"""
将输入张量的后一半维度进行旋转。它的作用是将输入张量x的维度分成两部分, 然后将后一半维度的内容
放到前一半维度的位置,同时将前一半维度的内容放到后一半维度的位置。这样就实现了旋转的效果。
原始矩阵 x_example [1.0, 2.0, 3.0, 4.0]
[5.0, 6.0, 7.0, 8.0]
应用rotate_half(x)后的矩阵 [-3.0, -4.0, 1.0, 2.0]
[-7.0, -8.0, 5.0, 6.0]
此处和原始论文推导中不同,正负号不是间隔的,而是分前半部分和后半部分。但对于结果没有影响
"""
# 三个点,称为省略号对象,用于表示多维张量中的所有前面的维度,而不需要显式指定它们
x1 = x[..., : x.shape[-1] // 2] # 获取最后一个维度的前一半
x2 = x[..., x.shape[-1] // 2 :] # 获取最后一个维度的后一半
return torch.cat((-x2, x1), dim=-1) # 在最后一个维度上拼接-x2和x1
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# cos和sin张量的前两个维度总是1(这通常是因为它们被设计为可广播到不同批次和头部的维度上),去除多余的维度,
# 可以使用.squeeze(1).squeeze(0)来移除这两个维度,得到形状为[seq_len, dim]的张量
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
# cos和sin张量根据position_ids(表示序列中每个元素的位置ID)来选择对应的位置编码,使用.unsqueeze(1)在第二个维度
#(索引为1的位置)上增加一个维度,使其形状变为[bs, 1, seq_len, dim],以便进行广播操作,以适配查询和键向量的维度。
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
# 已经准备好了旋转位置编码计算公式需要的元素,下面是使用公式分别对k和q进行旋转位置编码
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class LlamaMLP(nn.Module):
""" Llama模型中多了个线性层来对输出做门操作
1. 存在 gate_proj 和 up_proj 两个映射:这两个映射的存在暗示了可能有一种机制在控制信息流。具体到代码中,gate_proj
和 up_proj 的输出被用于生成中间状态,这种处理方式与常规 MLP 中简单的层序列不同。
2. 使用激活函数和乘法操作: gate_proj 的输出经过激活函数处理后, 与 up_proj 的输出相乘。这一步可以被视为一种门控操作,
其中 gate_proj 的输出经过激活函数后得到的结果控制着 up_proj 输出的信息量。
3. 在有分片的情况下的操作:当配置中的 pretraining_tp 参数大于1时,通过对权重进行分片和重组,实现了一种复杂的信息处理流程。
虽然这种分片操作本身不直接等同于门控机制,但它表明了在处理信息时采用了一种比标准 MLP 更复杂的策略。
"""
def __init__(self, config):
super().__init__()
self.config = config
# 隐藏层大小
self.hidden_size = config.hidden_size
# 中间层的大小, 一般是hidden_size的偶数倍
self.intermediate_size = config.internidiate_size
# 门线性层
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
# 维度上升层
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
# 维度下降层
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
# 激活函数,在门操作中使用
self.act_fn = ACT2FN[config.hidden_act] # 选择配置中指定的激活函数,将其保存在self.act_fn中。ACT2FN是一个包含不同激活函数的字典。
def forward(self, x):
# 检查是否使用模型分解或并行技术,则进行张量并行计算
if self.config.pretraining_tp > 1:
slice = self.intermediate_size // self.config.pretraining_tp # 计算每个预训练分片的大小
# gate_proj_slices是个列表,每个元素尺寸为[slice, self.hidden_size]
# self.gate_proj.weight的shape是[self.intermediate_size, self.hidden_size]
gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) # 将gate_proj的权重在第0维(输出特征维)上分片
# up_proj_slices,每个元素尺寸为[slice, self.hidden_size]
# self.up_proj.weight的shape是[self.intermediate_size, self.hidden_size]
up_proj_slices = self.up_proj.weight.split(slice, dim=0) # 将up_proj的权重在第0维(输出特征维)上分片
# down_proj_slices是个列表,每个元素尺寸为[self.hidden_size, slice]
# self.down_proj.weight的shape是[self.hidden_size, self.intermediate_size]
down_proj_slices = self.down_proj.weight.split(slice, dim=1) # 将down_proj的权重在第1维(输入特征维)上分片
# 从上面代码看,split操作是对self.intermediate_size进行的
# 对每个分块进行线性操作,然后进行拼接
# the shape of x is: [batch_size, seq_len, hidden_size]
# F.linear(x, gate_proj_slices[i]) -> [batch_size, seq_len, slice]
# dim = -1 cat之后shape变为[batch_size, seq_len, intermediate_size]
# 将输入张量x分别与gate_proj_slices中的每个块进行线性投影,并将投影结果拼接起来,得到gate_proj。这样实现了分片投影。
# F.linear函数的工作原理是将输入数据和权重矩阵进行矩阵乘法,然后(如果提供了)加上偏置。这个过程可以用下面的公式表示:
# output =input × weight^T + bias
gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
# 对中间变量进行split,dim=-1也就是dim=2
# intermediate_states的shape是:
# [batch_size, seq_len, intermediate_size] -> pretraining_tp个
# [batch_size, seq_len, slice]
intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
down_proj = [F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)]
down_proj = sum(down_proj)
else:
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor :
""" 这个函数使用,主要是因为模型使用Group Query Attention,也就是query的head是key和value的head数量的整数倍
将输入的张量hidden_states在指定的维度上进行重复,重复的次数由参数n_rep指定。重复后的张量形状将在指定维度上增加元素,并且每个元素的重复次数为n_rep
隐藏状态从(batch,num_key_value_heads,seqlen,head_dim)到(batch,num_attention_heads,seqlen,head_dim)
参数: n_rep参数表示要重复的次数,n_rep为1,那么直接返回输入的张量x。否则,使用torch.repeat_interleave函数将x在第2个维度(索引从0开始)上重复n_rep次。
"""
# hidden_states 隐藏层的信息
batch, num_key_value_heads, slen, head_dim = hidden_states
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class LlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper
以下代码,可以通过设置参数来实现multi-head attention、multi-query attention、group query attention"""
def __init__(self, config: LlamaConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size # 隐藏层大小
self.num_heads = config.num_attention_heads # 注意力头数量
self.head_dim = self.hidden_size // self.num_heads # 每个头的维度
self.num_key_values_heads = config.num_key_value_heads # key 和 value头的数量,group query attention 和 multi-query attention 使用
self.num_key_values_groups = self.num_key_values_heads // self.num_key_values_heads # GQA的分组数量, GQA介于MHA和MQA之间
self.max_position_embeddings = config.max_position_embeddings # 最大位置长度
self.rope_theta = config.rope_theta # 基础llama是10000,很多文章中用base来表示
# hidden_size一定可以被head_dim整除,否则就会报错
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
# self.num_heads 和 self.num_key_values_heads 是因为使用Group Query Attention时, query的head是key和value的head数量的整数倍
# 通过线性投影将隐藏状态投影到查询、键和值的维度
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) # q变量的映射线性层
self.k_proj = nn.Linear(self.hidden_size, self.num_key_values_heads * self.head_dim, bias=config.attention_bias) # k变量的映射线性层
self.v_proj = nn.Linear(self.hidden_size, self.num_key_values_heads * self.head_dim, bias=config.attention_bias) # v变量的映射线性层
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) # o输出变量的映射线性层
# 初始化 RoPE位置编码张量
self._init_rope()
def _init_rope(self):
if self.rope_theta is None:
# 不做rope外推,基础RoPE
self.rotary_emb = LlamaRotaryEmbedding(
dim=self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta)
else:
# 如果做rope的外推,首先提取外推需要的参数
# 外推类型,使用内插法还是动态NTK,其实目前YaRN的效果比较,同时考虑NTK和attention score的尺度变化
scaling_type = self.config.rope_scaling["type"]
# 内插值,外推长度是scaling_factor * max_position_embeddings,一般为2或4,效果尚可
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
dim=self.head_dim,
max_position_embedding=self.max_position_embeddings,
base=self.rope_theta,
scaling_factor=scaling_factor,
)
elif scaling_type == "dynamic":
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
dim=self.head_dim,
max_position_embedding=self.max_position_embeddings,
base=self.rope_theta,
scaling_factor=scaling_factor,
)
# 其实新增YaRN在这里也是可以的
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
# 将shape进行变换,最终变成 [batch_size, num_heads, seq_len, head_dim]
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor, # [batch_size, seq_len, hidden_size]
attention_mask: Optional[torch.Tensor], # [batch_size, seq_len]
position_ids: Optional[torch.Tensor], # [batch_size, seq_len]
past_key_value: Optional[Tuple[torch.Tensor]] = None, # [batch_size, past_seq_len, hidden_size]
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
if self.config.pretraining_tp > 1:
# 对key和value进行分片, tensor parallel
key_value_slicing = (self.num_key_values_heads * self.head_dim) // self.config.pretraining_tp
# 与 MLP 中的 slice 一样操作
# pretraining_tp个[key_value_slicing, hidden_size]
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
# hidden_states的shape是:[batch_size, seq_len, hidden_size] ->
# [batch_size, seq_len, key_value_slicing] ->
# [batch_size, seq_len, self.num_heads * self.head_dim]
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
query_states = torch.cat(query_states, dim=-1)
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
key_states = torch.cat(key_states, dim=-1)
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1)
else:
# 通过线性投影将隐藏状态投影到查询、键和值的维度
# [batch_size, seq_len, self.num_heads * self.head_dim]
query_states = self.q_proj(self.hidden_states)
# [batch_size, seq_len, self.num_key_value_heads * self.head_dim]
key_states = self.q_proj(self.hidden_states)
# [batch_size, seq_len, self.num_key_value_heads * self.head_dim]
value_states = self.q_proj(self.hidden_states)
# 将num_heads和q_len互换维度, 把要计算的数值移动到矩阵最后2维度, 后面attention计算方便
# [batch_size, seq_len, num_heads, head_dim] -> [batch_size, num_heads, seq_len, head_dim]
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
# 将num_key_value_heads和q_len互换维度, 把要计算的数值移动到矩阵最后2维度, 后面attention计算方便
# [batch_size, seq_len, num_key_values_heads, head_dim] -> [batch_size, num_key_values_heads, seq_len, head_dim]
key_states = key_states.view(bsz, q_len, self.num_key_values_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_values_heads, self.head_dim).transpose(1, 2)
# 记录 seq_len, 会有变动
kv_seq_len = key_states.shape[-2]
# 记录信息,用于KV Cache, KV cache解释看: https://zhuanlan.zhihu.com/p/684078417
# 如果有past key value,则添加到前面, kv_seq_len也有所变化
if past_key_value is not None:
# 如果有past的长度,比如kv cache,就需要做这个操作,那么就要修改长度
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
# reuse k, v, self_attention
# key_states的shape是[batch_size, num_heads, seq_len, head_dim] ->
# [batch_size, num_heads, past_seq_len + seq_len, head_dim]
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
# 继续缓存张量
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
# 因为是Group Query Attention,所以复制self.num_key_value_groups份
# 目的是让query_states, key_states和value_states的heads相同
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
# transformer论文中的attention操作
# 先做QK^T / sqrt(d),得到softmax操作之前的score
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
# 维度尺寸要求,防止计算出错
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
# 如果有attention_mask,则在softmax之前做加法,掩码部分为-inf,未被掩码部分为0
# 最开始的两个掩码函数就是完成这个操作的
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
# 这里有个细节,在做softmax时,使用float32数据格式,计算结束后转换为前面的数据格式
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
# 最后和输出张量相乘得到输出注意力
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
# 下面对输出进行形状变换,使其能够符合后面 MLP 层计算的输入形状, 其实就是计算attention的时候,把要相乘的矩阵移到最后两个维度,这里移动回来
# [batch_size, num_heads, seq_len, head_dim] -> [batch_size, seq_len, num_heads, head_dim]
attn_output = attn_output.transpose(1, 2).contiguous()
# 把多头注意力融合
# [batch_size, seq_len, num_heads, head_dim] -> [batch_size, seq_len, hidden_size]
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
if self.config.pretraining_tp > 1:
# 这里再做一次张量并行
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
else:
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
# 上面的注意力计算,现在都会使用flash attention进行替换,在GPU上提高注意力计算速度
# 详细的替换方法可以看Qwen和LongLora的源码
class LlamaFlashAttention2(LlamaAttention):
"""
Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# LlamaFlashAttention2 attention does not support output_attentions
output_attentions = False
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dime x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
# 记录信息, 用于KV Cache, KV cache解释看: https://zhuanlan.zhihu.com/p/684078417
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
# 记录信息, 用于KV Cache, KV cache解释看: https://zhuanlan.zhihu.com/p/684078417
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
# TODO: llama does not have dropout in the config??
# It is recommended to use dropout with FA according to the docs
# when training.
dropout_rate = 0.0 # if not self.training else self.attn_dropout
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
logger.warning_once(
"The input hidden states seems to be silently casted in float32, this might be related to"
" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
" float16."
)
query_states = query_states.to(torch.float16)
key_states = key_states.to(torch.float16)
value_states = value_states.to(torch.float16)
attn_output = self._flash_attention_forward(
query_states, key_states, value_states, padding_mask, q_len, dropout=dropout_rate
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
def _flash_attention_forward(
self, query_states, key_states, value_states, padding_mask, query_length, dropout=0.0, softmax_scale=None
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
padding_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`int`, *optional*):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
"""
# Contains at least one padding token in the sequence
if padding_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, padding_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=True,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True
)
return attn_output
def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
key_layer = index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
padding_mask = padding_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, padding_mask)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
self.hidden_size = config.hidden_size # 隐藏层大小
self.self_attn = ( # llama注意力层
LlamaAttention(config=config)
if not getattr(config, "_flash_attn_2_enabled", False)
else LlamaFlashAttention2(config)
)
self.mlp = LlamaMLP(config) # llama mlp层
self.input_layernorm = LlamaRNSNorm(config.hidden_size, eps=config.rms_norm_eps) # llama的输入归一化层
self.post_attention_layernorm = LlamaRNSNorm(config.hidden_size, eps=config.rms_norm_eps) # 注意力之后的归一化层
def forward(
self,
hidden_states: torch.Tensor, # [batch_size, seq_len, hidden_size]
attention_mask: Optional[torch.Tensor] = None, # [batch_size, seq_len]
position_ids: Optional[torch.LongTensor] = None, # [batch_size, seq_len]
past_key_value: Optional[Tuple[torch.Tensor]] = None, # [batch_size, past_seq_len]
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
padding_mask: Optional[torch.LongTensor] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
# 保存残差连接输入
residual = hidden_states
# pre-norm
hidden_states = self.input_layernorm(hidden_states)
# Self Attention,自注意力运算。如果在GPU中,都使用flash attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
# 残差计算
hidden_states = residual + hidden_states
# fully connected
# 保存残差连接输入
residual = hidden_states
# post-norm,在进行MLP之前进行归一化计算
hidden_states = self.post_attention_layernorm(hidden_states)
# 带门控机制的全连接层
hidden_states = self.mlp(hidden_states)
# 残差计算
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
# 推理时使用kv cache
outputs += (present_key_value,)
return outputs
from transformers import PreTrainedModel
from transformers.utils import add_start_docstrings
LLAMA_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`LlamaConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
LLAMA_START_DOCSTRING,
)
# 继承自transformers的PreTrainedModel类,可以方便的使用许多已有函数和类
class LlamaPreTrainedModel(PreTrainedModel):
config_class = LlamaConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["LlamaDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
def _init_weights(self, module):
# 初始化模型参数, Linear layer 和 embedding layer
std = self.config.initializer_range
if isinstance(module, nn.Linear):
# 对权重正态分布归一化
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
# 对偏置零初始化
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
# 对权重正态分布归一化
module.weight.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpoint(self, module, value=False):
if isinstance(module, LlamaModel):
module.gradient_checkpointing = value
# 基础Llama模型
LLAMA_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
from transformers.utils import add_start_docstrings_to_model_forward
from typing import List, Union
from transformers.modeling_outputs import BaseModelOutputWithPast
@ add_start_docstrings(
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
LLAMA_START_DOCSTRING,
)
# 继承 LlamaPreTrainedModel
# LlamaPreTrainedModel三个作用: 初始化,设置梯度检查,继承PreTrainedModel类的API
class LlamaModel(LlamaPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
Args:
config: LlamaConfig
"""
def __init__(self, config: LlamaConfig):
super().__init__(config)
# 设置 pad_token_id
self.paddintg_idx = config.pad_token_id
# 设置字典大小,直接决定输入embedding_layer和输出lm_head的维度
self.vocab_size = config.vocab_size
# 设置embedding层,Llama在嵌入层之后不会添加位置编码
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.norm = LlamaRNSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
# 返回嵌入层对象,可以通过该对象获取嵌入层的权重,其实就是一张表
return self.embed_tokens
def set_input_embeddings(self, value):
# 设置嵌入层的权重
self.embed_tokens = value
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
""" 此函数用于准备解码器的注意力掩码,支持因果掩码和扩展的注意力掩码的结合使用。"""
# create causal mask [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
combined_attention_mask = None
# 因果掩码是一个上三角矩阵,用于确保解码器在生成每个词时只能依赖于之前的词。
# 这个掩码的形状被调整为[批量大小, 1, 目标序列长度, 源序列长度],以适应解码器的要求。
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length
)
# 注意力掩码(Attention Mask):用于屏蔽(如填充词元)不应该被模型考虑的输入部分。
# 主要是在attention mask中softmax计算去除padding的影响
# 这对于处理不同长度的输入序列特别重要,以避免在计算注意力时考虑这些无关的词元。
if attention_mask is not None:
# 如果提供了额外的注意力掩码,将此掩码扩展到与因果掩码相同的形状,并将其应用于输入。
# 扩展后的注意力掩码用于指示解码器应该注意的输入序列中的特定部分,例如避免关注填充的位置。
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
# 如果已经有一个因果掩码,将扩展的注意力掩码与因果掩码相加,以结合两者的效果。
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None, # [batch_size, seq_len]
attention_mask: Optional[torch.Tensor] = None, # [batch_size, seq_len]
position_ids: Optional[torch.LongTensor] = None, # [batch_size, seq_len]
past_key_values: Optional[List[torch.FloatTensor]] = None, # [batch_size, past_seq_len]
inputs_embeds: Optional[torch.FloatTensor] = None, # [batch_size, seq_len, embedding_dim]
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
# 是否输出注意力
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
# 是否输出隐藏层
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# 是否使用cache
use_cache = use_cache if use_cache is not None else self.config.use_cache
# 是否设置输出格式为字典
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# 从解码器的输入中检索`input_ids`和`inputs_embeds`。
# 解码器的输入可以是`input_ids`(输入序列的整数表示)或`inputs_embeds`(输入序列的嵌入表示)。
# 如果同时指定了`input_ids`和`inputs_embeds`,则抛出一个错误,因为这两种输入形式是互斥的。
# 不能同时指定解码器的输入ID(`decoder_input_ids`)和解码器的输入嵌入(`decoder_inputs_embeds`)。
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
# 如果指定了`input_ids`,则通过其形状来确定批量大小和序列长度。
# `input_ids`的形状应为[批量大小, 序列长度]。
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
# 如果指定了`inputs_embeds`,则通过其形状来确定批量大小和序列长度。
# `inputs_embeds`的形状应为[批量大小, 序列长度, 嵌入维度]。
# 在这种情况下,我们只关心前两个维度(批量大小和序列长度),嵌入维度在此处不是必需的。
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
# 如果既没有指定`input_ids`也没有指定`inputs_embeds`,则抛出一个错误。
# 必须指定解码器的输入ID(`decoder_input_ids`)或解码器的输入嵌入(`decoder_inputs_embeds`)中的至少一种。
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
# 初始化序列长度和过去键值对的长度
seq_length_with_past = seq_length
past_key_values_length = 0
# 记录信息, 用于KV Cache, KV cache解释看: https://zhuanlan.zhihu.com/p/684078417
# 如果提供了过去键值对(用于缓存以加速生成), 更新序列长度和过去键值对的长度
if past_key_values is not None:
# 获取第一个键值对的长度作为过去键值对的长度
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
# 如果没有提供位置ID, 则自动生成, 考虑过去键值对的长度
if position_ids is None:
# 确定设备位置,优先使用`input_ids`的设备位置
device = input_ids.device if input_ids is not None else inputs_embeds
# 生成位置ID, 从过去键值对的长度到序列总长度
position_ids = torch.arange(past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device)
# 重塑位置ID以适应期望的格式, 增加第0维
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
# 确保提供的位置ID符合期望的格式和类型, [batch_size, seq_length]
position_ids = position_ids.view(-1, seq_length).long()
# 如果没有提供输入嵌入,则使用模型的嵌入层从`input_ids`生成
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# 如果没有提供注意力掩码,则创建一个默认的全1掩码,表示所有位置都可见
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device)
padding_mask = None
else:
# 如果注意力掩码包含0,它将用作填充掩码,否则不使用填充掩码
padding_mask = attention_mask if 0 in attention_mask else None
# 准备解码器的注意力掩码,考虑因果掩码和可能的扩展
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
# 初始隐状态就是嵌入层的输出, [batch_size, seq_len, embedding_dim]
hidden_states = inputs_embeds
# 如果启用了梯度检查点且模型处于训练模式,对`use_cache`进行调整
# 梯度检查和使用缓存是冲突,只能二选一
# 梯度检查节省显存,速度慢
# 使用缓存速度块,但是显存消耗大
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
use_cache = False
# 初始化存储隐藏状态, 自注意力权重和解码缓存的变量
all_hidden_states = () if output_hidden_states else None # 输出隐状态
all_self_attns = () if output_attentions else None # 输出注意层结果
next_decoder_cache = () if use_cache else None # kv cache
# 遍历每一层解码器
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
# 获取当前层的过去键值对, 如果做kv cache, 对每一层都要做kv cahce
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
# 如果启用梯度检查点, 定义自定义前向传播函数
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions, padding_mask=padding_mask)
return custom_forward
# 使用梯度检查点执行自定义前向传播
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer), hidden_states, attention_mask, position_ids
)
else:
# 正常执行解码器层的前向传播
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
)
# 更新隐藏层状态
hidden_states = layer_outputs[0] # layer_outputs = (output, self_attn_weights, present_key_value)
# 如果使用缓存,更新解码器缓存
# layer_outputs = (output, self_attn_weights, present_key_value), 实际就是保存 present_key_value
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
# 如果输出注意力权重,更新自注意力权重集合, self_attn_weights
if output_attentions:
all_self_attns += (layer_outputs[1],)
# 应用最后一层的规范化, 对输出再做归一化
hidden_states = self.norm(hidden_states)
# 如果输出隐藏状态,添加最后一层的隐藏状态到集合中
if output_hidden_states:
all_hidden_states += (hidden_states,)
# 根据是否使用缓存,准备最终的缓存输出
next_cache = next_decoder_cache if use_cache else None
# 根据是否返回字典格式,组织并返回模型的输出
if not return_dict:
# 按照元祖返回输出
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
# 按照BaseModelOutputWithPast对象格式输出
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
from transformers.utils import replace_return_docstrings
from transformers.modeling_outputs import CausalLMOutputWithPast
# Llama模型的因果语言模型, "因果"指的是模型生成文本的方式,即每次生成一个新的词元(例如单词或字符)
class LlamaForCausalLM(LlamaPreTrainedModel):
_tied_weights_keys = ["lm.head.weight"]
def __init__(self, config):
super().__init__(config)
# 初始化 Llama 模型
self.model = LlamaModel(config)
# 保存词汇表大小,用于最后的预测层
self.vocab_size = config.vocab_size
# 初始化语言模型头,将隐藏状态转换为对词汇表的预测
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# 调用post_init来执行权重初始化等后处理操作
self.post_init()
# 返回模型的输入嵌入层
def get_input_embeddings(self):
# 和基础模型的相同
return self.model.embed_tokens
# 设置模型的输入嵌入层
def set_input_embeddings(self, value):
self.model.embed_tokens = value
# 返回模型的输出嵌入层(即lm_head)
def get_output_embeddings(self):
return self.lm_head
# 设置模型的输出嵌入层
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
# 设置解码器
def set_decoder(self, decoder):
self.model = decoder
# 获取解码器
def get_decoder(self):
# 返回解码器,也就是基础模型
return self.model
# 定义模型的前向传播方法
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None, # [batch_size, seq_len]
attention_mask: Optional[torch.Tensor] = None, # [batch_size, seq_len]
position_ids: Optional[torch.LongTensor] = None, # [batch_size, seq_len]
past_key_values: Optional[List[torch.FloatTensor]] = None, # [batch_size, past_seq_len]
inputs_embeds: Optional[torch.FloatTensor] = None, # [batch_size, seq_len, embedding_dim]
labels: Optional[torch.LongTensor] = None, # [batch_size, seq_len], 标签的长度要和input_ids一样
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, LlamaForCausalLM
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
# 根据配置确定是否输出注意力权重、隐藏状态,以及是否以字典形式返回结果
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# 调用Llama模型的前向传播,获取输出
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 提取隐藏状态,作为语言模型头的输入
hidden_states = outputs[0]
# 如果配置了模型并行,对语言模型头进行分片处理,否则直接计算
if self.config.pretraining_tp > 1:
# 分片语言模型头权重,并对每个分片进行线性变换
# 这里是做列切割,再合并。如果是行切割,则是相加
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
# 将分片结果拼接回完整的逻辑回归输出
logits = torch.cat(logits, dim=-1)
else:
# 直接使用语言模型头计算逻辑回归
logits = self.lm_head(hidden_states)
# 转换为float32格式
logits = logits.float()
# 如果提供了标签,则计算损失
"""
当我们训练一个自回归语言模型时,我们的目标是让模型学会根据一系列给定的词(上文)预测下一个词。这种训练方式
允许模型学习语言的内在结构和词之间的关系。为了实现这一点,我们需要一种特定的数据准备方法,即错位操作,来确
保模型的预测目标(下一个词)与其输入(上文)正确对齐。
具体到操作上,假设句子中有N个词,模型为每个词生成了一个逻辑回归输出(预测分数),形成了一个[N, vocab_size]
的矩阵 (其中N是句子长度,vocab_size是词汇表大小)。为了与这些预测相匹配,我们需要N-1个目标词(因为最后一个
词没有“下一个词”)。
词汇表: ["hello", "world", "how", "are", "you"]
考虑一个简单的句子:“hello how are you”。
在自回归模型训练过程中,我们希望模型能够:
根据“hello”预测“how”
根据“hello how”预测“are”
根据“hello how are”预测“you”
换句话说,对于序列中的每个词,我们希望模型使用该词之前的所有词作为上下文来预测该词之后的词。因此,模型的输出
(逻辑回归分数或logits)需要与序列中下一个词的标签进行比较,这就是为什么我们需要错位操作。
假设 一开始的logits如下
[
[0.1, -0.2, 0.8, -0.5, 0.4], # "hello"后的预测
[0.3, 0.2, -0.1, 0.9, -0.4], # "how"后的预测
[-0.2, 0.5, 0.2, -0.1, 0.6], # "are"后的预测
[0.1, 0.3, -0.5, 0.7, -0.6], # "you"后的预测
]
labels: [0, 2, 3, 4]
错位操作的执行方式:
错位逻辑回归输出(logits):我们去除每个序列的最后一个输出,因为序列的最后一个词之后没有下一个词来预测。
这样,对于序列中的每个词(除了最后一个),logits都代表了模型对下一个词的预测。
处理后的logits如下, 删除最后一行
[
[0.1, -0.2, 0.8, -0.5, 0.4], # "hello"后的预测
[0.3, 0.2, -0.1, 0.9, -0.4], # "how"后的预测
[-0.2, 0.5, 0.2, -0.1, 0.6], # "are"后的预测
]
每一行的概率最大是: 2, 3, 4
错位标签(labels):我们去除每个序列的第一个标签,因为序列的第一个词之前没有上文。
然后,每个位置的标签代表了模型在该位置应该预测的下一个词。
处理后的labels: [2, 3, 4]
错位操作的目的:
错位操作确保了模型的每个预测(除了序列的最后一个词)都有一个对应的真实标签来进行比较。这种对齐方式使得我们可以
计算损失函数(如交叉熵损失),这是训练模型的关键。通过最小化预测和真实下一个词之间的差异,模型学习如何根据给定的上文生成文本。
这种方法训练出的模型能够基于先前的词(上文)生成连贯的文本序列,这是许多自然语言处理任务(如文本生成、机器翻译等)的基础。
"""
loss = None
if labels is not None:
# 首先,我们需要对模型的输出逻辑回归(logits)进行错位处理。在因果语言模型中,
# 我们希望模型使用序列中的前n-1个词元来预测第n个词元。因此,我们将逻辑回归输出向左移动一位,
# 这样,每个位置上的逻辑回归值就对应于使用该位置及其之前的词元作为上下文预测下一个词元的分布。
# 注意:由于序列的最后一个位置没有下一个词元,所以我们移除了逻辑回归输出的最后一个位置。
shift_logits = logits[..., :-1, :].contiguous()
# 接着,对于标签(labels),我们需要进行相应的错位处理,以使其与错位后的逻辑回归输出对齐。
# 我们去掉了标签序列的第一个位置,并保留剩下的部分,因为第一个位置没有前置词元可以用于预测。
# 这样,错位后的每个标签位置都对应于其在逻辑回归输出中的预测。
shift_labels = labels[..., 1:].contiguous()
# 使用交叉熵损失函数(CrossEntropyLoss)计算预测和真实标签之间的损失。
# 这是多分类任务中常用的损失函数,适用于语言模型的训练。
loss_fct = CrossEntropyLoss()
# 由于交叉熵损失函数期望的输入形式是二维的(即,每一行是一个样本的预测分布,每一列对应一个类别),
# 我们需要将错位后的逻辑回归输出和标签展平成二维形式。这里,逻辑回归输出被展平为一个长向量,
# 其中每个样本的预测分布按顺序排列。标签也被展平为一个一维向量,其中包含对应于每个样本预测的真实类别索引。
shift_labels = shift_labels.view(-1)
# 最后,计算展平后的逻辑回归输出和标签之间的交叉熵损失。
# 特别注意,标签中的-100值被用作一个特殊值,表示该位置的损失不应该被计算(即,忽略该位置)。
# 这在处理填充(padding)或者其他不应计入损失计算的特殊情况时非常有用。
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
# 如果不以字典形式返回,将各个输出组合成元组
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
# 以字典形式返回所有输出,包括损失、逻辑回归结果和可选的隐藏状态及注意力权重
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
"""为文本生成过程准备输入数据,确保每次生成步骤都使用正确的输入格式和必要的信息"""
# 如果存在过去的键值对,说明我们正在进行连续生成(例如,在文本生成任务中逐步生成文本)。
# 此时,只需要最近一次的输入ID,因为模型的下一个输出只依赖于最近的输入。
if past_key_values:
input_ids = input_ids[:, -1:]
# 如果未直接提供位置ID,并且存在注意力掩码,则基于注意力掩码动态生成位置ID。
# 这允许模型跟踪每个生成步骤的位置信息,即使在批量生成中处理不同长度的序列时也能正确处理。
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# 基于注意力掩码动态生成位置ID
"""attention_mask用于指示哪些位置是有效的输入,哪些位置是填充(padding)位置。在文本生成任务中,有效的输入位置通常设为1,
而填充位置设为0。.cumsum(-1)计算沿最后一个维度(即每个序列的长度维度)的累积和。这样做的目的是为每个位置生成一个累积计数
,基于其之前所有有效输入的数量。例如,对于一个序列的attention_mask [1, 1, 1, 0, 0](表示有三个有效输入,后面跟着两个
填充),最后,从累积和的结果中减去1是为了将位置索引从1开始的计数转换为从0开始,因为在大多数深度学习框架中,索引是从0开始的。
继续上面的例子,减去1后得到[0, 1, 2, 2, 2],这样就为每个有效输入生成了正确的位置索引。"""
position_ids = attention_mask.long().cumsum(-1) - 1
# 将注意力掩码为0的位置的位置ID设置为1,以避免位置ID为负数。
position_ids.masked_fill_(attention_mask == 0, 1)
# 如果存在过去的键值对,我们只关心最近一步的位置ID。
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
# 如果提供了输入嵌入,并且没有过去的键值对,意味着我们处于生成的第一步,
# 此时应使用输入嵌入作为模型的输入。
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
# 否则,使用输入ID作为模型的输入。
model_inputs = {"input_ids": input_ids}
# 更新模型输入字典,添加位置ID、过去的键值对、是否使用缓存和注意力掩码等参数。
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache", False),
"attention_mask": attention_mask,
}
)
# 返回准备好的模型输入字典,供生成过程使用。
return model_inputs
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
"""用于在文本生成过程中,特别是使用带有缓存机制的模型(如Transformer模型)进行束搜索(beam search)时,
根据beam_idx重新排序模型的过去键值对缓存。下面是对这个方法的详细注释:
"""
# 初始化一个空的元组,用于存储重新排序后的过去键值对
reordered_past = ()
# 遍历传入的past_key_values元组,每个元素代表模型的一层中的过去键值对
for layer_past in past_key_values:
# 对于每一层的过去键值对,重新排序每个状态张量。
# 这里使用了列表推导式和tuple函数,对每一层的每个状态张量(past_state)进行操作。
# 使用index_select方法按照beam_idx指定的顺序在第0维(通常是批次或束索引)上进行选择和重排序。
# 这是为了确保在进行束搜索时,每个束的缓存状态与当前的候选序列保持一致。
# beam_idx.to(past_state.device)确保索引张量在与状态张量相同的设备上。
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
# 返回重新排序后的过去键值对,以供模型在后续的生成步骤中使用。
return reordered_past
from transformers.modeling_outputs import SequenceClassifierOutputWithPast
@add_start_docstrings(
"""
The LLaMa Model transformer with a sequence classification head on top (linear layer).
[`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
(e.g. GPT-2) do.
Since it does classification on the last token, it requires to know the position of the last token. If a
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
each row of the batch).
""",
LLAMA_START_DOCSTRING,
)
# LlamaForSequenceClassification类定义:这个类是用于序列分类任务的LLaMa模型变体。
class LlamaForSequenceClassification(LlamaPreTrainedModel):
# 构造函数:初始化模型
def __init__(self, config):
super().__init__(config) # 调用父类构造函数
self.num_labels = config.num_labels # 从配置中获取类别数量
self.model = LlamaModel(config) # 初始化LLaMa模型
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) # 定义一个线性层作为分类头
self.post_init() # 调用post_init方法进行权重初始化和最终处理
# 获取输入嵌入
def get_input_embeddings(self):
return self.model.embed_tokens
# 设置输入嵌入
def set_input_embeddings(self, value):
self.model.embed_tokens = value
# 前向传播方法:执行模型的前向计算
def forward(
self,
input_ids: torch.LongTensor = None, # 输入的token ID
attention_mask: Optional[torch.Tensor] = None, # 注意力掩码
position_ids: Optional[torch.LongTensor] = None, # 位置ID
past_key_values: Optional[List[torch.FloatTensor]] = None, # 过去的键值对,用于递增式生成
inputs_embeds: Optional[torch.FloatTensor] = None, # 输入嵌入,直接提供,而非通过token ID计算得到
labels: Optional[torch.LongTensor] = None, # 用于计算损失的标签
use_cache: Optional[bool] = None, # 是否使用缓存,提高生成效率
output_attentions: Optional[bool] = None, # 是否输出注意力权重
output_hidden_states: Optional[bool] = None, # 是否输出隐藏状态
return_dict: Optional[bool] = None, # 是否以字典形式返回输出
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# 调用LLaMa模型的前向传播,获取输出
transformer_outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0] # 获取最后一层的隐藏状态
logits = self.score(hidden_states) # 通过分类头计算logits
# 根据输入确定批量大小
if input_ids is not None:
batch_size = input_ids.shape[0]
else:
batch_size = inputs_embeds.shape[0]
# 检查配置中是否定义了pad_token_id,以及批量大小是否大于1
if self.config.pad_token_id is None and batch_size != 1:
# 如果没有定义pad_token_id且批量大小不为1,则无法正确处理批量数据,因此抛出错误
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
# 初始化序列长度
if self.config.pad_token_id is None:
# 如果没有定义pad_token_id,则默认序列长度为-1(即使用序列的最后一个token进行分类)
sequence_lengths = -1
else:
# 如果定义了pad_token_id,则计算每个序列的实际长度(最后一个非填充token的位置)
if input_ids is not None:
# 使用argmax找到第一个pad_token_id出现的位置,减去1得到最后一个非填充token的位置
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(logits.device)
else:
# 如果没有提供input_ids,序列长度设置为-1
sequence_lengths = -1
# 根据计算出的序列长度,从logits中提取用于分类的logits
# torch.arange(batch_size, device=logits.device)生成一个连续的序列,用于索引batch中的每个样本
# sequence_lengths指示每个样本中用于分类的token的位置
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
# 初始化损失为None,仅当提供了标签时计算损失
loss = None
if labels is not None:
# 将标签转移到logits所在的设备上,以确保计算损失时不会因设备不匹配而出错
labels = labels.to(logits.device)
# 根据模型配置和标签的类型,确定问题类型(回归、单标签分类或多标签分类)
if self.config.problem_type is None:
if self.num_labels == 1:
# 如果num_labels为1,视为回归问题
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
# 如果num_labels大于1且标签类型为long或int,视为单标签分类问题
self.config.problem_type = "single_label_classification"
else:
# 否则视为多标签分类问题
self.config.problem_type = "multi_label_classification"
# 根据确定的问题类型计算损失
if self.config.problem_type == "regression":
# 对于回归问题,使用均方误差损失
loss_fct = MSELoss()
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
elif self.config.problem_type == "single_label_classification":
# 对于单标签分类问题,使用交叉熵损失
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
# 对于多标签分类问题,使用带logits的二进制交叉熵损失
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
# 根据是否要求以字典形式返回输出,构造最终的输出
if not return_dict:
# 如果不使用字典形式,将损失和其他输出构造成元组返回
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
# 如果使用字典形式返回输出,构造并返回一个包含损失、logits、过去的键值对、隐藏状态和注意力权重的对象
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
代码注释较长,请耐心查看。