知乎:车中草同学(已授权)
链接:https://zhuanlan.zhihu.com/p/1485465898
范围:该问题影响所有使用梯度累计的库,包括hf的等。(hf的人在修复中了)
10.18日更新:
感谢评论区的大佬赐教,补充下他们的观点。
@Quokka 提供了一种对描述的实验现象(梯度累积越大,最终 loss 就越大的猜测)的解释:
短序列因为上下文短,信息不足,所以不容易预测,loss 偏大(梯度累积再给他加权),于是观测到的现象就是“梯度累积 loss 偏大”。其实它是偏向了短序列的 loss,而不是往大的方向偏。
@Ethan Yan
提到之前写过一篇文章:SFT loss 计算的那些坑(多轮合并/packing)
https://zhuanlan.zhihu.com/p/721652210
解释了这种情况不只在梯度累积中发生,而且还发生在 SFT 阶段:1. 多轮对话合并 2. 不同样本的 Packing。(具体看他的文章)
一般情况下,loss 计算会经历三次平均:
micro batch 维度,分母是这个 micro batch 中的所有 label 不是 -100 的 token 数(不同 token 之间 loss 的平均)
DP 维度,分母是 DP size (和 GPU 数量相关,不同机器之间 loss 的平均)
梯度累加维度,分母是梯度累加数。(不同 batch 之间的 loss 的平均)
我们要做的就是,不要让 DP 以及梯度累积维度影响原本 token 级别等权的 loss。
因此计算完,所有的目标 token 的 loss,加和之后,再除以实际目标 token 总数,这样,在 token 维度都是等权的,不受长度影响。
10.17日更新:
Hugging Face 在 10.16 日写了篇博客介绍修复问题。
https://huggingface.co/blog/gradient_accumulation
他们的方法是:交叉熵改为 reduction=sum
,再除总的实际目标 token。(他们的 loss 都在模型里面,要改无数个模型文件。。。)
def ForCausalLMLoss(logits, labels, vocab_size, **kwargs):
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
shift_logits = shift_logits.view(-1, vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
num_items = kwargs.pop("num_items", None)
+ loss = nn.functional.cross_entropy(shift_logits, shift_labels, ignore_index=-100, reduction="sum")
+ loss = loss / num_items
- loss = nn.functional.cross_entropy(shift_logits, shift_labels, ignore_index=-100)
return loss
Qwen2 的计算 loss 的方式:
https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2/modeling_qwen2.py#L855
他们的 PR:
https://github.com/huggingface/transformers/pull/34191 https://github.com/huggingface/transformers/pull/34198
重新安装后,即可正常使用:pip install git+
https://github.com/huggingface/transformers
整体结论(太长不看版):
理论上,梯度累计在数学上应该等同于全批量训练,但实际发现 loss 并不匹配。
https://github.com/huggingface/trl/issues/2175
研究者通过公式和实验证明,罪魁祸首是基于平均(mean)交叉熵损失和梯度累计,结合后会比全批量的平均交叉熵损失更大。
loss大,会对什么有影响呢?
模型泛化性:梯度累计加上基于平均(mean)交叉熵损失,会导致 bsz=1,ga=16 之间的 L2 范数比 bsz=16 的 10 倍。可以理解为:导致模型泛化性不足。
梯度累积后,过度重视短的序列长度,而忽略长的序列长度。
问题描述:
实验发现1:之前的研究者发现,在总的
batch_size * gradient accumulation
相同的情况下,梯度累积越大,最终 loss 就越大。研究者对此提了一个 issue:
https://github.com/huggingface/trl/issues/2175
实验发现2:最近的研究者通过实验发现,在相同的
batch_size * gradient accumulation
下,除了 loss 会更大,梯度累积越大,最终会导致 L2 范数越大。L2 范数越大,说明权重越大,进而影响模型的泛化性。
下面我们推导下 full batch 和 gradient accumulation 在计算 loss 上到底是否相等?
实例推导:
假设我们是在 SFT 任务中,输入的维度是 [batch_size, seq_len]
,假设这个 batch 内:实际需要计算的 target token 的总数为 I:
那么,交叉熵 loss 会先为 batch 内每一个 target token 计算一个 loss,然后相加,最后除以这个 batch 内所有 target token 的总数 I。
交叉熵 loss 的默认设定是 reduction=mean
。我们举例说明:
第一种情况(full batch):
batch_size = 2
,gradient accumulation = 1
;第二种情况(gradient accumulation):
batch_size = 1
,gradient accumulation = 2
。
序列等长的情况(每个 batch 的 target token 数量一样):
首先,假设两个序列等长,都是 m;
full batch 情况:
gradient accumulation 情况:
两种 loss 是相等的。
序列不等长的情况(每个 batch 的 target token 数量不一样):
假设两个序列不等长(训练过程中十分常见),一个是 一个是 ;
full batch 情况:
gradient accumulation 情况:
其实本质上就是比较 和 谁大谁小的情况了。
四个变量,已知 和 不相等,可以对 loss 做些假设,数学证明下两者的大小。
从目前的公式,我们能得出的结论是,相比 full batch 下不同序列的 target token 等权,在 gradient accumulation 情况下:
如果一个序列 target token 越短,那么其 loss 权重越大;(因为分母小)
如果一个序列 target token 越长,那么其 loss 权重越小。(因为分母大)
参考文章:
Bug Fixes in LLM Training - Gradient Accumulation
https://unsloth.ai/blog/gradient
CrossEntropyLoss - PyTorch 2.4 documentation
https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
THUDM/LongAlign
https://github.com/THUDM/LongAlign/issues/3
备注:昵称-学校/公司-方向/会议(eg.ACL),进入技术/投稿群
id:DLNLPer,记得备注呦