本项目基于 torch-pruning 官方提供的 YOLOv8 剪枝方案,构建了适用于 YOLOv10 的 GroupNormPruner 结构化剪枝 + 微调流程。在不修改 YOLOv10 原始源码的前提下,仅通过新增
prune.py
以及少量模块适配文件,实现剪枝与后续训练的无缝集成。
一. 源码模块的修改
1.C2f ------ C2f_v2 (与torch-pruning官方给出的一样,不加以赘述)
需要注意:在判断是否存在 C2f 时,可能会判断出C2fCIB,因此需要添加
not isinstance(child_module, C2fCIB)
def infer_shortcut(bottleneck):
c1 = bottleneck.cv1.conv.in_channels
c2 = bottleneck.cv2.conv.out_channels
return c1 == c2 and hasattr(bottleneck, 'add') and bottleneck.add
if (isinstance(child_module, C2f)) and (not isinstance(child_module, C2fCIB)):
# Replace C2f with C2f_v2 while preserving its parameters
shortcut = infer_shortcut(child_module.m[0])
c2f_v2 = C2f_v2(child_module.cv1.conv.in_channels,
child_module.cv2.conv.out_channels,
n=len(child_module.m), shortcut=shortcut,
g=child_module.m[0].cv2.conv.groups,
e=child_module.c / child_module.cv2.conv.out_channels)
transfer_weights(child_module, c2f_v2)
setattr(module, name, c2f_v2)
2.C2fCIB ------ C2fCIB_v2
与C2f相似,将chuck转化为两个卷积(cv0 + cv1)
elif isinstance(child_module, C2fCIB):
shortcut = infer_shortcut2(child_module.m[0])
c2fcib_v2 = C2fCIB_v2(child_module.cv1.conv.in_channels,
child_module.cv2.conv.out_channels,
n=len(child_module.m), shortcut=shortcut,
g=child_module.m[0].cv1[4].conv.groups,
e=child_module.c / child_module.cv2.conv.out_channels)
transfer_weights(child_module, c2fcib_v2)
setattr(module, name, c2fcib_v2)
def infer_shortcut2(cib):
c1 = cib.cv1[0].conv.in_channels
c2 = cib.cv1[4].conv.out_channels
return c1 == c2 and hasattr(cib, 'add') and cib.add
3.在DocLayout存在G2L_CRM ------ G2L_CRM_v2DocLayout-YOLO
https://github.com/opendatalab/DocLayout-YOLO/tree/5fbc138175099c98a42381d0753ff50ebc5fe494
网络结构如下:
G2L_CRM_v2:
elif isinstance(child_module, G2L_CRM):
count_g2lcrm += 1
if count_g2lcrm == 3:
shortcut = infer_shortcut2(child_module.m[0])
block_k = 3
fuse = "sum"
else:
shortcut = infer_shortcut(child_module.m[0])
block_k = child_module.m[0].dilated_block.dcv.conv.kernel_size[0]
fuse = "glu" if hasattr(child_module.m[0].dilated_block, 'conv_gating')
else "sum"
use_dilated = isinstance(child_module.m[0], DilatedBottleneck)
g2l_crm_v2 = G2L_CRM_v2(child_module.cv1.conv.in_channels,
child_module.cv2.conv.out_channels,
n=len(child_module.m), shortcut=shortcut,
use_dilated=use_dilated,
dilation=[1, 3, 5] if count_g2lcrm == 2 else [1, 2, 3],
block_k=block_k, fuse=fuse
)
transfer_weights(child_module, g2l_crm_v2)
setattr(module, name, g2l_crm_v2)
4.构建新的py文件,存储C2f_v2、C2fCIB_v2、G2L_CRM_v2
为什么需要构建新的py文件存储,而不是在当前的剪枝.py下存储?
如果在当前的剪枝.py文件下存储,会报错
f=WindowsPath('runs/detect/prune/2025-04-02_16-24-13/step_0_finetune/weights/best.pt')
torch.load(f, map_location=torch.device("cpu"))
PyDev console: starting.
Traceback (most recent call last):
File "C:\Program Files\JetBrains\PyCharm 2024.1.7\plugins\python\helpers-pro\pydevd_asyncio\pydevd_asyncio_utils.py", line 117, in _exec_async_code
result = func()
File "<input>", line 1, in <module>
File "D:\Anaconda\envs\cpupytorch\lib\site-packages\torch\serialization.py", line 1360, in load
return _load(
File "D:\Anaconda\envs\cpupytorch\lib\site-packages\torch\serialization.py", line 1848, in _load
result = unpickler.load()
File "D:\A