YOLOv10结构化通道剪枝+微调

本项目基于 torch-pruning 官方提供的 YOLOv8 剪枝方案,构建了适用于 YOLOv10 的 GroupNormPruner 结构化剪枝 + 微调流程。在不修改 YOLOv10 原始源码的前提下,仅通过新增 prune.py 以及少量模块适配文件,实现剪枝与后续训练的无缝集成。

一. 源码模块的修改

1.C2f ------ C2f_v2 (与torch-pruning官方给出的一样,不加以赘述)

需要注意:在判断是否存在 C2f 时,可能会判断出C2fCIB,因此需要添加

not isinstance(child_module, C2fCIB)
def infer_shortcut(bottleneck):
    c1 = bottleneck.cv1.conv.in_channels
    c2 = bottleneck.cv2.conv.out_channels
    return c1 == c2 and hasattr(bottleneck, 'add') and bottleneck.add


        if (isinstance(child_module, C2f)) and (not isinstance(child_module, C2fCIB)):
            # Replace C2f with C2f_v2 while preserving its parameters
            shortcut = infer_shortcut(child_module.m[0])
            c2f_v2 = C2f_v2(child_module.cv1.conv.in_channels,         
                            child_module.cv2.conv.out_channels,
                            n=len(child_module.m), shortcut=shortcut,
                            g=child_module.m[0].cv2.conv.groups,
                            e=child_module.c / child_module.cv2.conv.out_channels)
            transfer_weights(child_module, c2f_v2)
            setattr(module, name, c2f_v2)

2.C2fCIB ------ C2fCIB_v2

与C2f相似,将chuck转化为两个卷积(cv0 + cv1)

        elif isinstance(child_module, C2fCIB):
            shortcut = infer_shortcut2(child_module.m[0])
            c2fcib_v2 = C2fCIB_v2(child_module.cv1.conv.in_channels, 
                                  child_module.cv2.conv.out_channels,
                                  n=len(child_module.m), shortcut=shortcut,
                                  g=child_module.m[0].cv1[4].conv.groups,
                                  e=child_module.c / child_module.cv2.conv.out_channels)
            transfer_weights(child_module, c2fcib_v2)
            setattr(module, name, c2fcib_v2)



def infer_shortcut2(cib):
    c1 = cib.cv1[0].conv.in_channels
    c2 = cib.cv1[4].conv.out_channels
    return c1 == c2 and hasattr(cib, 'add') and cib.add

3.在DocLayout存在G2L_CRM ------ G2L_CRM_v2DocLayout-YOLOhttps://github.com/opendatalab/DocLayout-YOLO/tree/5fbc138175099c98a42381d0753ff50ebc5fe494

 网络结构如下:

G2L_CRM_v2:

        elif isinstance(child_module, G2L_CRM):
            count_g2lcrm += 1
            if count_g2lcrm == 3:
                shortcut = infer_shortcut2(child_module.m[0])
                block_k = 3
                fuse = "sum"
            else:
                shortcut = infer_shortcut(child_module.m[0])
                block_k = child_module.m[0].dilated_block.dcv.conv.kernel_size[0]
                fuse = "glu" if hasattr(child_module.m[0].dilated_block, 'conv_gating') 
                        else "sum"
            use_dilated = isinstance(child_module.m[0], DilatedBottleneck)
            g2l_crm_v2 = G2L_CRM_v2(child_module.cv1.conv.in_channels, 
                                    child_module.cv2.conv.out_channels,
                                    n=len(child_module.m), shortcut=shortcut, 
                                    use_dilated=use_dilated,
                                    dilation=[1, 3, 5] if count_g2lcrm == 2 else [1, 2, 3],
                                    block_k=block_k, fuse=fuse
                                    )
            transfer_weights(child_module, g2l_crm_v2)
            setattr(module, name, g2l_crm_v2)

4.构建新的py文件,存储C2f_v2、C2fCIB_v2、G2L_CRM_v2

为什么需要构建新的py文件存储,而不是在当前的剪枝.py下存储?

如果在当前的剪枝.py文件下存储,会报错

f=WindowsPath('runs/detect/prune/2025-04-02_16-24-13/step_0_finetune/weights/best.pt')
torch.load(f, map_location=torch.device("cpu"))
PyDev console: starting.
Traceback (most recent call last):
  File "C:\Program Files\JetBrains\PyCharm 2024.1.7\plugins\python\helpers-pro\pydevd_asyncio\pydevd_asyncio_utils.py", line 117, in _exec_async_code
    result = func()
  File "<input>", line 1, in <module>
  File "D:\Anaconda\envs\cpupytorch\lib\site-packages\torch\serialization.py", line 1360, in load
    return _load(
  File "D:\Anaconda\envs\cpupytorch\lib\site-packages\torch\serialization.py", line 1848, in _load
    result = unpickler.load()
  File "D:\A
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值