yolov11剪枝

GPT-oss:20b

GPT-oss:20b

图文对话
Gpt-oss

GPT OSS 是OpenAI 推出的重量级开放模型,面向强推理、智能体任务以及多样化开发场景

思路:yolov11中的C3k2与yolov8的c2f的不同,所以与之前yolov8剪枝有稍许不同;

后续:会将剪枝流程写全,以及增加蒸馏、注意力、改loss;

注意:

1.在代码105行修改pruning.get_threshold(yolo.model, 0.65),可以获得不同的剪枝率;

2.改代码放在训练代码同一页面下即可;

3.在最后修改文件夹地址来获得剪枝后的模型;

from ultralytics import YOLO
import torch
from ultralytics.nn.modules import Bottleneck, Conv, C2f, SPPF, Detect, C3k2
from torch.nn.modules.container import Sequential
import os


# os.environ["CUDA_VISIBLE_DEVICES"] = "2"


class PRUNE():
    def __init__(self) -> None:
        self.threshold = None

    def get_threshold(self, model, factor=0.8):
        ws = []
        bs = []
        for name, m in model.named_modules():
            if isinstance(m, torch.nn.BatchNorm2d):
                w = m.weight.abs().detach()
                b = m.bias.abs().detach()
                ws.append(w)
                bs.append(b)
                print(name, w.max().item(), w.min().item(), b.max().item(), b.min().item())
                print()
        # keep
        ws = torch.cat(ws)
        self.threshold = torch.sort(ws, descending=True)[0][int(len(ws) * factor)]

    def prune_conv(self, conv1: Conv, conv2: Conv):
        ## Normal Pruning
        gamma = conv1.bn.weight.data.detach()
        beta = conv1.bn.bias.data.detach()

        keep_idxs = []
        local_threshold = self.threshold
        while len(keep_idxs) < 8:  ## 若剩余卷积核<8, 则降低阈值重新筛选
            keep_idxs = torch.where(gamma.abs() >= local_threshold)[0]
            local_threshold = local_threshold * 0.5
        n = len(keep_idxs)
        # n = max(int(len(idxs) * 0.8), p)
        print(n / len(gamma) * 100)
        conv1.bn.weight.data = gamma[keep_idxs]
        conv1.bn.bias.data = beta[keep_idxs]
        conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs]
        conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]
        conv1.bn.num_features = n
        conv1.conv.weight.data = conv1.conv.weight.data[keep_idxs]
        conv1.conv.out_channels = n

        if isinstance(conv2, list) and len(conv2) > 3 and conv2[-1]._get_name() == "Proto":
            proto = conv2.pop()
            proto.cv1.conv.in_channels = n
            proto.cv1.conv.weight.data = proto.cv1.conv.weight.data[:, keep_idxs]
        if conv1.conv.bias is not None:
            conv1.conv.bias.data = conv1.conv.bias.data[keep_idxs]

        ## Regular Pruning
        if not isinstance(conv2, list):
            conv2 = [conv2]
        for item in conv2:
            if item is None: continue
            if isinstance(item, Conv):
                conv = item.conv
            else:
                conv = item
            if isinstance(item, Sequential):
                conv1 = item[0]
                conv = item[1].conv
                conv1.conv.in_channels = n
                conv1.conv.out_channels = n
                conv1.conv.groups = n
                conv1.conv.weight.data = conv1.conv.weight.data[keep_idxs, :]
                conv1.bn.bias.data = conv1.bn.bias.data[keep_idxs]
                conv1.bn.weight.data = conv1.bn.weight.data[keep_idxs]
                conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs]
                conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]
                conv1.bn.num_features = n
            conv.in_channels = n
            conv.weight.data = conv.weight.data[:, keep_idxs]

    def prune(self, m1, m2):
        if isinstance(m1, C3k2):  # C3k2 as a top conv
            m1 = m1.cv2
        if isinstance(m1, Sequential):
            m1 = m1[1]
        if not isinstance(m2, list):  # m2 is just one module
            m2 = [m2]
        for i, item in enumerate(m2):
            if isinstance(item, C3k2) or isinstance(item, SPPF):
                m2[i] = item.cv1

        self.prune_conv(m1, m2)


def do_pruning(modelpath, savepath):
    pruning = PRUNE()

    ### 0. 加载模型
    yolo = YOLO(modelpath)  # build a new model from scratch
    pruning.get_threshold(yolo.model, 0.65)  # 这里的0.8为剪枝率。

    ### 1. 剪枝C3k2 中的Bottleneck
    for name, m in yolo.model.named_modules():
        if isinstance(m, Bottleneck):
            pruning.prune_conv(m.cv1, m.cv2)

    ### 2. 指定剪枝不同模块之间的卷积核
    seq = yolo.model.model
    for i in [3, 5, 7, 8]:
        pruning.prune(seq[i], seq[i + 1])

    ### 3. 对检测头进行剪枝
    # 在P3层: seq[15]之后的网络节点与其相连的有 seq[16]、detect.cv2[0] (box分支)、detect.cv3[0] (class分支)
    # 在P4层: seq[18]之后的网络节点与其相连的有 seq[19]、detect.cv2[1] 、detect.cv3[1]
    # 在P5层: seq[21]之后的网络节点与其相连的有 detect.cv2[2] 、detect.cv3[2]
    detect: Detect = seq[-1]
    proto = detect.proto
    last_inputs = [seq[16], seq[19], seq[22]]
    colasts = [seq[17], seq[20], None]
    for idx, (last_input, colast, cv2, cv3, cv4) in enumerate(zip(last_inputs, colasts, detect.cv2, detect.cv3, detect.cv4)):
        if idx == 0:
            pruning.prune(last_input, [colast, cv2[0], cv3[0], cv4[0], proto])
        else:
            pruning.prune(last_input, [colast, cv2[0], cv3[0], cv4[0]])
        pruning.prune(cv2[0], cv2[1])
        pruning.prune(cv2[1], cv2[2])
        pruning.prune(cv3[0], cv3[1])
        pruning.prune(cv3[1], cv3[2])
        pruning.prune(cv4[0], cv4[1])
        pruning.prune(cv4[1], cv4[2])

    ### 4. 模型梯度设置与保存
    for name, p in yolo.model.named_parameters():
        p.requires_grad = True

    yolo.val(data='data.yaml', batch=2, device=0, workers=0)
    torch.save(yolo.ckpt, savepath)



if __name__ == "__main__":
    modelpath = "runs/segment/Constraint/weights/best.pt"
    savepath = "runs/segment/Constraint/weights/last_prune.pt"
    do_pruning(modelpath, savepath)

您可能感兴趣的与本文相关的镜像

GPT-oss:20b

GPT-oss:20b

图文对话
Gpt-oss

GPT OSS 是OpenAI 推出的重量级开放模型,面向强推理、智能体任务以及多样化开发场景

目前YOLO系列的最新版本为YOLOv8,尚未存在官方或广泛认可的YOLOv11版本。然而,可以基于YOLOv5和YOLOv8的剪枝方法推测,若未来出现YOLOv11模型,其剪枝技术路径将继承并优化现有模型的结构化剪枝策略。以下为基于现有YOLO剪枝方法总结的通用剪枝方法及实现指南。 ### 剪枝方法概述 模型剪枝是一种常见的模型压缩技术,旨在通过移除神经网络中冗余或不重要的权重或结构,从而减小模型规模并提升推理效率,同时尽可能保持模型性能。对于YOLO系列模型,剪枝通常分为**非结构化剪枝**与**结构化剪枝**两种形式: - **非结构化剪枝**:主要针对单个权重进行剪枝,通常通过设定阈值将接近于零的权重置为零,从而形成稀疏矩阵。虽然可以显著减少参数数量,但对计算效率提升有限。 - **结构化剪枝**:以通道(channel)或层(layer)为单位进行剪枝,能够有效减少推理时的计算量和内存占用。例如,通过分析模型中各通道对最终输出的影响,移除影响较小的通道[^2]。 ### YOLO模型剪枝的关键步骤 1. **依赖图分析(DepGraph)**:在结构化剪枝中,使用依赖图分析可以识别出哪些通道或层可以被安全剪枝而不破坏模型的整体结构。例如,在YOLOv8中应用基于DepGraph的剪枝方法,能够更安全、有效地进行剪枝操作[^1]。 2. **剪枝策略设计**:选择合适的剪枝策略是关键。可以基于L1范数、BN层的gamma系数、通道重要性评分等指标来评估通道的重要性。通常,重要性评分越低的通道优先被剪除。 3. **剪枝后微调(Fine-tuning)**:剪枝操作会破坏模型原有的参数分布,因此剪枝后需要进行微调以恢复模型性能。这一过程可以采用稀疏训练(Sparse Training)技术,如SparseML等工具提供的训练流程,实现剪枝与训练的无缝集成[^3]。 4. **评估与迭代**:剪枝是一个迭代过程,通常需要多次剪枝-微调-评估的循环,直到达到预期的模型压缩率和性能平衡。 ### 实现指南 以下为实现YOLO模型剪枝的基本步骤: #### 1. 准备环境 - 安装必要的库,如PyTorch、SparseML、torchvision等。 - 获取YOLO模型源码,如YOLOv5或YOLOv8的官方实现。 #### 2. 加载预训练模型 ```python import torch from models.yolo import Model # 加载预训练模型 model = torch.load('yolov5s.pt')['model'].float().eval() ``` #### 3. 构建剪枝配置(Pruning Recipe) 以YOLOv5为例,可以使用SparseML提供的剪枝配置文件(如`yolov5s.pruned.md`)定义剪枝策略: ```yaml # yolov5s.pruned.md version: 1 pruners: - name: channel_pruner class: ChannelPruner args: sparsity: 0.5 ignore: - model.24.m.0.weight - model.24.m.1.weight - model.24.m.2.weight ``` #### 4. 执行剪枝与微调 利用SparseML与YOLOv5训练流程集成的模块(如`sparse.py`),可以实现剪枝训练: ```python from yolov5.utils.sparse import SparseTraining # 初始化剪枝流程 sparse_trainer = SparseTraining(model, recipe='yolov5s.pruned.md') # 开始剪枝训练 sparse_trainer.train(dataloader, epochs=50) ``` #### 5. 模型导出与部署优化 剪枝完成后,可将模型导出为ONNX格式,并使用NCNN等推理框架进行部署优化: ```bash # 导出ONNX模型 python export.py --weights yolov5s_pruned.pt --img 640 --batch 1 --include onnx # 转换为NCNN格式 ./onnx2ncnn yolov5s_pruned.onnx yolov5s.param yolov5s.bin ./ncnnoptimize yolov5s.param yolov5s.bin yolov5s-opt.param yolov5s-opt.bin 1 ``` ### 注意事项 - 剪枝比例不宜过高,否则可能导致模型性能显著下降。 - 剪枝后必须进行微调,以恢复模型精度。 - 对于YOLOv11(若未来发布),建议关注其官方文档与社区支持的剪枝工具链。
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值