前言:知识储备
-
- 首先, 选择剪枝的颗粒度:规律 or 不规则
- 然后, 选择在哪里剪枝:权重 or 结构
- 其次,选择剪枝程度:计算量减少5倍?
-
yolov7代码中的
train.py
,test.py
要了解因为我们剪枝进行 finetune 的时候需要
train()
这个函数,prune 的时候需要test()
这个函数
1. 剪枝流程
剪枝是循序渐进的过程,有step的一点一点的剪枝,而不是上来就剪掉50%。
剪枝 -> finetune -> 剪枝 -> finetune … 直到 满足你的需求(剪到模型足够小了,计算量足够低了),就可以停下while循环了。
如下图所示:
2. 剪枝工具 Torch-Pruning
大家可以看一下 Torch-Pruning 的作者,对工具底层的解释:Torch-Pruning | 轻松实现结构化剪枝算法
Torch-Pruning的ResNet18 简单示例:
import torch
from torchvision.models import resnet18
import torch_pruning as tp
model = resnet18(pretrained=True)
example_inputs = torch.randn(1, 3, 224, 224)
# 1. 选择合适的重要性评估指标,这里使用权值大小
imp = tp.importance.MagnitudeImportance(p=2)
# 2. 忽略无需剪枝的层,例如最后的分类层(总不能剪完类别都变少了叭?)
ignored_layers = []
for m in model.modules():
if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
ignored_layers.append(m) # DO NOT prune the final classifier!
# 3. 初始化剪枝器
iterative_steps = 5 # 迭代式剪枝,重复5次Pruning-Finetuning的循环完成剪枝。
pruner = tp.pruner.MagnitudePruner(
model,
example_inputs, # 用于分析依赖的伪输入
importance=imp, # 重要性评估指标
iterative_steps=iterative_steps, # 迭代剪枝,设为1则一次性完成剪枝
ch_sparsity=0.5, # 目标稀疏性,这里我们移除50%的通道 ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
ignored_layers=ignored_layers, # 忽略掉最后的分类层
)
# 4. Pruning-F