文章目录
下面代码不包括DDPG强化学习参数优化器和Distill蒸馏训练
PocketFlow框架安装
conda create --name PocketFlow python=3.6
source activate PocketFlow
pip install tensorflow-gpu=1.10.0
pip install numpy=1.14.5
conda install panda
conda install scikit-learn
cifar10数据集准备
cifar-10:使用binary版本
wget https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz
下载完成后,解压到data目录
并在path.conf中设置路径:data_dir_local_cifar10 = /home/mars/hewu/tensorflow/PocketFlow/data/cifar-10-batches-bin
执行代码
./scripts/run_local.sh nets/resnet_at_cifar10_run.py
官方教程需要在工程根目录配置path.conf文件,然后执行上述脚本。个人觉得不太方便调试,直接启动py,单步跟踪,更加便于理解程序逻辑。
nets/resnet_at_cifar10_run.py --model_http_url https://api.ai.tencent.com/pocketflow --data_dir_local /home/mars/hewu/tensorflow/PocketFlow/data/cifar-10-batches-bin --learner channel --cp_prune_option uniform
程序入口-main.py
#/home/mars/hewu/tensorflow/PocketFlow/main.py
from nets.resnet_at_cifar10 import ModelHelper
from learners.learner_utils import create_learner
#1创建模型helper和learner
model_helper = ModelHelper()#网络和数据集的类
learner = create_learner(sw_writer,model_helper)#跳转到不同的压缩算法learner
#2进入训练,或者评估
learner.train()
learner.evaluate()
通道裁剪学习器-channel_pruning/learner.py
#/home/mars/hewu/tensorflow/PocketFlow/learners/channel_pruning/learner.py
from learners.distillation_helper import DistillationHelper #蒸馏相关
from learners.abstract_learner import AbstractLearner
from learners.channel_pruning.model_wrapper import Model #模型相关
from learners.channel_pruning.channel_pruner import ChannelPruner #裁剪相关
from rl_agents.ddpg.agent import Agent as DdpgAgent #强化学习代理DDPG
#继承自AbstractLearner
class ChannelPrunedLearner(AbstractLearner):
#继承初始化

本文详细解析了PocketFlow框架中的通道裁剪学习器和模型封装器,通过cifar10数据集展示了通道裁剪的过程,探讨了深度学习模型压缩技术,特别指出在处理depthwise卷积时的策略,并提到最后的卷积层不进行裁剪。
最低0.47元/天 解锁文章





