知识蒸馏|使用一个大模型训练另一个大模型的三种方法

大语言模型(LLM)不仅能从海量文本中学习,它们也能「互相学习」:

·Llama 4 的 Scout 和 Maverick 模型就是在 Llama 4 Behemoth 的辅助下训练出来的。

·Google 的 Gemma 2 和 3 是在内部模型 Gemini 的指导下训练完成的。

这种「互相学习」的过程,主要依赖于知识蒸馏(Distillation)技术。下面这张图展示了目前主流的三种知识蒸馏方式。

通俗地说,知识蒸馏的目标就是把一个模型中的“知识”迁移给另一个模型。这在传统深度学习中早就很常见。

在 LLM 的训练中,知识蒸馏可以发生在两个阶段:

1.预训练阶段

·将较大的教师模型(Teacher LLM)和较小的学生模型(Student LLM)一同训练。

·Llama 4 就采用了这种方法来训练 Scout 和 Maverick。

2.后训练阶段(微调后再蒸馏)

·先单独训练好教师模型,然后将其知识迁移给学生模型。

·DeepSeek 就采用了这种方式,把 DeepSeek-R1 蒸馏到了 Qwen 和 Llama 3.1。

你也可以在两个阶段都使用蒸馏,Gemma 3 就是这样做的。

三种主流知识蒸馏方法:

1. Soft-label 蒸馏(软标签蒸馏)

·使用一个已经训练好的教师模型,对整个语料生成 softmax 概率分布(即每个词的概率而不是最终选哪个词)。

·再让一个未训练的学生模型跑同样的语料,也生成自己的 softmax 概率分布。

·然后训练学生模型,让它的输出概率尽可能接近教师模型的输出。

这种方式可以最大限度保留教师模型的推理逻辑与知识,但也有硬性要求:必须能访问到教师模型的权重,才能获取它输出的概率分布。

即便可以访问,还有个大问题:假设词表大小是 10 万,语料是 5 万亿 token。对于每一个 token,都需要生成一个长度为 10 万的概率向量。即便使用 float8 精度,也需要约 5 亿 GB 内存 来存储这些 soft labels!

所以我们来看第二种方法,它解决了这个问题。

2. Hard-label 蒸馏(硬标签蒸馏)

·同样用一个已经训练好的教师模型,但这次只输出最终选择的单个 token(即 one-hot 编码)。

·学生模型依旧输出 softmax 概率分布。

·然后训练学生模型,使其概率分布尽量接近教师模型最终选择的 token。

DeepSeek 就是采用这种方式,把 DeepSeek-R1 蒸馏到 Qwen 和 Llama 3.1 的。这种方法相比软标签,资源占用低很多,但信息量也少一些。

3. Co-distillation(协同蒸馏)

·从一个未训练的教师模型和一个未训练的学生模型同时开始。

·两个模型都跑当前的训练 batch,生成各自的 softmax 概率。

·教师模型正常训练,使用 ground-truth 的硬标签来学习。

·学生模型则训练它的 softmax 输出尽可能接近教师模型当前的输出。

Llama 4 就是用这种方法,从 Llama 4 Behemoth 训练出了 Scout 和 Maverick。

当然,刚开始训练时,教师模型的输出概率也不靠谱,所以学生模型的训练同时使用教师模型的 soft labels + 真实的 ground-truth 硬标签。

### 大模型蒸馏方法概述 大模型蒸馏是一种通过压缩大型复杂模型来获得更高效小型模型的技术。这种方法的核心在于利用教师模型(Teacher Model)的知识指导学生模型(Student Model),从而使得学生模型能够继承教师模型的关键特性并保持较高的性能。 #### 知识蒸馏技术的主要形式 知识蒸馏可以分为多种不同的模式,这些模式涵盖了不同场景下的应用需求: - **离线蒸馏**:在这种模式下,教师模型已经训练完成,在固定的状态下对学生模型进行指导[^1]。 - **在线蒸馏**:与离线蒸馏相反,这种模式允许教师和学生模型同时更新权重参数,通常用于动态环境中的联合优化过程[^1]。 - **自蒸馏**:此过程中,单个模型既作为自己的老师也作为自己的学生来进行迭代改进。 - **无数据蒸馏**:当无法访问原始训练数据时采用的一种特殊形式,它依赖于合成样本或其他替代方案来传递知识[^1]。 - **多模型蒸馏**:多个教师模型共同教导单一或几个学生模型,旨在融合来自不同源的信息以提高泛化能力。 - **特权蒸馏**:引入额外的辅助信息(即所谓的“特权”信息),帮助提升最终的学生表现水平[^1]。 #### 常见的大规模语言模型(Large Language Models, LLMs) 蒸馏算法及其特点 一种先进的大规模语言模型蒸馏框架被提出,该框架基于迭代剪枝与蒸馏相结合的方式,用来构建一系列较小版本的语言模型家族[^2]。具体来说: - 这种方法首先通过对预定义架构执行多次循环操作实现逐步减少冗余节点数量; - 同时在整个流程期间持续运用知识转移机制确保质量损失最小化; 以下是几种典型的算法实例以及它们各自的优势所在: | 名称 | 描述 | |--|--| | DistilBERT | 由Hugging Face开发的一个轻量化版BERT变体,主要针对NLP任务进行了显著提速而不牺牲太多准确性 。 它采用了软标签加权平均策略加上隐藏状态匹配技巧达成目标效果 [^未提供].| | TinyBERT | 另外一款专为资源受限设备设计的小型BERT实现方案 , 不仅体积小巧而且推理速度快得多 . 此外还特别强调了跨层交互的重要性以便更好地捕捉上下文关系特征 [^未提供].| #### 应用领域分析 随着人工智能技术的发展,特别是在自然语言处理(Natural Language Processing,NLP),计算机视觉(Computer Vision,CV),语音识别(Speech Recognition)等领域内的广泛应用,促使人们不断探索更加高效的解决方案。下面列举了一些典型的应用案例说明如何有效利用上述提到的各种类型的蒸馏技术和相应工具包解决实际问题: - 在搜索引擎服务中部署经过高度优化后的文本嵌入向量计算模块,可极大降低延迟时间成本的同时维持良好的检索精度标准 ; - 面向移动终端应用程序内部集成简化过的图像分类器组件,则能够在保障用户体验流畅性的前提条件下节省电量消耗 ; 总之,无论是为了满足特定硬件条件限制还是追求极致效率的目标导向考虑因素驱动之下,合理选用合适种类别的知识迁移途径都将发挥不可估量的作用价值. ```python def knowledge_distillation_loss(student_output, teacher_output, temperature=2.0): """ 计算知识蒸馏损失函数 参数: student_output (Tensor): 学生模型预测概率分布. teacher_output (Tensor): 教师模型预测概率分布. temperature (float): 温度超参调节平滑程度,默认值设为2.0 返回: Tensor: KL散度衡量两者差异大小的结果表示数值 """ import torch.nn.functional as F soft_student = F.log_softmax(student_output / temperature, dim=-1) soft_teacher = F.softmax(teacher_output / temperature, dim=-1) return F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (temperature**2) # 使用示例代码片段展示调用方式如下所示: student_logits = model Student(input_data) teacher_logits = model Teacher(input_data) loss_value = knowledge_distillation_loss(student_logits , teacher_logits ) ``` 相关问题
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值