Tensorflow Object Detection API 源码分析之 model_lib.py
# model_main.py 中调用,是重要的 建立模型,组合各模块的功能# 最终create_train_and_eval_specs函数 返回 train_spec 和 eval_spec (tf.estimator)r"""Constructs model, inputs, and training environment."""from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# python高阶函数包 仅使用了functools.partial# detection_model_fn 通过 detection_model_fn = functools.partial(# model_builder.build, model_config=model_config)import functools
import os
import tensorflow as tf
from object_detection import eval_util
from object_detection import inputs
from object_detection.builders import graph_rewriter_builder
from object_detection.builders import model_builder
from object_detection.builders import optimizer_builder
from object_detection.core import standard_fields as fields
from object_detection.utils import config_util
from object_detection.utils import label_map_util
from object_detection.utils import shape_utils
from object_detection.utils import variables_helper
from object_detection.utils import visualization_utils as vis_utils
# 仅为了少写几个包名(config_util, inputs)?# A map of names to methods that help build the model.
MODEL_BUILD_UTIL_MAP = {
'get_configs_from_pipeline_file':
config_util.get_configs_from_pipeline_file,
'create_pipeline_proto_from_configs':
config_util.create_pipeline_proto_from_configs,
'merge_external_params_with_configs':
config_util.merge_external_params_with_configs,
'create_train_input_fn': inputs.create_train_input_fn,
'create_eval_input_fn': inputs.create_eval_input_fn,
'create_predict_input_fn': inputs.create_predict_input_fn,
}
# 顾名思义 prepare groundtruth:从 detection_model 提取groundtruth data(即label)def_prepare_groundtruth_for_eval(detection_model, class_agnostic):"""Extracts groundtruth data from detection_model and prepares it for eval.
Args:
detection_model: A `DetectionModel` object.
class_agnostic: Whether the detections are class_agnostic.
Returns:
A tuple of:
groundtruth: Dictionary with the following fields:
'groundtruth_boxes': [num_boxes, 4] float32 tensor of boxes, in
normalized coordinates.
'groundtruth_classes': [num_boxes] int64 tensor of 1-indexed classes.
'groundtruth_masks': 3D float32 tensor of instance masks (if provided in
groundtruth)
'groundtruth_is_crowd': [num_boxes] bool tensor indicating is_crowd
annotations (if provided in groundtruth).
class_agnostic: Boolean indicating whether detections are class agnostic.
"""
input_data_fields = fields.InputDataFields()
groundtruth_boxes = detection_model.groundtruth_lists(
fields.BoxListFields.boxes)[0]
# 如果是类别无关的,one-hot就变为了一位,即0/1# For class-agnostic models, groundtruth one-hot encodings collapse to all# ones.if class_agnostic:
groundtruth_boxes_shape = tf.shape(groundtruth_boxes)
groundtruth_classes_one_hot = tf.ones([groundtruth_boxes_shape[0], 1])
else:
groundtruth_classes_one_hot = detection_model.groundtruth_lists(
fields.BoxListFields.classes)[0]
label_id_offset = 1# Applying label id offset (b/63711816)
groundtruth_classes = (
tf.argmax(groundtruth_classes_one_hot, axis=1) + label_id_offset)
groundtruth = {
input_data_fields.groundtruth_boxes: groundtruth_boxes,
input_data_fields.groundtruth_classes: groundtruth_classes
}
if detection_model.groundtruth_has_field(fields.BoxListFields.masks):
groundtruth[input_data_fields.groundtruth_instance_masks] = (
detection_model.groundtruth_lists(fields.BoxListFields.masks)[0])
if detection_model.groundtruth_has_field(fields.BoxListFields.is_crowd):
groundtruth[input_data_fields.groundtruth_is_crowd] = (
detection_model.groundtruth_lists(fields.BoxListFields.is_crowd)[0])
return groundtruth
defunstack_batch(tensor_dict, unpad_groundtruth_tensors=True):"""Unstacks all tensors in `tensor_dict` along 0th dimension.
Unstacks tensor from the tensor dict along 0th dimension and returns a
tensor_dict containing values that are lists of unstacked, unpadded tensors.
Tensors in the `tensor_dict` are expected to be of one of the three shapes:
1. [batch_size]
2. [batch_size, height, width, channels]
3. [batch_size, num_boxes, d1, d2, ... dn]
When unpad_groundtruth_tensors is set to true, unstacked tensors of form 3
above are sliced along the `num_boxes` dimension using the value in tensor
field.InputDataFields.num_groundtruth_boxes.
Note that this function has a static list of input data fields and has to be
kept in sync with the InputDataFields defined in core/standard_fields.py
Args:
tensor_dict: A dictionary of batched groundtruth tensors.
unpad_groundtruth_tensors: Whether to remove padding along `num_boxes`
dimension of the groundtruth tensors.
Returns:
A dictionary where the keys are from fields.InputDataFields and values are
a list of unstacked (optionally unpadded) tensors.
Raises:
ValueError: If unpad_tensors is True and `tensor_dict` does not contain
`num_groundtruth_boxes` tensor.
"""
unbatched_tensor_dict = {key: tf.unstack(tensor)
for key, tensor in tensor_dict.items()}
if unpad_groundtruth_tensors:
if (fields.InputDataFields.num_groundtruth_boxes notin
unbatched_tensor_dict):
raise ValueError('`num_groundtruth_boxes` not found in tensor_dict. ''Keys available: {}'.format(
unbatched_tensor_dict.keys()))
unbatched_unpadded_tensor_dict = {}
unpad_keys = set([
# List of input data fields that are padded a