yolov8蒸馏(附代码-免费)

首先蒸馏是什么?

模型蒸馏(Model Distillation)是一种用于在计算机视觉中提高模型性能和效率的技术。在模型蒸馏中,通常存在两个模型,即“教师模型”和“学生模型”。

为什么需要蒸馏?

  1. 在不增加模型计算量和参数量的情况下提升精度,也即是可以无损提高精度。
  2. 配合剪枝一起使用,可以尽量达到无损降低模型参数量、计算量,提高FPS的情况下,还能保持模型精度没有下降甚至上升,这是改进网络结构无法达到的高度。
  3. 论文中的保底手段,因为剪枝和蒸馏的特殊性,其都不会增加参数量和计算量,可以在最后一个点上大幅度增加实验和工作量,因为本身蒸馏也需要做大量实验。

目录

一.代码前提

(1)本文选取的老师模型为yolov8s,学生为剪枝完的yolov8s

(2)本文使用的蒸馏方法包括mgd,cwd

(3)使用前下载必须的包,并且把数据集放在datasets文件夹中,最后替换data.yaml中分类。

二.蒸馏步骤

(1) 训练教师模型

(2) 训练学生模型

(3) 蒸馏训练

三.模型剪枝+蒸馏

(1)约束训练在我上一篇文章中提到,链接:yolov8剪枝

(2)约束训练后,先进行剪枝,使用prune.py。替换模型位置,直接运行。

(3)剪完枝后,效果不一定好,所以使用剪枝完后的模型,继续训练:


一.代码前提

(1)本文选取的老师模型为yolov8s,学生为剪枝完的yolov8s

(2)本文使用的蒸馏方法包括mgd,cwd

(3)使用前下载必须的包,并且把数据集放在datasets文件夹中,最后替换data.yaml中分类。

本文代码已经上传到GitHub,链接:yolov8_蒸馏

使用不妨加个关注,后续还会加入Vit(vision transformer),替换loss等提升精度的方法。

二.蒸馏步骤

(1) 训练教师模型

打开文件中train.py,替换模型文件位置。开始训练,达到理想目标就停止。

import os
from ultralytics import YOLO
import torch
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'


def main():
    model = YOLO("yolov8s.pt")
    model.train(data="data.yaml", Distillation = None, loss_type='None', amp=False, imgsz=640, epochs=50, batch=20, device=0, workers=0)


if __name__ == '__main__':
    main()

(2) 训练学生模型

打开文件中train.py,替换模型文件位置。我这边使用的是剪枝后的yolov8s模型,具体轻量化剪枝步骤可见本文最后。

import os
from ultralytics import YOLO
import torch
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'


def main():
    model_s = YOLO("./runs/detect/prune/weights/prune.pt")
    model_s.train(data="data.yaml", Distillation = None, loss_type='None', amp=False, imgsz=640, epochs=50, batch=20, device=0, workers=0)


if __name__ == '__main__':
    main()

(3) 蒸馏训练

打开文件中train_distillation.py,替换老师与学生模型文件位置。两种蒸馏方法可以选择:cwd和mgd。

import os
from ultralytics import YOLO
import torch

os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'


def main():
    model_t = YOLO('runs/detect/yolov8s/weights/best.pt')  # the teacher model
    model_s = YOLO('runs/detect/prune/weights/best.pt')  # the student model
    """
    Attributes:
        Distillation: the distillation model
        loss_type: mgd, cwd
        amp: Automatic Mixed Precision
    """
    model_s.train(data="data.yaml", Distillation=model_t.model, loss_type='mgd', amp=False, imgsz=640, epochs=100,
                  batch=20, device=0, workers=0, lr0=0.001)


if __name__ == '__main__':
    main()

现在先不进行训练,打开文件夹yolo_project_distillation\ultralytics\engine\trainer.py

在类FeatureLoss中,函数forward大概162行处打一个断点,进行调试。代码位置:

    def forward(self, y_s, y_t):
        assert len(y_s) == len(y_t)
        tea_feats = []
        stu_feats = []

        for idx, (s, t) in enumerate(zip(y_s, y_t)):
            # change ---
            if self.distiller == 'cwd':
                s = self.align_module[idx](s)
                s = self.norm[idx](s)
            else:
                s = self.norm1[idx](s)
            t = self.norm[idx](t)
            tea_feats.append(t)
            stu_feats.append(s)

        loss = self.feature_loss(stu_feats, tea_feats)
        return self.loss_weight * loss

调试运行,查看变量中学生模型y_s和老师模型y_t的张量大小。把通道数记下来,写在类Distillation_loss的

        channels_s = [256, 480, 256, 64, 143, 229][-le:]
        channels_t = [256, 512, 256, 128, 256, 512][-le:]

这边总共有六个,刚好对应模型的六个层的通道数。

替换完成后,应该就可以进行训练了。训练不好的话,再来评论区找我吧。

三.模型剪枝+蒸馏

(1)约束训练在我上一篇文章中提到,链接:yolov8剪枝

(2)约束训练后,先进行剪枝,使用prune.py。替换模型位置,直接运行。

from ultralytics import YOLO
import torch
from ultralytics.nn.modules import Bottleneck, Conv, C2f, SPPF, Detect
from copy import deepcopy

# Load a model
yolo = YOLO("./runs/detect/yolov8s/weights/last.pt")
# Save model address
res_dir = "./runs/detect/prune/weights/prune.pt"
# Pruning rate
factor = 0.75

yolo.info()
model = yolo.model
ws = []
bs = []

for name, m in model.named_modules():
    if isinstance(m, torch.nn.BatchNorm2d):
        w = m.weight.abs().detach()
        b = m.bias.abs().detach()
        ws.append(w)
        bs.append(b)
        # print(name, w.max().item(), w.min().item(), b.max().item(), b.min().item())

# keep

ws = torch.cat(ws)
threshold = torch.sort(ws, descending=True)[0][int(len(ws) * factor)]
print(threshold)


def prune_conv(conv1: Conv, conv2: Conv):
    gamma = conv1.bn.weight.data.detach()
    beta = conv1.bn.bias.data.detach()
    keep_idxs = []
    local_threshold = threshold
    while len(keep_idxs) < 8:
        keep_idxs = torch.where(gamma.abs() >= local_threshold)[0]
        local_threshold = local_threshold * 0.5
    n = len(keep_idxs)
    # n = max(int(len(idxs) * 0.8), p)
    # print(n / len(gamma) * 100)
    # scale = len(idxs) / n
    conv1.bn.weight.data = gamma[keep_idxs]
    conv1.bn.bias.data = beta[keep_idxs]
    conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs]
    conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]
    conv1.bn.num_features = n
    conv1.conv.weight.data = conv1.conv.weight.data[keep_idxs]
    conv1.conv.out_channels = n

    if conv1.conv.bias is not None:
        conv1.conv.bias.data = conv1.conv.bias.data[keep_idxs]

    if not isinstance(conv2, list):
        conv2 = [conv2]

    for item in conv2:
        if item is not None:
            if isinstance(item, Conv):
                conv = item.conv
            else:
                conv = item
            conv.in_channels = n
            conv.weight.data = conv.weight.data[:, keep_idxs]


def prune(m1, m2):
    if isinstance(m1, C2f):  # C2f as a top conv
        m1 = m1.cv2

    if not isinstance(m2, list):  # m2 is just one module
        m2 = [m2]

    for i, item in enumerate(m2):
        if isinstance(item, C2f) or isinstance(item, SPPF):
            m2[i] = item.cv1

    prune_conv(m1, m2)


for name, m in model.named_modules():
    if isinstance(m, Bottleneck):
        prune_conv(m.cv1, m.cv2)

seq = model.model
for i in range(3, 9):
    if i in [6, 4, 9]: continue
    prune(seq[i], seq[i + 1])

detect: Detect = seq[-1]
last_inputs = [seq[15], seq[18], seq[21]]
colasts = [seq[16], seq[19], None]
for last_input, colast, cv2, cv3 in zip(last_inputs, colasts, detect.cv2, detect.cv3):
    prune(last_input, [colast, cv2[0], cv3[0]])
    prune(cv2[0], cv2[1])
    prune(cv2[1], cv2[2])
    prune(cv3[0], cv3[1])
    prune(cv3[1], cv3[2])

for name, p in yolo.model.named_parameters():
    p.requires_grad = True

#yolo.val(workers=0)  # 剪枝模型进行验证 yolo.val(workers=0)
yolo.info()
# yolo.export(format="onnx")  # 导出为onnx文件
# yolo.train(data="./data/data_nc5/data_nc5.yaml", epochs=100)  # 剪枝后直接训练微调
ckpt = {
            'epoch': -1,
            'best_fitness': None,
            'model': yolo.ckpt['ema'],
            'ema': None,
            'updates': None,
            'optimizer': None,
            'train_args': yolo.ckpt["train_args"],  # save as dict
            'date': None,
            'version': '8.0.142'}

torch.save(yolo.ckpt, res_dir)

(3)剪完枝后,效果不一定好,所以使用剪枝完后的模型,继续训练:

import os
from ultralytics import YOLO
import torch
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'


def main():
    # model = YOLO(r'ultralytics/cfg/models/v8/yolov8s.yaml').load('runs/detect/yolov8s/weights/best.pt')
    model_s = YOLO("./runs/detect/prune/weights/prune.pt")
    model_s.train(data="data.yaml", Distillation = None, loss_type='None', amp=False, imgsz=640, epochs=50, batch=20, device=0, workers=0)


if __name__ == '__main__':
    main()

------------------------------------------over!!!!!!!!!!!!!!!!!------------------------------

### YOLOv8 知识蒸馏代码实现与解释 #### 1. 环境配置 为了确保能够顺利运行YOLOv8的知识蒸馏代码,环境配置至关重要。建议使用Python虚拟环境来管理依赖项,并安装必要的库和工具[^1]。 ```bash conda create -n yolov8_distillation python=3.9 conda activate yolov8_distillation pip install ultralytics torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu117 ``` #### 2. 数据准备 数据集对于训练模型非常重要,在进行知识蒸馏之前需准备好相应的图像分类或目标检测数据集并按照指定格式整理好文件结构。 #### 3. Logits-Based 蒸馏方法 Logits-Based 方法是最简单的知识蒸馏形式之一,通过让小型学生网络模仿大型教师网络的输出分布来进行学习。具体来说就是最小化两者之间的差异损失函数: \[ L_{distill} = \frac{1}{N}\sum_i^N KL(\text{softmax}(T\cdot s(x_i)) || \text{softmax}(T\cdot t(x_i))) \] 其中\(s(x)\)表示学生模型预测值;\(t(x)\)代表老师模型预测结果;\(KL\)指代Kullback-Leibler散度;而参数\(T>0\)则用来调整温度以控制软概率分布的程度。 ```python import torch.nn.functional as F def logits_based_loss(student_logits, teacher_logits, temperature=4): """计算基于logits的知识蒸馏损失""" soft_student = F.log_softmax(student_logits / temperature, dim=-1) soft_teacher = F.softmax(teacher_logits / temperature, dim=-1) return F.kl_div( soft_student, soft_teacher, reduction="batchmean" ) * (temperature ** 2) ``` #### 4. Feature-Based 蒸馏方法 Feature-Based 方式则是提取中间层特征图作为监督信号传递给学生网路,从而使得其内部表征更加接近于教师模型。通常采用均方误差(MSE)或其他相似性测度衡量两者的差距: \[ L_{feat\_distill}=\left \| f_s(X)-f_t(X) \right \|_F^{2} \] 这里\(f_s()\) 和 \(f_t()\),分别对应着学生和老师的某一层激活响应矩阵; 符号\(||\cdot||_F\) 表明 Frobenius范数运算操作。 ```python from functools import partial class FeatureDistiller(nn.Module): def __init__(self, student_model, teacher_model, layers=('layer2', 'layer3')): super().__init__() self.student_features = [] self.teacher_features = [] # 注册钩子获取特定层的特征图 for layer_name in layers: getattr(student_model.model[layer_name], "register_forward_hook")(partial(self._hook_fn, is_student=True)) getattr(teacher_model.model[layer_name], "register_forward_hook")(partial(self._hook_fn, is_student=False)) def _hook_fn(self, module, input, output, is_student): if is_student: self.student_features.append(output.detach()) else: self.teacher_features.append(output.detach()) def forward(self, inputs): outputs = {} with torch.no_grad(): _ = self.teacher(inputs) _ = self.student(inputs) feature_losses = [ F.mse_loss(s_feat, t_feat) for s_feat, t_feat in zip(self.student_features, self.teacher_features) ] total_feature_loss = sum(feature_losses)/len(feature_losses) outputs['total_feature_loss'] = total_feature_loss return outputs ```
评论 60
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值