前言
引用参考了以下网站:
https://github.com/tensorflow/models/tree/master/research/slim
https://blog.youkuaiyun.com/stesha_chen/article/details/81976415
https://blog.youkuaiyun.com/rookie_wei/article/details/80796009
"""
contrib模块在tensorflow2.0将会移除,slim模块位于contrib中
tensorflow slim模块适用于图像分类,使用的数据格式为tfrecord
"""
备注:如果使用docker方式安装tensorflow的话
docker run --runtime=nvidia --name=yuankun_tfslim -it -p 15003:8000 -p 15004:22 -p 15005:5000 -p 15006:6006 -v /home/yk:/yk tensorflow/tensorflow:latest-gpu
注意映射tensorboard 6006端口到外部。
1.在某些版本中tensorflow不存在TF-slim模块,执行下列语句确保slim存在
如果存在则不会报错(新版本中有slim模块,当前使用1.13.1)
python -c "import tensorflow.contrib.slim as slim; eval = slim.evaluation.evaluate_once"
2.安装TF models模块(该模块tensorflow并不自带,需自行安装)
cd $HOME/workspace
git clone https://github.com/tensorflow/models/ #安装tensorflow包根目录
运行以下命令,不报错则安装成功
cd $HOME/workspace/models/research/slim
python -c "from nets import cifarnet; mynet = cifarnet.cifarnet"
3.下载flower数据集转换为tfrecord格式
DATA_DIR=/tmp/data/flowers
python download_and_convert_data.py \
--dataset_name=flowers \
--dataset_dir="${DATA_DIR}"
4.开始训练
DATASET_DIR=/tmp/data/flowers
TRAIN_DIR=/tmp/train_logs
python train_image_classifier.py \
--train_dir=${TRAIN_DIR} \
--dataset_name=flowers \
--dataset_split_name=train \
--dataset_dir=${DATASET_DIR} \
--model_name=inception_resnet_v2 \
--max_number_of_steps=500 \
--batch_size=32 \
--learning_rate=0.0001 \
--learning_rate_decay_type=fixed \
--save_interval_secs=60 \
--save_summaries_secs=60 \
--log_every_n_steps=10 \
--optimizer=rmsprop \
--weight_decay=0.00004
备注:#model_name:定义所使用的模型
#model_name:定义所使用的模型
可供选择的model:
inception_resnet_v2
inception_v1
inception_v2
inception_v3
inception_v4
vgg_16
vgg_19
resnet_v1_50,resnet_v1_101,resnet_v1_152,resnet_v1_200
resnet_v2_50,resnet_v2_101,resnet_v2_152,resnet_v2_200
等等,具体参考https://github.com/tensorflow/models/tree/master/research/slim/nets
5.查看tensorboard日志
tensorboard --logdir=/tmp/train_logs
6.微调模型的方法
使用inception_resnet_v2模型机型微调
PRETRAINED_CHECKPOINT_DIR=/tmp/checkpoints
MODEL_NAME=inception_resnet_v2
TRAIN_DIR=/tmp/flowers-models/${MODEL_NAME}
DATASET_DIR=/tmp/data/flowers
mkdir ${PRETRAINED_CHECKPOINT_DIR}
wget http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz
tar -xvf inception_resnet_v2_2016_08_30.tar.gz
#下载预训练模型
mv inception_resnet_v2.ckpt ${PRETRAINED_CHECKPOINT_DIR}/${MODEL_NAME}.ckpt
python train_image_classifier.py \
--train_dir=${TRAIN_DIR} \
--dataset_name=flowers \
--dataset_split_name=train \
--dataset_dir=${DATASET_DIR} \
--model_name=${MODEL_NAME} \
--checkpoint_path=${PRETRAINED_CHECKPOINT_DIR}/${MODEL_NAME}.ckpt \
--checkpoint_exclude_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits \
--trainable_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits \
--max_number_of_steps=1000 \
--batch_size=32 \
--learning_rate=0.01 \
--learning_rate_decay_type=fixed \
--save_interval_secs=60 \
--save_summaries_secs=60 \
--log_every_n_steps=10 \
--optimizer=rmsprop \
--weight_decay=0.00004
#微调说明:
--checkpoint_exclude_scopes # 第一次不加载这些参数
--trainable_scopes # 重新训练这部分参数