代码: https://github.com/tensorflow/models/tree/master/research/deeplab
train.py主要函数及注释如下
main()
# 配置GPU
conifg = slim.deployment.model_deploy.DeploymentConfig(xxx) # Create a DeploymentConfig for multi-gpu
# 获取slim数据集实例
dataset = deeplab.datasets.segmentation_dataset.get_dataset(xxx) # Gets an instance of slim dataset
# 得到数据
samples = input_generator.get(dataset, xxx)
# Creates a queue to prefetch tensors from `tensors`
inputs_queue = prefetch_queue.prefetch_queue(samples, capacity=128 * config.num_clones)
#
clones = Clone(_build_deeplab(inputs_queue, xxx), scope, device)
learning_rate = train_utils.get_model_learning_rate(xxx)
slim.learning.train(xxx)
deeplab.datasets.segmentation_dataset.get_dataset(dataset_name, split_name, dataset_dir):
# 将example反序列化成存储之前的格式。由tf完成
keys_to_features
# 将反序列化的数据组装成更高级的格式。由slim完成
items_to_handlers
# 解码器,进行解码
decoder = tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)
return dataset.Dataset(xxx)
deeplab.utils.input_generator.get(dataset, xxx)
# provider对象根据dataset信息读取数据
data_provider = slim.dataset_data_provider.DatasetDataProvider(dataset, xxx)
# 获取数据,获取到的数据是单个数据,还需要对数据进行预处理,组合数据
image, height, width = data_provider.get([common.IMAGE, common.HEIGHT, common.WIDTH])
original_image, image, label = input_preprocess.preprocess_image_and_label(xxx)
return tf.train.batch(xxx)
_build_deeplab(inputs_queue, outputs_to_num_classes, ignore_label)
# 获取数据
samples = inputs_queue.dequeue()
model_options = common.ModelOptions