pytorch模型剪枝学习笔记

本文介绍了PyTorch中模型剪枝的实现,包括官方支持的API、剪枝策略及其工作原理。详细讲解了如何使用BasePruningMethod抽象基类创建自定义剪枝方法,并探讨了剪枝过程,包括前向回调函数、参数处理和mask计算。最后,提到了剪枝结果的持久化以及如何编写自己的剪枝策略。

pytorch代码仓库

pytorch在19年11月份的时候合入了这部分剪枝的代码。pytorch提供一些直接可用的api,用户只需要传入需要剪枝的module实例和需要剪枝的参数名字,系统自动帮助完成剪枝操作,看起来接口挺简单。比如 def random_structured(module, name, amount, dim)

pytorch支持的几种类型的剪枝策略:

详细分析

  • pytorch提供了一个剪枝的抽象基类‘‘class BasePruningMethod(ABC)’,所有剪枝策略都需要继承该基类,并重载部分函数就可以了

  • 一般情况下需要重载init和compute_mask方法,call, apply_mask, apply, prune和remove不需要重载,例如官方提供的RandomUnstructured剪枝方法 file

  • 基类实现的6个方法: file

  • 剪枝的API接口,可以看到支持用户自定义的剪枝mask,接口为custom_from_mask

评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值