下面将会使用VGG16为基础,来微调(Fine-tune)模型达到训练我们自己的数据的目的。这里将会分类一些地表的卫星图片来区分森林、水域、岩石、农田、冰川和城市区域。数据集已经上传至优快云:https://download.youkuaiyun.com/download/viafcccy/11791071
一、数据集
这里需要了解一下python在命令行下的参数解析
1.from __future__ import absolute_import理解 https://blog.youkuaiyun.com/viafcccy/article/details/101061413
2.argparse库 https://blog.youkuaiyun.com/viafcccy/article/details/101061661
# coding:utf-8
from __future__ import absolute_import
import argparse
import os
import logging
from src.tfrecord import main
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('-t', '--tensorflow-data-dir', default='pic/')
parser.add_argument('--train-shards', default=2, type=int)
parser.add_argument('--validation-shards', default=2, type=int)
parser.add_argument('--num-threads', default=2, type=int)
parser.add_argument('--dataset-name', default='satellite', type=str)
return parser.parse_args()
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
args = parse_args()
args.tensorflow_dir = args.tensorflow_data_dir
args.train_directory = os.path.join(args.tensorflow_dir, 'train')
args.validation_directory = os.path.join(args.tensorflow_dir, 'validation')
args.output_directory = args.tensorflow_dir
args.labels_file = os.path.join(args.tensorflow_dir, 'label.txt')
if os.path.exists(args.labels_file) is False:
logging.warning('Can\'t find label.txt. Now create it.')
all_entries = os.listdir(args.train_directory)
dirnames = []
for entry in all_entries:
if os.path.isdir(os.path.join(args.train_directory, entry)):
dirnames.append(entry)
with open(args.labels_file, 'w') as f:
for dirname in dirnames:
f.write(dirname + '\n')
main(args)#将args作为参数传入src.tfrecord的main()函数执行
在命令行切换到当前的目录,输入
python data_convert.py -t