大模型训练关键技术参数总结浓缩版(基于minimind)

查看各个训练脚本,整理训练方式、数据集、初始模型、训练参数和损失函数的总结。

详细说明:

1. 初始模型加载说明:

  • Pretrain: 无需初始模型,随机初始化
  • Full SFT: 加载预训练模型
  • DPO: 加载SFT模型作为actor和ref(ref冻结)
  • LoRA: 加载SFT模型,添加LoRA层,冻结原参数
  • Distillation: 学生模型和教师模型分别加载不同大小的SFT模型
  • Reason Distill: 加载RLHF后的模型

2. 损失函数说明:

  • CE Loss: CrossEntropyLoss,用于预测下一个token
  • DPO Loss: 直接偏好优化损失,基于chosen/rejected样本的概率差
  • KL Divergence: 知识蒸馏的软标签损失
  • 标签惩罚: 推理模型训练中对特定标签位置增强损失权重

3. 学习率调度:

所有训练方式均使用余弦退火调度

lr = base_lr / 10 + 0.5 * base_lr * (1 + cos(π * step / total_steps))

该表格覆盖了MiniMind项目中的主要训练方式及其配置。

训练方式脚本文件数据集初始模型数量使用的初始模型默认训练参数损失函数损失计算说明
预训练 (Pretrain)train_pretrain.pypretrain_hq.jsonl0个(从零开始)无(随机初始化)epochs=1
batch_size=32
lr=5e-4
max_seq_len=512
accumulation_steps=8
grad_clip=1.0
CrossEntropyLoss
+ MoE aux_loss
1. 计算CE Loss:对每个token位置的logits和label计算交叉熵
2. 应用loss_mask:只计算有效token位置的损失
3. 平均:loss = (loss * loss_mask).sum() / loss_mask.sum()
4. 添加MoE负载均衡loss:loss += res.aux_loss
5. 梯度累积:loss = loss / accumulation_steps
全参微调 (Full SFT)train_full_sft.pysft_mini_512.jsonl
(可配置其他sft数据集)
1个pretrain_{hidden_size}.pthepochs=2
batch_size=16
lr=5e-7
max_seq_len=512
accumulation_steps=1
grad_clip=1.0
CrossEntropyLoss
+ MoE aux_loss
1. 计算CE Loss:对每个token位置计算交叉熵
2. 应用loss_mask:loss = (loss * loss_mask).sum() / loss_mask.sum()
3. 添加MoE负载均衡loss:loss += res.aux_loss
4. 梯度累积:loss = loss / accumulation_steps
注:与预训练计算方式相同,但学习率更小
DPO强化学习train_dpo.pydpo.jsonl2个full_sft_{hidden_size}.pth
(actor model + ref model,两者相同,ref冻结)
epochs=2
batch_size=4
lr=1e-8
max_seq_len=1024
accumulation_steps=1
grad_clip=1.0
beta=0.1
DPO Loss1. 计算log概率:log_probs = log_softmax(logits, dim=2),然后提取对应标签的概率
2. 平均序列概率:probs = (probs * mask).sum(dim=1) / seq_lengths
3. 分离chosen/rejected:batch前一半为chosen,后一半为rejected
4. 计算logratio:pi_logratios = chosen_probs - reject_probs
5. 计算参考模型logratio:ref_logratios = chosen_ref_probs - reject_ref_probs
6. DPO损失:loss = -logsigmoid(beta * (pi_logratios - ref_logratios))
7. 取平均:loss.mean()
LoRA微调train_lora.pylora_medical.jsonl
(可配置其他lora数据集)
1个full_sft_{hidden_size}.pthepochs=10
batch_size=32
lr=1e-4
max_seq_len=512
accumulation_steps=1
grad_clip=1.0
CrossEntropyLoss
+ MoE aux_loss
(仅更新LoRA参数)
1. 计算CE Loss:标准交叉熵损失
2. 应用loss_mask:loss = (loss * loss_mask).sum() / loss_mask.sum()
3. 添加MoE负载均衡loss:loss += res.aux_loss
4. 关键:只对LoRA参数计算梯度,基础模型参数冻结
5. 梯度累积:loss = loss / accumulation_steps
知识蒸馏train_distillation.pysft_xxx.jsonl
(需要手动指定)
2个学生模型:full_sft_512.pth
教师模型:full_sft_768.pth
epochs=6
batch_size=32
lr=5e-6
max_seq_len=512
accumulation_steps=1
grad_clip=1.0
alpha=0.0
temperature=1.0
α * CE Loss
+ (1-α) * KL Divergence Loss
1. CE Loss(可选):ce_loss = cross_entropy(student_logits, labels),应用loss_mask后平均
2. 蒸馏Loss
- 教师模型softmax:teacher_probs = softmax(teacher_logits / T)
- 学生模型log_softmax:student_log_probs = log_softmax(student_logits / T)
- KL散度:kl = kl_div(student_log_probs, teacher_probs)
- 温度缩放:distill_loss = T² * kl
3. 总损失loss = α * ce_loss + (1-α) * distill_loss
4. 添加MoE loss:ce_loss += res.aux_loss(如果使用MoE)
推理模型蒸馏train_distill_reason.pyr1_mix_1024.jsonl1个rlhf_{hidden_size}.pthepochs=1
batch_size=8
lr=1e-6
max_seq_len=1024
accumulation_steps=1
grad_clip=1.0
CrossEntropyLoss
+ 标签位置惩罚
1. 计算CE Loss:标准交叉熵损失
2. 特殊处理:识别特殊标签位置(<think>, </think>, <answer>, </answer>
3. 标签惩罚:对特殊标签位置的loss权重×10
- loss_mask[sp_ids] = 10(sp_ids为特殊标签位置)
4. 加权平均:loss = (loss * loss_mask).sum() / loss_mask_sum
5. 添加MoE loss:loss += res.aux_loss
6. 目的:增强模型对思考和回答标签的识别能力

关键要点总结:

1. 通用损失计算流程

  • 计算交叉熵损失(对每个token位置)
  • 应用loss_mask(只计算有效token位置)
  • 平均损失(除以有效token数量)
  • 添加辅助损失(如果使用MoE)
  • 梯度累积(除以accumulation_steps)

2. DPO特殊处理

  • 需要同时计算actor模型和参考模型的概率
  • 使用chosen/rejected样本的概率差计算偏好

3. 知识蒸馏特殊处理

  • 使用温度缩放(temperature)进行软标签蒸馏
  • 可混合使用硬标签(CE Loss)和软标签(KL Loss)

4. 推理模型特殊处理

  • 对特定标签位置进行10倍权重惩罚
  • 提高模型对格式标签的识别准确性

这些说明涵盖了各训练方式的具体损失计算细节。

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值