pytorch在19年11月份的时候合入了这部分剪枝的代码。pytorch提供一些直接可用的api,用户只需要传入需要剪枝的module实例和需要剪枝的参数名字,系统自动帮助完成剪枝操作,看起来接口挺简单。比如 def random_structured(module, name, amount, dim)
pytorch支持的几种类型的剪枝策略:

详细分析
pytorch提供了一个剪枝的抽象基类‘‘class BasePruningMethod(ABC)’,所有剪枝策略都需要继承该基类,并重载部分函数就可以了
一般情况下需要重载init和compute_mask方法,call, apply_mask, apply, prune和remove不需要重载,例如官方提供的RandomUnstructured剪枝方法

基类实现的6个方法:

剪枝的API接口,可以看到支持用户自定义的剪枝mask,接口为custom_from_mask

本文介绍了PyTorch中模型剪枝的实现,包括官方支持的API、剪枝策略及其工作原理。详细讲解了如何使用BasePruningMethod抽象基类创建自定义剪枝方法,并探讨了剪枝过程,包括前向回调函数、参数处理和mask计算。最后,提到了剪枝结果的持久化以及如何编写自己的剪枝策略。
最低0.47元/天 解锁文章
1502





