Tensorflow Object Detection API 源码分析之 model_lib.py

Tensorflow Object Detection API 源码分析之 model_lib.py

# model_main.py 中调用,是重要的 建立模型,组合各模块的功能
# 最终create_train_and_eval_specs函数 返回 train_spec 和 eval_spec (tf.estimator)
r"""Constructs model, inputs, and training environment."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# python高阶函数包 仅使用了functools.partial
# detection_model_fn 通过 detection_model_fn = functools.partial(
#                           model_builder.build, model_config=model_config)
import functools
import os

import tensorflow as tf

from object_detection import eval_util
from object_detection import inputs
from object_detection.builders import graph_rewriter_builder
from object_detection.builders import model_builder
from object_detection.builders import optimizer_builder
from object_detection.core import standard_fields as fields
from object_detection.utils import config_util
from object_detection.utils import label_map_util
from object_detection.utils import shape_utils
from object_detection.utils import variables_helper
from object_detection.utils import visualization_utils as vis_utils

# 仅为了少写几个包名(config_util, inputs)?
# A map of names to methods that help build the model.
MODEL_BUILD_UTIL_MAP = {
    'get_configs_from_pipeline_file':
        config_util.get_configs_from_pipeline_file,
    'create_pipeline_proto_from_configs':
        config_util.create_pipeline_proto_from_configs,
    'merge_external_params_with_configs':
        config_util.merge_external_params_with_configs,
    'create_train_input_fn': inputs.create_train_input_fn,
    'create_eval_input_fn': inputs.create_eval_input_fn,
    'create_predict_input_fn': inputs.create_predict_input_fn,
}

# 顾名思义 prepare groundtruth:从 detection_model 提取groundtruth data(即label)
def _prepare_groundtruth_for_eval(detection_model, class_agnostic):
  """Extracts groundtruth data from detection_model and prepares it for eval.

  Args:
    detection_model: A `DetectionModel` object.
    class_agnostic: Whether the detections are class_agnostic.

  Returns:
    A tuple of:
    groundtruth: Dictionary with the following fields:
      'groundtruth_boxes': [num_boxes, 4] float32 tensor of boxes, in
        normalized coordinates.
      'groundtruth_classes': [num_boxes] int64 tensor of 1-indexed classes.
      'groundtruth_masks': 3D float32 tensor of instance masks (if provided in
        groundtruth)
      'groundtruth_is_crowd': [num_boxes] bool tensor indicating is_crowd
        annotations (if provided in groundtruth).
    class_agnostic: Boolean indicating whether detections are class agnostic.
  """
  input_data_fields = fields.InputDataFields()
  groundtruth_boxes = detection_model.groundtruth_lists(
      fields.BoxListFields.boxes)[0]
  # 如果是类别无关的,one-hot就变为了一位,即0/1
  # For class-agnostic models, groundtruth one-hot encodings collapse to all
  # ones.
  if class_agnostic:
    groundtruth_boxes_shape = tf.shape(groundtruth_boxes)
    groundtruth_classes_one_hot = tf.ones([groundtruth_boxes_shape[0], 1])
  else:
    groundtruth_classes_one_hot = detection_model.groundtruth_lists(
        fields.BoxListFields.classes)[0]
  label_id_offset = 1  # Applying label id offset (b/63711816)
  groundtruth_classes = (
      tf.argmax(groundtruth_classes_one_hot, axis=1) + label_id_offset)
  groundtruth = {
      input_data_fields.groundtruth_boxes: groundtruth_boxes,
      input_data_fields.groundtruth_classes: groundtruth_classes
  }
  if detection_model.groundtruth_has_field(fields.BoxListFields.masks):
    groundtruth[input_data_fields.groundtruth_instance_masks] = (
        detection_model.groundtruth_lists(fields.BoxListFields.masks)[0])
  if detection_model.groundtruth_has_field(fields.BoxListFields.is_crowd):
    groundtruth[input_data_fields.groundtruth_is_crowd] = (
        detection_model.groundtruth_lists(fields.BoxListFields.is_crowd)[0])
  return groundtruth


def unstack_batch(tensor_dict, unpad_groundtruth_tensors=True):
  """Unstacks all tensors in `tensor_dict` along 0th dimension.

  Unstacks tensor from the tensor dict along 0th dimension and returns a
  tensor_dict containing values that are lists of unstacked, unpadded tensors.

  Tensors in the `tensor_dict` are expected to be of one of the three shapes:
  1. [batch_size]
  2. [batch_size, height, width, channels]
  3. [batch_size, num_boxes, d1, d2, ... dn]

  When unpad_groundtruth_tensors is set to true, unstacked tensors of form 3
  above are sliced along the `num_boxes` dimension using the value in tensor
  field.InputDataFields.num_groundtruth_boxes.

  Note that this function has a static list of input data fields and has to be
  kept in sync with the InputDataFields defined in core/standard_fields.py

  Args:
    tensor_dict: A dictionary of batched groundtruth tensors.
    unpad_groundtruth_tensors: Whether to remove padding along `num_boxes`
      dimension of the groundtruth tensors.

  Returns:
    A dictionary where the keys are from fields.InputDataFields and values are
    a list of unstacked (optionally unpadded) tensors.

  Raises:
    ValueError: If unpad_tensors is True and `tensor_dict` does not contain
      `num_groundtruth_boxes` tensor.
  """
  unbatched_tensor_dict = {key: tf.unstack(tensor)
                           for key, tensor in tensor_dict.items()}
  if unpad_groundtruth_tensors:
    if (fields.InputDataFields.num_groundtruth_boxes not in
        unbatched_tensor_dict):
      raise ValueError('`num_groundtruth_boxes` not found in tensor_dict. '
                       'Keys available: {}'.format(
                           unbatched_tensor_dict.keys()))
    unbatched_unpadded_tensor_dict = {}
    unpad_keys = set([
        # List of input data fields that are padded a
root@ubuntu:~# python -c "from datasets import exceptions; print('成功导入')" 成功导入 root@ubuntu:~# python3 image_to_text.py Traceback (most recent call last): File "/root/image_to_text.py", line 1, in <module> from modelscope.pipelines import pipeline File "/usr/local/lib/python3.10/dist-packages/modelscope/pipelines/__init__.py", line 4, in <module> from .base import Pipeline File "/usr/local/lib/python3.10/dist-packages/modelscope/pipelines/base.py", line 15, in <module> from modelscope.models.base import Model File "/usr/local/lib/python3.10/dist-packages/modelscope/models/__init__.py", line 18, in <module> fix_transformers_upgrade() File "/usr/local/lib/python3.10/dist-packages/modelscope/utils/automodel_utils.py", line 48, in fix_transformers_upgrade from transformers import PreTrainedModel File "/usr/local/lib/python3.10/dist-packages/transformers/utils/import_utils.py", line 2154, in __getattr__ module = self._get_module(self._class_to_module[name]) File "/usr/local/lib/python3.10/dist-packages/transformers/utils/import_utils.py", line 2184, in _get_module raise e File "/usr/local/lib/python3.10/dist-packages/transformers/utils/import_utils.py", line 2182, in _get_module return importlib.import_module("." + module_name, self.__name__) File "/usr/lib/python3.10/importlib/__init__.py", line 126, in import_module return _bootstrap._gcd_import(name[level:], package, level) File "/usr/local/lib/python3.10/dist-packages/transformers/modeling_utils.py", line 73, in <module> from .loss.loss_utils import LOSS_MAPPING File "/usr/local/lib/python3.10/dist-packages/transformers/loss/loss_utils.py", line 21, in <module> from .loss_d_fine import DFineForObjectDetectionLoss File "/usr/local/lib/python3.10/dist-packages/transformers/loss/loss_d_fine.py", line 21, in <module> from .loss_for_object_detection import ( File "/usr/local/lib/python3.10/dist-packages/transformers/loss/loss_for_object_detection.py", line 32, in <module> from transformers.image_transforms import center_to_corners_format File "/usr/local/lib/python3.10/dist-packages/transformers/image_transforms.py", line 22, in <module> from .image_utils import ( File "/usr/local/lib/python3.10/dist-packages/transformers/image_utils.py", line 59, in <module> from torchvision.transforms import InterpolationMode File "/usr/local/lib/python3.10/dist-packages/torchvision/__init__.py", line 10, in <module> from torchvision import _meta_registrations, datasets, io, models, ops, transforms, utils # usort:skip ImportError: cannot import name 'datasets' from partially initialized module 'torchvision' (most likely due to a circular import) (/usr/local/lib/python3.10/dist-packages/torchvision/__init__.py)
最新发布
07-13
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值