TF学习之DeepLabv3+代码阅读2(dataset)

本文深入解析DeepLabv3+的data_generator.py模块,阐述了数据集处理流程,包括初始化设置、迭代器生成、文件读取、解析与预处理。适用于理解DeepLabv3+图像分割模型的数据准备过程。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

DeepLabv3+代码阅读之data_generator.py

一、Dataset

1. __init__

class Dataset(object):
  """Represents input dataset for deeplab model."""

  def __init__(self,
               dataset_name,# Dataset name
               split_name,# A train/val Split name
               dataset_dir,# The directory of the dataset sources
               batch_size,# Batch size
               crop_size,# The size used to crop the image and label
               min_resize_value=None,# Desired size of the smaller image side
               max_resize_value=None,# Maximum allowed size of the larger image side
               resize_factor=None,# Resized dimensions are multiple of factor plus one
               min_scale_factor=1.,# Minimum scale factor value
               max_scale_factor=1.,# Maximum scale factor value
               scale_factor_step_size=0,
               model_variant=None,# Model variant (string) for choosing how to mean-subtract the images.
        						  # See feature_extractor.network_map for supported model variants.
               num_readers=1,# Number of readers for data provider
               is_training=False,# Boolean, if dataset is for training or not
               should_shuffle=False,# Boolean, if should shuffle the input data
               should_repeat=False):# Boolean, if should repeat the input data

    if dataset_name not in _DATASETS_INFORMATION:
      raise ValueError('The specified dataset is not supported yet.')
    self.dataset_name = dataset_name

    splits_to_sizes = _DATASETS_INFORMATION[dataset_name].splits_to_sizes# 训练测试集对应的图片数量

    if split_name not in splits_to_sizes:
      raise ValueError('data split name %s not recognized' % split_name)

    if model_variant is None:# 使用的模型,如xception_65
      tf.logging.warning('Please specify a model_variant. See '
                         'feature_extractor.network_map for supported model '
                         'variants.')

    self.split_name = split_name
    self.dataset_dir = dataset_dir
    self.batch_size = batch_size
    self.crop_size = crop_size
    self.min_resize_value = min_resize_value
    self.max_resize_value = max_resize_value
    self.resize_factor = resize_factor
    self.min_scale_factor = min_scale_factor
    self.max_scale_factor = max_scale_factor
    self.scale_factor_step_size = scale_factor_step_size
    self.model_variant = model_variant
    self.num_readers = num_readers
    self.is_training = is_training
    self.should_shuffle = should_shuffle
    self.should_repeat = should_repeat

    self.num_of_classes = _DATASETS_INFORMATION[self.dataset_name].num_classes# 类别数
    self.ignore_label = _DATASETS_INFORMATION[self.dataset_name].ignore_label# 忽略的标记

2. get_one_shot_iterator

得到一个迭代数据集一次的迭代器
返回:
	tf.data.Iterator类型迭代器
def get_one_shot_iterator(self):
    """Gets an iterator that iterates across the dataset once.

    Returns:
      An iterator of type tf.data.Iterator.
    """

    files = self._get_all_files()# 得到数据集文件(tfreord)的文件路径列表
	# tf.data.TFRecordDataset, 从一个或多个TFRecord文件生成数据集,map()对数据做自定义的transform
    dataset = (
        tf.data.TFRecordDataset(files, num_parallel_reads=self.num_readers)
        .map(self._parse_function, num_parallel_calls=self.num_readers)# 解析得到图片、标记等信息
        .map(self._preprocess_image, num_parallel_calls=self.num_readers))

    if self.should_shuffle:# 是否需要打乱顺序
      dataset = dataset.shuffle(buffer_size=100)# 从数据集里选100个组成buffer,每次读数据从buffer里随机取一个,
      											# 再从剩下的里边补一个到buffer。最完美的shuffle则是buffer中的个数
      											# 大于等于整个数据集所含的图片个数。

    if self.should_repeat:
      dataset = dataset.repeat()  # Repeat forever for training.
    else:
      dataset = dataset.repeat(1) # 对于测试则只需要重复一遍数据集即可

    dataset = dataset.batch(self.batch_size).prefetch(self.batch_size)
    # .batch()形成batch,.prefetch(),将生成数据的时间和使用数据的时间分离开,减少空闲时间
    return dataset.make_one_shot_iterator()# 创建一个列举数据集中元素的迭代器,最新版本tf中已被废弃

3. _get_all_files

得到所有文件读取的位置
返回:
	输入文件的list
def _get_all_files(self):
  file_pattern = _FILE_PATTERN
  file_pattern = os.path.join(self.dataset_dir,
                              file_pattern % self.split_name)
  return tf.gfile.Glob(file_pattern)# tf.gfile.Glob()返回符合模式的文件路径列表
  # 例如:file_pattern='tensorflow/models/research/deeplab/datasets/pascal_voc_seg/tfrecord/train-*'

其中

_FILE_PATTERN = '%s-*'

4. _parse_function

解析example proto
参数:
	tf.Example格式的Proto
返回:
	字典,包含解析后的图片、标签、尺寸、文件名
备注:
	当前仅支持jpeg和png格式的图片
def _parse_function(self, example_proto):
    # Currently only supports jpeg and png.
    # Need to use this logic because the shape is not known for tf.image.decode_image and we rely on 
    # this info to extend label if necessary.
    def _decode_image(content, channels):
      return tf.cond(
          tf.image.is_jpeg(content),
          lambda: tf.image.decode_jpeg(content, channels),
          lambda: tf.image.decode_png(content, channels))

    features = {
        'image/encoded':
            tf.FixedLenFeature((), tf.string, default_value=''),
        'image/filename':
            tf.FixedLenFeature((), tf.string, default_value=''),
        'image/format':
            tf.FixedLenFeature((), tf.string, default_value='jpeg'),
        'image/height':
            tf.FixedLenFeature((), tf.int64, default_value=0),
        'image/width':
            tf.FixedLenFeature((), tf.int64, default_value=0),
        'image/segmentation/class/encoded':
            tf.FixedLenFeature((), tf.string, default_value=''),
        'image/segmentation/class/format':
            tf.FixedLenFeature((), tf.string, default_value='png'),
    }

    parsed_features = tf.parse_single_example(example_proto, features)

    image = _decode_image(parsed_features['image/encoded'], channels=3)

    label = None
    if self.split_name != common.TEST_SET:
      label = _decode_image(
          parsed_features['image/segmentation/class/encoded'], channels=1)

    image_name = parsed_features['image/filename']
    if image_name is None:
      image_name = tf.constant('')

    sample = {
        common.IMAGE: image,
        common.IMAGE_NAME: image_name,
        common.HEIGHT: parsed_features['image/height'],
        common.WIDTH: parsed_features['image/width'],
    }

    if label is not None:
      if label.get_shape().ndims == 2:
        label = tf.expand_dims(label, 2)
      elif label.get_shape().ndims == 3 and label.shape.dims[2] == 1:
        pass
      else:
        raise ValueError('Input label shape must be [height, width], or '
                         '[height, width, 1].')

      label.set_shape([None, None, 1])

      sample[common.LABELS_CLASS] = label

    return sample

5. _preprocess_image

对图片和标记做处理
参数:
	包含图片和标签的sample
返回:
	处理后的sample
def _preprocess_image(self, sample):

    image = sample[common.IMAGE]
    label = sample[common.LABELS_CLASS]

    original_image, image, label = input_preprocess.preprocess_image_and_label(
        image=image,
        label=label,
        crop_height=self.crop_size[0],
        crop_width=self.crop_size[1],
        min_resize_value=self.min_resize_value,
        max_resize_value=self.max_resize_value,
        resize_factor=self.resize_factor,
        min_scale_factor=self.min_scale_factor,
        max_scale_factor=self.max_scale_factor,
        scale_factor_step_size=self.scale_factor_step_size,
        ignore_label=self.ignore_label,
        is_training=self.is_training,
        model_variant=self.model_variant)

    sample[common.IMAGE] = image# 将sample里的image更新为处理过的

    if not self.is_training:
      # Original image is only used during visualization.
      sample[common.ORIGINAL_IMAGE] = original_image

    if label is not None:
      sample[common.LABEL] = label

    # Remove common.LABEL_CLASS key in the sample since it is only used to
    # derive label and not used in training and evaluation.
    sample.pop(common.LABELS_CLASS, None)

    return sample

二、_DATASETS_INFORMATION

数据集的信息,包含有cityscapes等,使用自己的数据集也要补充在这里。

_DATASETS_INFORMATION = {
    'cityscapes': _CITYSCAPES_INFORMATION,
    'pascal_voc_seg': _PASCAL_VOC_SEG_INFORMATION,
    'ade20k': _ADE20K_INFORMATION,
}
pascal_voc_seg数据集对应的INFORMATION,其中包含了训练、测试集图片数量,类别数等信息。由DatasetDescriptor类定义。
_PASCAL_VOC_SEG_INFORMATION = DatasetDescriptor(
    splits_to_sizes={
        'train': 1464,
        'train_aug': 10582,
        'trainval': 2913,
        'val': 1449,
    },
    num_classes=21,
    ignore_label=255,
)

三、DatasetDescriptor

DatasetDescriptor是用来描述数据集性质的元组。使用collections.namedtuple定义,包含数据集的各种信息。
# Named tuple to describe the dataset properties.
DatasetDescriptor = collections.namedtuple(
    'DatasetDescriptor',
    [
        'splits_to_sizes',  # Splits of the dataset into training, val and test.
        'num_classes',  # Number of semantic classes, including the
                        # background class (if exists). For example, there
                        # are 20 foreground classes + 1 background class in
                        # the PASCAL VOC 2012 dataset. Thus, we set
                        # num_classes=21.
        'ignore_label',  # Ignore label value.
    ])
<think>好的,我需要帮助用户找到关于Deeplabv3+模型在TensorFlow和PyTorch中的代码实现示例教程。首先,我需要回忆一下用户提供的引用内容,看看里面有没有相关的信息。 根据用户给出的引用[1],里面提到了使用TensorFlow和Keras库创建DeepLabV3Plus模型,并给出了代码示例。这段代码导入了必要的库,调用了`DeeplabV3Plus`类,指定了预训练权重、输入形状和类别数。这应该是一个TensorFlow的实现示例。 然后看引用[2],这里提到了一个更高性能的PyTorch实现的Deeplab-v3plus项目,项目地址是https://gitcode.com/gh_mirrors/de/Deeplab-v3plus。这可能是一个PyTorch的代码库,用户可以参考这个项目的README和文档来获取实现细节。 引用[3]提到KerasCV集成了DeepLabv3+,用户可以使用预训练的主干网络进行微调,比如ResNet50_V2和EfficientNetv2_small。这可能是在TensorFlow环境下使用KerasCV库的实现方法。 引用[4]和[5]主要是关于训练自己的数据集和使用指南,可能包含一些脚本和配置说明,但具体的代码实现可能不如前几个引用直接。 接下来,我需要整理这些信息,分TensorFlow和PyTorch两部分给出代码示例,并附上相关引用。对于TensorFlow部分,引用[1]和[3]提供了不同的方法,一个是直接使用Keras的DeeplabV3Plus类,另一个是通过KerasCV库。PyTorch部分则引用[2]的项目地址。 需要注意用户提到的代码实现的正确性,比如在TensorFlow中可能需要安装特定的版本或额外的库,比如`keras_cv`。同时,PyTorch的实现可能需要从Git仓库克隆代码,并按照项目文档进行配置。 另外,用户可能需要知道如何加载预训练权重、调整输入尺寸和类别数,以及如何进行训练和推理。这些步骤在引用中都有所涉及,需要简明扼要地说明。 最后,生成相关问题时要围绕Deeplabv3+的实现、训练、微调、应用场景等方面,确保问题具有针对性,帮助用户深入理解。</think>### Deeplabv3+模型代码实现示例 以下是TensorFlow和PyTorch框架下的Deeplabv3+模型实现方法及教程参考: --- #### **TensorFlow/Keras 实现** 1. **使用Keras内置模块** ```python import tensorflow as tf from tensorflow.keras.applications import DeeplabV3Plus # 创建模型(Pascal VOC预训练权重) model = DeeplabV3Plus( weights=&#39;pascal_voc&#39;, # 可选 &#39;cityscapes&#39; 或其他预训练权重 input_shape=(512, 512, 3), # 输入图像尺寸 classes=21 # 类别数(根据任务调整) ) ``` **说明**: - 直接调用`DeeplabV3Plus`类,需指定输入尺寸和类别数[^1]。 - 支持迁移学习,通过`weights`参数加载预训练模型。 2. **通过KerasCV库实现** ```python import keras_cv model = keras_cv.models.DeepLabV3Plus( backbone="resnet50_v2", # 主干网络选择 num_classes=21, # 类别数 input_shape=(512, 512, 3) # 输入尺寸 ) ``` **说明**: - KerasCV提供更灵活的主干网络(如EfficientNet、ResNet)[^3]。 - 需安装`keras_cv`库:`pip install keras-cv`。 --- #### **PyTorch 实现** 参考开源项目 **[Deeplab-v3plus](https://gitcode.com/gh_mirrors/de/Deeplab-v3plus)**: 1. **克隆仓库并安装依赖** ```bash git clone https://gitcode.com/gh_mirrors/de/Deeplab-v3plus cd Deeplab-v3plus pip install -r requirements.txt ``` 2. **模型定义示例** ```python from model.deeplab import DeepLab model = DeepLab( backbone=&#39;resnet&#39;, # 可选 &#39;xception&#39; 或 &#39;mobilenet&#39; output_stride=16, # 输出步长(控制特征图分辨率) num_classes=21 # 类别数 ) ``` **说明**: - 该项目提供完整的训练、验证脚本及预训练模型加载功能[^2][^5]。 - 支持多GPU训练和自定义数据集配置[^4]。 --- #### **训练自定义数据集** 1. **数据准备** - 需提供图像和对应的语义分割掩码(PNG格式)。 - 参考[引用4]的脚本进行数据预处理和加载。 2. **微调模型** ```python # TensorFlow示例(加载预训练模型后) model.compile(optimizer=&#39;adam&#39;, loss=&#39;sparse_categorical_crossentropy&#39;) model.fit(train_dataset, epochs=50, validation_data=val_dataset) ``` ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值