工具介绍
最近在尝试了一些pytorch模型转caffe模型的方法,推荐一个特别好用的工具,github上的开源项目
https://github.com/xxradon/PytorchToCaffe.git
本项目最大的优点是可以支持pytorch各最新版本,对于0.3, 0.3.1, 0.4, 0.4.1 ,1.0, 1.2都有支持
同时项目还提供对caffe中prototxt每层形状和ops的分析工具
工具使用
请参考项目
example/
中的各例子
经测试该工具对YOLO v3, Deeplab v3+以及Enet等网络都可以做到比较准确的转换
代码分析
项目的主要实现在pytorch_to_caffe.py模块中
模块初始化时会创建Rp类的对象,并用这个对象覆盖pytorch中的层实现,例如卷积层的实现
F.conv2d=Rp(F.conv2d,_conv2d)
在工具使用中会调用pytorch网络的forward()方法,此时在调用到F.conv2d层是就会调用刚才覆盖的Rp(F.conv2d,_conv2d)这个对象中的__call__方法,并在此方法中调用_conv2d
_conv2d是工具内部定义的方法,作用是计算pytorch中的conv,将该层的名字以及计算得到的blob加入到之前创建的Translog中,并创建caffe中的conv实现,将pytorch中的相关权重写入caffe层中
自定义层
由之前的代码分析可知,如果需要用该工具将某个pytorch层转化为caffe层,只需要首先定义一个转换层函数如_conv2d, 用Rp类的对象覆盖原pytorch中的实现,调用pytorch的forward方法后新创建的caffe网络和权

最低0.47元/天 解锁文章
404

被折叠的 条评论
为什么被折叠?



