本节的内容主要是将数处理成tfrecord的格式,然后送进网络进行训练。
准备数据
首先使用代码data_convert.py将图片转化为tfrecord的格式
代码如下:我们需要将数据放在同一目录下的pic文件里面,包含训练集,验证集,文件结构如下:
# coding:utf-8
from __future__ import absolute_import
import argparse
import os
import logging
from src.tfrecord import main
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('-t', '--tensorflow-data-dir', default='pic/')
parser.add_argument('--train-shards', default=2, type=int)
parser.add_argument('--validation-shards', default=2, type=int)
parser.add_argument('--num-threads', default=2, type=int)
parser.add_argument('--dataset-name', default='satellite', type=str)
return parser.parse_args()
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
args = parse_args()
args.tensorflow_dir = args.tensorflow_data_dir
args.train_directory = os.path.join(args.tensorflow_dir, 'train')
args.validation_directory = os.path.join(args.tensorflow_dir, 'validation')
args.output_directory = args.tensorflow_dir
args.labels_file = os.path.join(args.tensorflow_dir, 'label.txt')
if os.path.exists(args.labels_file) is False:
logging.warning('Can\'t find label.txt. Now create it.')
all_entries = os.listdir(args.train_directory)
dirnames = []
for entry in all_entries:
if os.path.isdir(os.path.join(args.train_directory, entry)):
dirnames.append(entry)
with open(args.labels_file, 'w') as f:
for dirname in dirnames:
f.write(dirname + '\n')
main(args)
然后运行一下代码:
python data_convert.py -t pic/ \
--train-shards 2 \#将数据集分成两块
--validation-shards 2 \
--num-threads 2 \#采用两个线程产生数据
--dataset-name satellite#给生成的数据集取一个名字
下面进行训练准备
如果需要使用 Slim 微调模型,首先要下载 Slim 的源代码 。 Slim 的源代码保存在 tensorflow/models 项目中,可以使用下面的 git 命 令下载tensorflow/models ·
,git clone https://github.corn/tensorflow/models.git找到 models/research/目录中的 slim 文件夹 , 这就是要用到的 TensorF lowSlim 的源代码 。
- 定义新的dataset文件
在slim/dataset文件夹下面创建一个文件satellite.py,将flower.py文件的内容复制到satellite.py 文件里面,接下来修改代码
第一处是 FILE PATTERN 、 SPLITS_TO_ SIZES 、 NUM CLASSES , 将
真进行以下修改:
修改完 satellite.py 后,还需要在同目录的 dataset_factory. py 文件中注册satellite 数据库
。
- 定义完数据;集后,在 slim 文件夹下再新建一个 satellite 目录,在这个目
录中,完成最后的几项准备工作:
- inceptionv3的文件直接在Linux里面运行程序就可以了
wget http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz#下载模型
tar -xvf inception_v3_2016_08_28.tar.gz #解压模型
- 开始训练
python train_image_classifier.py \
--train_dir=satellite/train_dir \
--dataset_name=satellite \
--dataset_split_name=train \
--dataset_dir=satellite/data \
--model_name=inception_v3 \
--checkpoint_path=satellite/pretrained/inception_v3.ckpt \
--checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
--trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
--max_number_of_steps=100000 \
--batch_size=32 \
--learning_rate=0.001 \
--learning_rate_decay_type=fixed \
--save_interval_secs=300 \
--save_summaries_secs=2 \
--log_every_n_steps=10 \
--optimizer=rmsprop \
--weight_decay=0.00004 \
我前面弄了很久都不对,都是这个命令写错了,后来复制了一片博客的就好很多,所以复制我的没有错
后记:
这个是slim的通用框架,以后自己加什么数据都可以拿来用的,是可以训练自己的数据集的,学知识都是从一个不熟悉到熟悉的过程,所以我们熟能生巧,多加练习吧~