常见的蒸馏技术

什么是知识蒸馏?

根据维基百科的定义,知识蒸馏(knowledge distillation) 是人工智能领域的一项模型训练技术。该技术透过类似于教师—学生的方式,令规模较小、结构较为简单的人工智能模型从已经经过充足训练的大型、复杂模型身上学习其掌握的知识。该技术可以让小型简单模型快速有效学习到大型复杂模型透过漫长训练才能得到的结果,从而改善模型的效率、减少运算开销,因此亦被称为模型蒸馏(model distillation)。

知识蒸馏并不是什么新技术,早在 2006 年, Bucilua 等人最先提出将大模型的知识迁移到小模型的想法。2015 年,Hinton 正式提出广为人知的知识蒸馏的概念。其主要的想法是:学生模型通过模仿教师模型来获得和教师模型相当的精度,关键问题是如何将教师模型的知识迁移到学生模型。

在这里插入图片描述

目前最常用的蒸馏方法有三种:数据蒸馏、Logits蒸馏、特征蒸馏。下面分别介绍下这三种蒸馏方法。

数据蒸馏

在数据蒸馏过程中,教师模型首先生成<问题, 答案> pair,这些 pair 随后被用来训练学生模型。例如,DeepSeekR1-Distill-Qwen-32B 模型就是通过使用 DeepSeek-r1 生成的80万条数据,在 Qwen2.5-32B 基

计及光伏电站快速无功响应特性的分布式电源优化配置方法(Matlab代码实现)内容概要:本文提出了一种计及光伏电站快速无功响应特性的分布式电源优化配置方法,并提供了基于Matlab的代码实现。该方法在传统分布式电源配置基础上,充分考虑了光伏电站通过逆变器实现的快速无功调节能力,以提升配电网的电压稳定性与运行效率。通过建立包含有功、无功协调优化的数学模型,结合智能算法求解最优电源配置方案,有效降低了网络损耗,改善了节点电压质量,增强了系统对可再生能源的接纳能力。研究案例验证了所提方法在典型配电系统中的有效性与实用性。; 适合人群:具备电力系统基础知识和Matlab编程能力的电气工程专业研究生、科研人员及从事新能源并网、配电网规划的相关技术人员。; 使用场景及目标:①用于分布式光伏等新能源接入配电网的规划与优化设计;②提升配电网电压稳定性与电能质量;③研究光伏逆变器无功补偿能力在系统优化中的应用价值;④为含高比例可再生能源的主动配电网提供技术支持。; 阅读建议:建议读者结合Matlab代码与算法原理同步学习,重点理解目标函数构建、约束条件设定及优化算法实现过程,可通过修改系统参数和场景设置进行仿真对比,深入掌握方法的核心思想与工程应用潜力。
### 知识蒸馏中的损失函数及其技术公式详解 #### 1. 损失函数概述 知识蒸馏的核心在于通过教师模型指导学生模型的学习过程,使学生模型能够逼近教师模型的行为。为了实现这一目标,在训练过程中通常会引入一种特殊的损失函数——软目标交叉熵损失(Soft Target Cross Entropy Loss),并将其与传统的硬标签交叉熵损失(Hard Label Cross Entropy Loss)相结合。 总的知识蒸馏损失函数可以表示为以下形式[^2]: \[ L_{total} = \alpha L_{hard} + (1-\alpha)L_{soft} \] 其中: - \(L_{hard}\): 基于真实标签的硬标签交叉熵损失; - \(L_{soft}\): 基于教师模型输出的概率分布的软目标交叉熵损失; - \(\alpha\): 控制两种损失权重的超参数,取值范围为\[0, 1\]。 --- #### 2. 硬标签交叉熵损失 (\(L_{hard}\)) 硬标签交叉熵损失是最常见的监督学习损失函数之一,它基于真实的类别标签和学生模型预测的概率分布来计算误差。具体表达式如下[^3]: \[ L_{hard}(y, p_s) = -\sum_i y_i \log(p_{s,i}) \] 其中: - \(y\) 是真实标签向量(one-hot 编码); - \(p_s\) 是学生模型经过 softmax 函数后的概率分布。 --- #### 3. 软目标交叉熵损失 (\(L_{soft}\)) 软目标交叉熵损失利用了教师模型的输出作为“软标签”,并通过 KL 散度衡量学生模型与教师模型之间输出分布的差异。由于原始 softmax 输出可能过于尖锐化,因此在知识蒸馏中常采用升温操作(temperature scaling)来平滑概率分布[^4]。 设温度参数为\(T\),则教师模型和学生模型的输出分别记作: \[ q_t(i) = \frac{\exp(z_t(i)/T)}{\sum_j \exp(z_t(j)/T)} \] \[ q_s(i) = \frac{\exp(z_s(i)/T)}{\sum_j \exp(z_s(j)/T)} \] 其中: - \(z_t\) 和 \(z_s\) 分别代表教师模型和学生模型未经过 softmax 的 logits 向量; - \(i\) 表示第 i 类别的索引。 软目标交叉熵损失可由 KL 散度定义为: \[ L_{soft}(q_t, q_s) = T^2 D_{KL}(q_t || q_s) \] 进一步展开得到: \[ D_{KL}(q_t || q_s) = \sum_i q_t(i) \log\left(\frac{q_t(i)}{q_s(i)}\right) \] 最终的软目标损失写成显式形式为: \[ L_{soft}(q_t, q_s) = T^2 \sum_i q_t(i) \log(q_t(i)) - T^2 \sum_i q_t(i) \log(q_s(i)) \] --- #### 4. 总体损失函数分析 总体损失函数综合考虑了硬标签和软标签的信息,平衡两者的贡献比例取决于超参数 \(\alpha\)。当 \(\alpha=1\) 时,仅依赖硬标签;而当 \(\alpha=0\) 时,则完全依赖软标签。实践中一般设置较小的 \(\alpha\) 来保留部分硬标签信息,从而提升学生的泛化能力。 此外,温度参数 \(T\) 对软目标的影响也至关重要。较大的 \(T\) 可以让概率分布更加平滑,使得学生模型更容易捕捉到教师模型隐含的知识。 --- ```python import torch import torch.nn.functional as F def knowledge_distillation_loss(student_logits, teacher_logits, labels, temperature=4.0, alpha=0.5): """ 计算知识蒸馏的总损失 参数: student_logits: 学生模型的 logit 输出 teacher_logits: 教师模型的 logit 输出 labels: 真实标签 temperature: 温度参数,默认为 4.0 alpha: 平衡因子,默认为 0.5 返回: total_loss: 总损失 """ # Hard loss: 使用真实标签计算交叉熵 hard_loss = F.cross_entropy(student_logits, labels) # Soft loss: 使用教师模型的 soft target 进行 KL 散度计算 student_probs = F.log_softmax(student_logits / temperature, dim=-1) teacher_probs = F.softmax(teacher_logits / temperature, dim=-1) soft_loss = F.kl_div(student_probs, teacher_probs, reduction='batchmean') * (temperature ** 2) # Total loss total_loss = alpha * hard_loss + (1 - alpha) * soft_loss return total_loss ``` --- ####
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

comli_cn

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值