segdec-net 论文复现

segdec-nethttps://paperswithcode.com/paper/segmentation-based-deep-learning-approach-for论文模型解读复现。(自用)

代码详细解释在注释中。

目录

环境搭建:

模型架构:

代码解读:

segdec_model.py:

__init__

get_inference

 get_decision_net_simple

get_decision_net

get_loss

restore

_activation_summary

_activation_summaries

segdec_train.py

__init__

_tower_loss

_average_gradients

train

_eval_once

evaluate

if __name__=='__main__'

segdec_print_eval.py

calc_confusion_mat

get_performance_eval

evaluate_decision


环境:python=3.6  tensorflow=1.4.0  

环境搭建:

由于源代码适用tensorflow版本过低,Intel没有合适的cpu/gpu优化。

对于2.x以上的tensorflow / pytorch, Intel可以使用 oneAPI 优化。

使用jupyter搭建虚拟环境,至于tensorflow低版本可以在阿里云镜像源找到。(本人查找了清华源豆瓣源,反正是没有找到)。还有就是虚拟环境python版本不要过低,会导致pip旧包警告,可以使用python -m pip~~~解决部分问题,不过还是不建议过低,使用3.5/3.6即可。

使用3.x会有print问题,简单修改一下代码就可以。

模型架构:

决策网络paper中提供了012三种选择(逐渐复杂)

代码解读:

对于分割网络和决策网络分开训练,决策网络又有

segdec_model.py:

定义了一个SegDecModel类。封装了模型的构建、推理、损失计算以及权重加载(预训练)等。

核心功能:分割网络、决策网络、灵活配置(决策网络的012复杂程度)

    BATCHNORM_MOVING_AVERAGE_DECAY = 0.9997
    #用于批量归一化中的移动平均
    
    MOVING_AVERAGE_DECAY = 0.9999
    #用于模型参数的全局移动平均

    DECISION_NET_NONE = 0 #表示不使用决策网络 
    DECISION_NET_LOGISTIC = 1 #表示使用逻辑回归作为决策网络get_decsion_net_simple
    DECISION_NET_FULL = 2  #表示使用完整的决策网络get_decision_net

批量归一化:是一种常用的正则化技术,通过在每个批次的数据上
计算均值和方差来标准化输入数据,加速训练并提高模型稳定性。

批量归一化过程中,使用移动平均估计全局的均值和方差。

__init__

参数初始化

    
    def __init__(self,
                 use_corss_entropy_seg_net=True,
                 positive_weight=1,
                 decision_net=DECISION_NET_NONE,
                 decision_positive_weight=1,
                 load_from_seg_only_net=False):
        #设置模型参数,是否使用交叉熵,正样本权重,
        #决策网络类型以及是否从分割网络的预训练模型中加载权重等等
        # weight for positive samples in segmentation net
        self.positive_weight = positive_weight

        # weight for positive samples in decision net
        self.decision_positive_weight = decision_positive_weight

        if decision_net == SegDecModel.DECISION_NET_NONE:
            self.decision_net_fn = lambda net, net_prob_mat: None
        elif decision_net == SegDecModel.DECISION_NET_LOGISTIC:
            self.decision_net_fn = self.get_decision_net_simple
        elif decision_net == SegDecModel.DECISION_NET_FULL:
            self.decision_net_fn = self.get_decision_net

        self.use_corss_entropy_seg_net = use_corss_entropy_seg_net

        # this is only when loading from pre-trained network of segmetnation that did not have decision net layers
        # present at the same time
        self.load_from_seg_only_net = load_from_seg_only_net

get_inference

构建模型的推理部分,返回分割网络的输出和决策网络的输出。

使用卷积层和池化层构建分割网络。调用了decision_net_fn函数生成决策网络输出。

返回分割网络输出和决策网络输出以及中间端点(方便可视化)。

    #构造推理模型,卷积池化等等,调用决策网络输出结果
    # inputs:输入的图片
    # num_classes:类别数量
    # for_training:是否是训练模式
    # restore_logits:是否恢复logits
    # scope:前缀
    def get_inference(self, inputs, num_classes, for_training=False, restore_logits=True, scope=None):
      """ Build model
      
    
      Args:
        images: Images returned from inputs() or distorted_inputs().
        num_classes: number of classes
        for_training: If set to `True`, build the inference model for training.
          Kernels that operate differently for inference during training
          e.g. dropout, are appropriately configured.
        restore_logits: whether or not the logits layers should be restored.
          Useful for fine-tuning a model with different num_classes.
        scope: optional prefix string identifying the ImageNet tower.
    
      Returns:
        Logits. 2-D float Tensor.
        Auxiliary Logits. 2-D float Tensor of side-head. Used for training only.
      """


      with variable_scope.variable_scope(scope, 'SegDecNet', [inputs]) as sc:
          end_points_collection = sc.original_name_scope + '_end_points'
          # Collect outputs for conv2d, max_pool2d
          with arg_scope(
                  [layers.conv2d, layers.fully_connected, layers_lib.max_pool2d, layers.batch_norm],
                  outputs_collections=end_points_collection):

              # Apply specific parameters to all conv2d layers (to use batch norm and relu - relu is by default)
              with arg_scope([layers.conv2d, layers.fully_connected],
                             weights_initializer= lambda shape,dtype=tf.float32, partition_info=None: tf.random_normal(shape, mean=0,stddev=0.01, dtype=dtype),
                             biases_initializer=None,
                             normalizer_fn=layers.batch_norm,
                             normalizer_params={'center': True,
                                                'scale': True,
                                                #'is_training': for_training, # we disable this to do feature normalization (but requires batch size=1)
                                                'decay': self.BATCHNORM_MOVING_AVERAGE_DECAY, # Decay for the moving averages.
                                                'epsilon': 0.001, # epsilon to prevent 0s in variance.
                                                }):

                  net = layers_lib.repeat(inputs, 2, layers.conv2d, 32, [5, 5], scope='conv1')

                  net = layers_lib.max_pool2d(net, [2, 2], scope='pool1')

                  net = layers_lib.repeat(net, 3, layers.conv2d, 64, [5, 5], scope='conv2')

                  net = layers_lib.max_pool2d(net, [2, 2], scope='pool2')

                  net = layers_lib.repeat(net, 4, layers.conv2d, 64, [5, 5], scope='conv3')

                  net = layers_lib.max_pool2d(net, [2, 2], scope='pool3')

                  net = layers.conv2d(net, 1024, [15, 15], padding='SAME', scope='conv4')

                  net_prob_mat = layers.conv2d(net, 1, [1, 1], scope='conv5',
                                               activation_fn=None)

                  decision_net = self.decision_net_fn(net, tf.nn.relu(net_prob_mat))

                  # Convert end_points_collection into a end_point dict.
                  endpoints = utils.convert_collection_to_dict(end_points_collection)



      # Add summaries for viewing model statistics on TensorBoard.
      self._activation_summaries(endpoints)

      return net_prob_mat, decision_net, endpoints

 get_decision_net_simple

实现简单的逻辑回归作为决策网络。

使用全局平均池化和全局最大池化提取特征。

将池化结果拼接后通过一个卷积层生成最终的决策输出。

    #简单的决策网络:全连接和最大池化然后卷积
    def get_decision_net_simple(self, net, net_prob_mat):

        avg_output = keras.layers.GlobalAveragePooling2D()(net_prob_mat)
        max_output = keras.layers.GlobalMaxPooling2D()(net_prob_mat)

        decision_net = tf.concat([avg_output, max_output], 3)

        decision_net = layers.conv2d(decision_net, 1, [1, 1], scope='decision6',
                                     normalizer_fn=None,
                                     weights_initializer=initializers.xavier_initializer_conv2d(False),
                                     biases_initializer=tf.constant_initializer(0),
                                     activation_fn=None)

        return decision_net

get_decision_net

复杂的决策网络

包含多个卷积层和池化层,逐步提取高层次特征。

最后使用全局平均池化和全局最大池化提取最终特征,并通过全连接层生成决策输出。

    # 复杂的决策网络:卷积池化然后卷积
    def get_decision_net(self, net, net_prob_mat):

        with tf.name_scope('decision'):

            decision_net = tf.concat([net, net_prob_mat],axis=3)

            decision_net = layers_lib.max_pool2d(decision_net, [2, 2], scope='decision/pool4')

            decision_net = layers.conv2d(decision_net, 8, [5, 5], padding='SAME', scope='decision/conv6')

            decision_net = layers_lib.max_pool2d(decision_net, [2, 2], scope='decision/pool5')

            decision_net = layers.conv2d(decision_net, 16, [5, 5], padding='SAME', scope='decision/conv7')

            decision_net = layers_lib.max_pool2d(decision_net, [2, 2], scope='decision/pool6')

            decision_net = layers.conv2d(decision_net, 32, [5, 5], scope='decision/conv8')

            with tf.name_scope('decision/global_avg_pool'):
                avg_decision_net = keras.layers.GlobalAveragePooling2D()(decision_net)

            with tf.name_scope('decision/global_max_pool'):
                max_decision_net = keras.layers.GlobalMaxPooling2D()(decision_net)

            with tf.name_scope('decision/global_avg_pool'):
                avg_prob_net = keras.layers.GlobalAveragePooling2D()(net_prob_mat)

            with tf.name_scope('decision/global_max_pool'):
                max_prob_net = keras.layers.GlobalMaxPooling2D()(net_prob_mat)

            # adding avg_prob_net and max_prob_net may not be needed, but it doesen't hurt
            decision_net = tf.concat([avg_decision_net, max_decision_net, avg_prob_net, max_prob_net], axis=1)

            decision_net = layers.fully_connected(decision_net, 1, scope='decision/FC9',
                                                  normalizer_fn=None,
                                                  biases_initializer=tf.constant_initializer(0),
                                                  activation_fn=None)
        return decision_net

get_loss

计算分割网络和决策网络的损失

根据use_cross_entropy_seg_net参数选择分割网络的损失函数(MSE/Cross)

决策网络使用sigmoid交叉熵损失函数。

    #计算损失函数
    def get_loss(self, net_model, masks, batch_size=None, return_segmentation_net=True, return_decision_net=True, output_resolution_reduction=8):
      """Adds all losses for the model.
    
      Note the final loss is not returned. Instead, the list of losses are collected
      by slim.losses. The losses are accumulated in tower_loss() and summed to
      calculate the total loss.
    
      Args:
        logits: List of logits from inference(). Each entry is a 2-D float Tensor.
        labels: Labels from distorted_inputs or inputs(). 1-D tensor
                of shape [batch_size]
        batch_size: integer
      """

#检查批次大小参数
      if not batch_size:
        raise Exception("Missing batch_size")

#解包输入,将net_momdel,分为分割网络输出、决策网络输出、中间端点
      net, decision_net, endpoints = net_model

#处理掩码分辨率,处理后有助于减少计算量并适应不同分辨率的输入
      if output_resolution_reduction > 1:
        mask_blur_kernel = [output_resolution_reduction*2+1, output_resolution_reduction*2+1]
        masks = layers_lib.avg_pool2d(masks, mask_blur_kernel, stride=output_resolution_reduction, padding='SAME', scope='pool_mask',outputs_collections='tower_0/_end_points')

#转换掩码格式,如果不使用交叉熵损失函数,就将masks转换成布尔值,表示像素是否大于0.5
      if self.use_corss_entropy_seg_net is False:
          masks = tf.greater(masks, tf.constant(0.5))

#记录预测结果,将分割网络的输出net记录为预测结果
      predictions = net

      tf.summary.image('prediction', predictions)

#分别记录分割网络和决策网络的损失
      l1 = None
      l2 = None

#计算分割网络的损失
#若positive_weight>1,动态调整正负样本的数量,应对类别不平衡问题。
#归一化权重,确保权重总和等于元素数量
#损失函数MSE/交叉熵
      if return_segmentation_net:
        if self.positive_weight > 1:
            pos_pixels = tf.less(tf.constant(0.0), masks)
            neg_pixels = tf.greater_equal(tf.constant(0.0), masks)

            num_pos_pixels = tf.cast(tf.count_nonzero(pos_pixels), dtype=tf.float32)
            num_neg_pixels = tf.cast(tf.count_nonzero(neg_pixels), dtype=tf.float32)

            pos_pixels = tf.cast(pos_pixels, dtype=tf.float32)
            neg_pixels = tf.cast(neg_pixels, dtype=tf.float32)

            positive_weight = tf.cond(num_pos_pixels > tf.constant(0,dtype=tf.float32),
                                      lambda: tf.multiply(tf.div(num_neg_pixels, num_pos_pixels),
                                                          tf.constant(self.positive_weight,dtype=tf.float32)),
                                      lambda: tf.constant(self.positive_weight, dtype=tf.float32))

            positive_weight = tf.reshape(positive_weight, [1])

            # weight positive samples more !!
            weights = tf.add(neg_pixels, tf.multiply(pos_pixels, positive_weight))

            # noramlize weights so that the sum of weights is always equal to the num of elements
            N = tf.constant(weights.shape[1]._value * weights.shape[2]._value, dtype=tf.float32)

            factor = tf.reduce_sum(weights,axis=[1,2])
            factor = tf.divide(N, factor)

            weights = tf.multiply(weights, tf.reshape(factor,[-1,1,1,1]))

            if self.use_corss_entropy_seg_net is False:
                l1 = tf.losses.mean_squared_error(masks, predictions, weights=weights)
            else:
                l1 = tf.losses.sigmoid_cross_entropy(logits=predictions, multi_class_labels=masks, weights=weights) # NOTE: weights were added but not tested yet !!
        else:
            if self.use_corss_entropy_seg_net is False:
                l1 = tf.losses.mean_squared_error(masks, predictions)
            else:
                l1 = tf.losses.sigmoid_cross_entropy(logits=predictions,multi_class_labels=masks)


#计算决策网络的损失
#这里将masks转换为浮点数,并计算每张图像的正样本比例作为决策网络的标签。
#根据decision_net的维度,进行适当的压缩,使形状符合要求
#损失利用交叉熵损失函数计算,并应用decision_positive_weight来调整正样本的权重
      if return_decision_net:
          with tf.name_scope('decision'):
            masks = tf.cast(masks, tf.float32)
            label = tf.minimum(tf.reduce_sum(masks, [1, 2, 3]), tf.constant(1.0))

            if len(decision_net.shape) == 2:
                decision_net = tf.squeeze(decision_net, [1])
            elif len(decision_net.shape) == 4:
                decision_net = tf.squeeze(decision_net, [1, 2, 3])
            else:
                raise Exception("Only 2 or 4 dimensional output expected for decision_net")

            decision_net = tf.reshape(decision_net,[-1,1])
            label = tf.reshape(label, [-1, 1])

            l2 = tf.losses.sigmoid_cross_entropy(logits=decision_net,multi_class_labels=label, weights=self.decision_positive_weight)

      return [l1,l2]

restore

权重加载,这与 init 中的是否从预训练的网络中加载权重对应。

如果 load_from_seg_only_net 为 true,就过滤掉决策网络相关的变量。恢复权重。

    def restore(self, session, model_checkpoint_path, variables_to_restore = None, load_from_seg_only_net=False):

        if variables_to_restore is None:
            variables_to_restore = tf.trainable_variables()# + tf.moving_average_variables() # tf.moving_average_variables is required only in TF r1.1

        # this is only when loading from pre-trained network of segmetnation that did not have decision net layers
        # present at the same time
        if load_from_seg_only_net:
            variables_to_restore = [v for v in variables_to_restore if v.name.count('decision') == 0]

        saver = tf.train.Saver(variables_to_restore)
        try:
            saver.restore(session, model_checkpoint_path)

        except:
            # remove decision variables if cannot load them
            if type(variables_to_restore) is dict:
                variables_to_restore = [variables_to_restore[v] for v in variables_to_restore.keys() if v.find('decision') < 0]
            else:
                variables_to_restore = [v for v in variables_to_restore if v.name.find('decision') < 0]

            saver = tf.train.Saver(variables_to_restore)

            saver.restore(session, model_checkpoint_path)

_activation_summary

用于生成 TensorBoard 中的激活统计信息(直方图和稀疏性)

   
#为张量x创建两个类型的摘要
#直方图摘要:记录激活值的分布情况
#稀疏性摘要:记录激活值中0的比例
   def _activation_summary(self, x):
      """Helper to create summaries for activations.
    
      Creates a summary that provides a histogram of activations.
      Creates a summary that measure the sparsity of activations.
    
      Args:
        x: Tensor
      """
      # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training
      # session. This helps the clarity of presentation on tensorboard.

#清理操作,清理掉GPU训练时的GPU名次
      tensor_name = re.sub('%s_[0-9]*/' % self.TOWER_NAME, '', x.op.name)

#创建直方图摘要
      tf.summary.histogram(tensor_name + '/activations', x)

#创建稀疏性摘要
      tf.summary.scalar(tensor_name + '/sparsity', tf.nn.zero_fraction(x))

_activation_summaries

遍历所有的中间端点,调用_activation_summary方法生成统计信息。

    def _activation_summaries(self, endpoints):
      with tf.name_scope('summaries'):
        for act in endpoints.values():
          self._activation_summary(act)

segdec_train.py

用于训练和评估分割决策网络,定义了一个SegDecTrain的类。

封装了模型训练、评估、权重加载以及可视化等等。

核心功能:

模型训练:支持GPU并行训练(这个我没有实现),提供学习率调度、梯度裁剪、权重初始化

模型评估:在测试数据集上评估模型性能,生成PRC、ROC曲线等

权重管理:支持从预训练的模型加载权重,并可以选择性冻结某些层的参数

可视化:通过TesorBoard 提供可视化,包括激活值分布、权重可视化、损失曲线等

__init__

参数初始化,数据预处理(NetInputProcessing 随机旋转、正负样本平衡)

    def __init__(self, model, storage_dir, run_string, image_size, batch_size,
                 learning_rate = 0.01,
                 max_epochs = 1000,
                 max_steps = 10000000,
                 num_gpus = 1,
                 visible_device_list = None,
                 num_preprocess_threads = 1,
                 pretrained_model_checkpoint_path = None,
                 train_segmentation_net = True,
                 train_decision_net = False,
                 use_random_rotation=False,
                 ensure_posneg_balance=True):

        self.model = model

        run_train_string = run_string[0] if type(run_string) is tuple else run_string
        run_eval_string = run_string[1] if type(run_string) is tuple else run_string

        self.visible_device_list = visible_device_list
        self.batch_size = batch_size
        self.train_dir = os.path.join(storage_dir, 'segdec_train', run_train_string) # Directory where to write event logs and checkpoint.
        self.eval_dir = os.path.join(storage_dir, 'segdec_eval', run_eval_string)

        # Takes number of learning batch iterations based on min(self.max_steps, self.max_epoch * num_batches_per_epoch)
        self.max_steps = max_steps  # Number of batches to run.
        self.max_epochs = max_epochs  # Number of epochs to run

        # Flags governing the hardware employed for running TensorFlow.
        self.num_gpus = num_gpus  # How many GPUs to use.
        self.log_device_placement = False  # Whether to log device placement

        self.num_preprocess_threads = num_preprocess_threads
        # Flags governing the type of training.
        self.fine_tune = False  # If set, randomly initialize the final layer of weights in order to train the network on a new task.
        self.pretrained_model_checkpoint_path = pretrained_model_checkpoint_path  # If specified, restore this pretrained model before beginning any training.

        self.initial_learning_rate = learning_rate  # Initial learning rate.
        self.decay_steps = 0 # no decay by default
        self.learning_rate_decay_factor = 1

        self.TOWER_NAME = "tower"

        # Batch normalization. Constant governing the exponential moving average of
        # the 'global' mean and variance for all activations.
        self.BATCHNORM_MOVING_AVERAGE_DECAY = 0.9997

        # The decay to use for the moving average.
        self.MOVING_AVERAGE_DECAY = 0.9999

        # Override the number of preprocessing threads to account for the increased
        # number of GPU towers.
        input_num_preprocess_threads = self.num_preprocess_threads * self.num_gpus

        self.input = NetInputProcessing(batch_size=self.batch_size,
                                        num_preprocess_threads=input_num_preprocess_threads,
                                        input_size=image_size,
                                        mask_size=(image_size[0],image_size[1],1),
                                        use_random_rotation=use_random_rotation,
                                        ensure_posneg_balance=ensure_posneg_balance)

        self.train_segmentation_net = train_segmentation_net
        self.train_decision_net = train_decision_net

        assert self.batch_size == 1, "Only batch_size=1 is allowed due to the way the batch_norm is used to normalize features in testing !!!"

        self.loss_print_step = 11
        self.summary_step = 110
        self.checkpoint_step = 10007

model: 分割决策网络模型实例(如 SegDecModel)。
storage_dir: 存储训练日志和评估结果的目录路径。
run_string: 运行标识符,用于区分不同的实验。
image_size: 输入图像的尺寸(高度、宽度、通道数)。
batch_size: 批次大小,默认为 1。
learning_rate: 初始学习率,默认为 0.01。
max_epochs: 最大训练轮数,默认为 1000。
max_steps: 最大训练步数,默认为 10000000。
num_gpus: 使用的 GPU 数量,默认为 1。
visible_device_list: 可见的 GPU 设备列表。
num_preprocess_threads: 数据预处理线程数,默认为 1。
pretrained_model_checkpoint_path: 预训练模型的检查点路径。
train_segmentation_net: 是否训练分割网络,默认为 True。
train_decision_net: 是否训练决策网络,默认为 False。
use_random_rotation: 是否对输入数据进行随机旋转增强,默认为 False。
ensure_posneg_balance: 是否确保正负样本平衡,默认为 True。

_tower_loss

计算单个GPU塔上的总损失。

调用get_inference构建计算图。调用get_loss计算损失。

    def _tower_loss(self, images, masks, num_classes, scope, reuse_variables=None):

      # When fine-tuning a model, we do not restore the logits but instead we
      # randomly initialize the logits. The number of classes in the output of the
      # logit is the number of classes in specified Dataset.
      restore_logits = not self.fine_tune

      # Build inference Graph.

#构建推理图,restore_logits决定是否加载模型的分类层(logits)
#reuse_variables控制变量复用,确保多塔共享权重
      with tf.variable_scope(tf.get_variable_scope(), reuse=reuse_variables):
        net_model = self.model.get_inference(images, num_classes, for_training=True,
                                     restore_logits=restore_logits,
                                     scope=scope)

#分割批量大小,也就是当前塔处理的批量大小,被分割到多个塔上
#计算损失
      # Build the portion of the Graph calculating the losses. Note that we will
      # assemble the total_loss using a custom function below.
      split_batch_size = images.get_shape().as_list()[0]
      self.model.get_loss(net_model, masks,
                          batch_size=split_batch_size,
                          return_segmentation_net=self.train_segmentation_net,
                          return_decision_net=self.train_decision_net)

#收集损失,从当前塔的命名空间中手机所有的损失项
#收集正则化损失
      # Assemble all of the losses for the current tower only.
      losses = tf.get_collection(tf.GraphKeys.LOSSES, scope)

#将所有的损失项相加,得到总的损失
      # Calculate the total loss for the current tower.
      regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
      total_loss = tf.add_n(losses + regularization_losses, name='total_loss')

#使用指数移动平均平滑损失值,将EMA应用于所有的损失项和总损失
      # Compute the moving average of all individual losses and the total loss.
      loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg')
      loss_averages_op = loss_averages.apply(losses + [total_loss])

#添加摘要记录
      # Attach a scalar summmary to all individual losses and the total loss; do the
      # same for the averaged version of the losses.
      for l in losses + [total_loss]:
        # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training
        # session. This helps the clarity of presentation on TensorBoard.
        loss_name = re.sub('%s_[0-9]*/' % self.TOWER_NAME, '', l.op.name)
        # Name each loss as '(raw)' and name the moving average version of the loss
        # as the original loss name.
        tf.summary.scalar(loss_name +' (raw)', l)
        tf.summary.scalar(loss_name, loss_averages.average(l))

#确保在计算总损失之前完成EMA更新操作
      with tf.control_dependencies([loss_averages_op]):
        total_loss = tf.identity(total_loss)
      return total_loss

_average_gradients

计算每个共享变量的梯度平均值(相同变量在不同塔上的梯度的平均值)

    def _average_gradients(self, tower_grads):
      """Calculate the average gradient for each shared variable across all towers.
    
      Note that this function provides a synchronization point across all towers.
    
      Args:
        tower_grads: List of lists of (gradient, variable) tuples. The outer list
          is over individual gradients. The inner list is over the gradient
          calculation for each tower.
      Returns:
         List of pairs of (gradient, variable) where the gradient has been averaged
         across all towers.
      """
      average_grads = []
      for grad_and_vars in zip(*tower_grads):
        # Note that each grad_and_vars looks like the following:
        #   ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
        grads = []
        for g, _ in grad_and_vars:
          # Add 0 dimension to the gradients to represent the tower.
          expanded_g = tf.expand_dims(g, 0)

          # Append on a 'tower' dimension which we will average over below.
          grads.append(expanded_g)

        # Average over the 'tower' dimension.
        grad = tf.concat(axis=0, values=grads)
        grad = tf.reduce_mean(grad, 0)

        # Keep in mind that the Variables are redundant because they are shared
        # across towers. So .. we will just return the first tower's pointer to
        # the Variable.
        v = grad_and_vars[0][1]
        grad_and_var = (grad, v)
        average_grads.append(grad_and_var)
      return average_grads

train

定义学习率调度器、优化器等等

构建计算图:将输入数据分发多塔、计算损失、计算梯度

训练循环:按批次运行训练操作,更新模型参数

权重加载:如果有提供预训练模型路径,加载权重

    def train(self, dataset):
      """Train on input_data for a number of steps."""

#创建新的计算图,默认在cpu上执行
      with tf.Graph().as_default(), tf.device('/cpu:0'):
        # Create a variable to count the number of train() calls. This equals the
        # number of batches processed * FLAGS.num_gpus.
        global_step = tf.Variable(
            'global_step', [],
            initializer=tf.constant_initializer(0), trainable=False)

        # Calculate the learning rate schedule.

#学习率调度,指数衰减/固定
        # Decay the learning rate exponentially based on the number of steps.
        if self.decay_steps > 0:
            lr = tf.train.exponential_decay(self.initial_learning_rate,
                                        global_step,
                                        self.decay_steps,
                                        self.learning_rate_decay_factor,
                                        staircase=True)
        else:
            lr = self.initial_learning_rate

#优化器,使用梯度下降优化器更新参数
        # Create an optimizer that performs gradient descent.
        opt = tf.train.GradientDescentOptimizer(lr)

#获取输入数据,图像和掩码
#需要确保批量大小是GPU数量的整数倍
        # Get images and labels for ImageNet and split the batch across GPUs.
        assert self.batch_size % self.num_gpus == 0, (
            'Batch size must be divisible by number of GPUs')

        images, masks, _ = self.input.add_inputs_nodes(dataset, True)


        input_summaries = copy.copy(tf.get_collection(tf.GraphKeys.SUMMARIES))

#类别数+1,其中0类别通常是用于背景
        # Number of classes in the Dataset label set plus 1.
        # Label 0 is reserved for an (unused) background class.
        num_classes = dataset.num_classes() + 1

#分发数据
         # Split the batch of images and labels for towers.
        images_splits = tf.split(axis=0, num_or_size_splits=self.num_gpus, value=images)
        masks_splits = tf.split(axis=0, num_or_size_splits=self.num_gpus, value=masks)

#计算每个塔的梯度
#循环每个 GPU:
#设备设置:将操作分配到指定的 GPU。
#命名空间:为每个塔设置唯一的命名空间(如 'tower_0')。
#变量作用域:强制所有变量在 CPU 上创建,以避免 GPU 内存问题。
#计算损失:调用 _tower_loss 计算当前塔的损失。
#变量复用:设置 reuse_variables 为 True,以便后续塔复用变量。
#收集摘要:收集当前塔的摘要信息。
#批归一化更新:收集批归一化的更新操作。
#计算梯度:使用优化器计算梯度。
#存储梯度:将梯度存储在 tower_grads 列表中。
        # Calculate the gradients for each model tower.
        tower_grads = []
        reuse_variables = None
        for i in range(self.num_gpus):
          with tf.device('/gpu:%d' % i):
            with tf.name_scope('%s_%d' % (self.TOWER_NAME, i)) as scope:
              # Force all Variables to reside on the CPU.
              with slim.arg_scope([slim.variable], device='/cpu:0'):
                # Calculate the loss for one tower of the ImageNet model. This
                # function constructs the entire ImageNet model but shares the
                # variables across all towers.
                loss = self._tower_loss(images_splits[i], masks_splits[i], num_classes,
                                   scope, reuse_variables)

              # Reuse variables for the next tower.
              reuse_variables = True

              # Retain the summaries from the final tower.
              summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope)

              # Retain the Batch Normalization updates operations only from the
              # final tower. Ideally, we should grab the updates from all towers
              # but these stats accumulate extremely fast so we can ignore the
              # other stats from the other towers without significant detriment.
              batchnorm_updates = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope)

              # Calculate the gradients for the batch of data on this ImageNet
              # tower.
              grads = opt.compute_gradients(loss)

              # Keep track of the gradients across all towers.
              tower_grads.append(grads)

#选择过滤决策网络梯度,这是为了先后训练分割网络和决策网络
        variables_to_average = (tf.trainable_variables() +
                                    tf.moving_average_variables())

        # if decision_net is not trained then remove all gradients for decision
        if self.train_decision_net is False:
            tower_grads = [[g for g in tg if g[1].name.find('decision') < 0] for tg in tower_grads]

            variables_to_average = [v for v in variables_to_average if v.name.find('decision') < 0]

        # if segmentation_net is not trained then remove all gradients for segmentation net
        # i.e. we assume all variables NOT flaged as decision net are segmentation net
        if self.train_segmentation_net is False:
            tower_grads = [[g for g in tg if g[1].name.find('decision') >= 0] for tg in tower_grads]

        # We must calculate the mean of each gradient. Note that this is the
        # synchronization point across all towers.
        grads = self._average_gradients(tower_grads)

        # Apply the gradients to adjust the shared variables.
        apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)

        # Track the moving averages of all trainable variables.
        # Note that we maintain a "double-average" of the BatchNormalization
        # global statistics. This is more complicated then need be but we employ
        # this for backward-compatibility with our previous models.
        variable_averages = tf.train.ExponentialMovingAverage(self.MOVING_AVERAGE_DECAY, global_step)

        # Another possibility is to use tf.slim.get_variables().
        variables_averages_op = variable_averages.apply(variables_to_average)

        # Group all updates to into a single train op.
        batchnorm_updates_op = tf.group(*batchnorm_updates)
        train_op = tf.group(apply_gradient_op, variables_averages_op,
                            batchnorm_updates_op)

        # Add summaries and visualization
        
        
        # Add histograms for trainable variables.
        for var in tf.trainable_variables():
          summaries.append(tf.summary.histogram(var.op.name, var))

        # Add weight visualization
        weight_variables = [v for v in tf.global_variables() if v.name.find('/weights') >= 0]

        for c in ['conv1_1','conv1_2',
                  'conv2_1', 'conv2_2', 'conv2_3',
                  'conv3_1', 'conv3_2', 'conv3_3', 'conv3_4']:
            with tf.name_scope(c):
                w = [v for v in weight_variables if v.name.find('/' + c + '/') >= 0]
                w = w[0]

                x_min = tf.reduce_min(w)
                x_max = tf.reduce_max(w)
                ww = (w - x_min) / (x_max - x_min)

                ww_t = tf.transpose(ww, [3, 0, 1, 2])
                ww_t = tf.reshape(ww_t[:,:,:,0], [int(ww_t.shape[0]), int(ww_t.shape[1]), int(ww_t.shape[2]), 1])
                tf.summary.image(c, ww_t, max_outputs=10)

                summaries.extend(tf.get_collection(tf.GraphKeys.SUMMARIES, c))

        # Add a summaries for the input processing and global_step.
        summaries.extend(input_summaries)

        # Add a summary to track the learning rate.
        summaries.append(tf.summary.scalar('learning_rate', lr))

        # Add histograms for gradients.
        for grad, var in grads:
          if grad is not None:
            summaries.append(
                tf.summary.histogram(var.op.name + '/gradients', grad))

        summaries = tf.get_collection(tf.GraphKeys.SUMMARIES)
        # Create a saver.
        saver = tf.train.Saver(tf.global_variables())

        # Build the summary operation from the last tower summaries.
        summary_op = tf.summary.merge(summaries)


        # Build an initialization operation to run below.
        init = tf.global_variables_initializer()

        # Start running operations on the Graph. allow_soft_placement must be set to
        # True to build towers on GPU, as some of the ops do not have GPU
        # implementations.
        c = tf.ConfigProto(allow_soft_placement=True,
                            log_device_placement=self.log_device_placement)
        if self.visible_device_list is not None:
            c.gpu_options.visible_device_list = self.visible_device_list
        c.gpu_options.allow_growth = True

        sess = tf.Session(config=c)
        sess.run(init)

        # restore weights from previous model
        if self.pretrained_model_checkpoint_path is not None:
            ckpt = tf.train.get_checkpoint_state(self.pretrained_model_checkpoint_path)
            if ckpt is None:
                raise Exception('No valid saved model found in ' + self.pretrained_model_checkpoint_path)

            self.model.restore(sess, ckpt.model_checkpoint_path)

        # Start the queue runners.
        tf.train.start_queue_runners(sess=sess)

        summary_writer = tf.summary.FileWriter(
            self.train_dir,
            graph=sess.graph)

        num_steps = min(int(self.max_epochs * dataset.num_examples_per_epoch() /  self.batch_size),
                        self.max_steps)

        prev_duration = None

        for step in range(num_steps):

          run_nodes = [train_op, loss]

          if step % self.summary_step == 0:
              run_nodes = [train_op, loss, summary_op]

          start_time = time.time()
          output_vals = sess.run(run_nodes)
          duration = time.time() - start_time

          if prev_duration is None:
              prev_duration = duration

          loss_value = output_vals[1]

          assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

          if step % self.loss_print_step == 0:
            examples_per_sec = self.batch_size / float(prev_duration)
            format_str = ('%s: step %d, loss = %.5f (%.1f examples/sec; %.3f '
                          'sec/batch)')
            print(format_str % (datetime.now(), step, loss_value,
                                examples_per_sec, prev_duration))

          if step % self.summary_step == 0:
            summary_str = output_vals[2]
            summary_writer.add_summary(summary_str, step)

          # Save the model checkpoint periodically.
          if step % self.checkpoint_step == 0 or (step + 1) == num_steps:
            checkpoint_path = os.path.join(self.train_dir, 'model.ckpt')
            saver.save(sess, checkpoint_path, global_step=step)

          prev_duration = duration

train 方法的主要步骤包括:

构建计算图:创建计算图并设置设备。
定义全局步数:定义全局步数变量。
学习率调度:设置学习率衰减策略。
优化器:选择优化器。梯度下降
获取输入数据:从数据集中获取训练数据。
收集输入摘要:收集输入数据的摘要信息。
计算类别数量:确定数据集中的类别数量。类别数+1
分割批量数据:将数据按 GPU 数量分割。
计算每个塔的梯度:为每个 GPU 计算梯度。
过滤梯度:根据训练配置过滤梯度。
计算平均梯度:计算所有塔的平均梯度。
应用梯度:应用平均梯度更新模型参数。
移动平均:计算变量的移动平均值。
批归一化更新:更新批归一化的统计信息。
组合训练操作:将所有训练操作组合成一个操作。
添加摘要:为变量、权重、输入处理、学习率和梯度添加摘要。
创建保存器和摘要写入器:创建保存器和摘要写入器。
初始化会话:初始化会话并配置会话。
恢复预训练模型:如果需要,恢复预训练模型。
启动队列运行器:启动数据输入管道。
创建摘要写入器:创建摘要写入器。
训练循环:执行训练循环,包括运行操作、记录时间、检查损失、打印日志、写入摘要和保存模型。

_eval_once

在测试数据集上运行一次评估,生成预测结果并计算性能指标。

加载模型权重,遍历测试数据集,生成预测,计算PR,ROC,AUC,AP等。

    def _eval_once(self, eval_dir, variables_to_restore, net_op, decision_op, images_op, labels_op, img_names_op, num_examples, plot_results=True):
        """Runs Eval once.
  
        Args:
          saver: Saver.
          summary_writer: Summary writer.
          net_op: net operation with prediction          
          summary_op: Summary op.
        """
        c = tf.ConfigProto()
        if self.visible_device_list is not None:
            c.gpu_options.visible_device_list = self.visible_device_list
        c.gpu_options.allow_growth = True
        with tf.Session(config=c) as sess:
            ckpt = tf.train.get_checkpoint_state(self.train_dir)
            if ckpt and ckpt.model_checkpoint_path:

                model_checkpoint_path = ckpt.model_checkpoint_path

                # Restores from checkpoint with relative path.
                if os.path.isabs(model_checkpoint_path):
                    model_checkpoint_path = os.path.join(self.train_dir, model_checkpoint_path)

                self.model.restore(sess, model_checkpoint_path, variables_to_restore)

                # Assuming model_checkpoint_path looks something like:
                #   /my-favorite-path/imagenet_train/model.ckpt-0,
                # extract global_step from it.
                global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
                print('Successfully loaded model from %s at step=%s.' %
                      (ckpt.model_checkpoint_path, global_step))
            else:
                print('No checkpoint file found')
                return

            # Start the queue runners.
            coord = tf.train.Coordinator()
            try:
                threads = []
                for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
                    threads.extend(qr.create_threads(sess, coord=coord, daemon=True,
                                                     start=True))

                num_iter = int(math.ceil(num_examples / self.batch_size))

                # Counts the number of correct predictions.
                samples_outcome = []
                samples_names = []
                samples_speed_eval = []

                total_sample_count = num_iter * self.batch_size
                step = 0

                print('%s: starting evaluation on (%s).' % (datetime.now(), ''))
                start_time = time.time()
                while step < num_iter and not coord.should_stop():
                    start_time_run = time.time()
                    if decision_op is None:
                        predictions, image, label, img_name = sess.run([net_op, images_op, labels_op, img_names_op])
                    else:
                        predictions, decision, image, label, img_name = sess.run([net_op, decision_op, images_op, labels_op, img_names_op])

                        decision = 1.0/(1+np.exp(-np.squeeze(decision)))

                    # if we use sigmoid cross-correlation loss, then we need to add sigmoid to predictions
                    # since this is usually handled by loss which we do not use in inference
                    if self.model.use_corss_entropy_seg_net:
                        predictions = 1.0/(1+np.exp(-predictions))

                    end_time_run = time.time()

                    name = str(img_name[0]).replace("/", "_")
                    samples_names.append(name)

                    np.save(str.format("{0}/result_{2}.npy", eval_dir, step, name), predictions)
                    np.save(str.format("{0}/result_{2}_gt.npy", eval_dir, step, name), label)

                    if plot_results:
                        plt.figure(1)
                        plt.clf()
                        plt.subplot(1, 3, 1)
                        plt.title('Input image')
                        plt.imshow(image[0, :, :, 0], cmap="gray")

                        plt.subplot(1, 3, 2)
                        plt.title('Groundtruth')
                        plt.imshow(label[0, :, :, 0], cmap="gray")

                        plt.subplot(1, 3, 3)
                        if decision_op is None:
                            plt.title('Output/prediction')
                        else:
                            plt.title(str.format('Output/prediction: {0}',decision))

                        # display max
                        vmax_value = max(1, predictions.max())

                        plt.imshow((predictions[0, :, :, 0] > 0) * predictions[0, :, :, 0], cmap="jet", vmax=vmax_value)
                        plt.suptitle(str(img_name[0]))

                        plt.show(block=0)

                        out_prefix = ''

                        if decision_op is not None:
                            out_prefix = '%.3f_' % decision

                        plt.savefig(str.format("{0}/{1}result_{2}.pdf", eval_dir, out_prefix, name), bbox_inches='tight')

                    samples_speed_eval.append(end_time_run - start_time_run)

                    if decision_op is None:
                        pass
                    else:
                        samples_outcome.append((decision, np.max(label)))

                    step += 1
                    if step % 20 == 0:
                        duration = time.time() - start_time
                        sec_per_batch = duration / 20.0
                        examples_per_sec = self.batch_size / sec_per_batch
                        print('%s: [%d batches out of %d] (%.1f examples/sec; %.3f'
                              'sec/batch)' % (datetime.now(), step, num_iter,
                                              examples_per_sec, sec_per_batch))
                        start_time = time.time()

                if len(samples_outcome) > 0:
                    from sklearn.metrics import precision_recall_curve, roc_curve, auc

                    samples_outcome = np.matrix(np.array(samples_outcome))

                    idx = np.argsort(samples_outcome[:,0],axis=0)
                    idx = idx[::-1]
                    samples_outcome = np.squeeze(samples_outcome[idx, :])
                    samples_names = np.array(samples_names)
                    samples_names = samples_names[idx]

                    np.save(str.format("{0}/samples_outcome.npy", eval_dir), samples_outcome)
                    np.save(str.format("{0}/samples_names.npy", eval_dir), samples_names)

                    P = np.sum(samples_outcome[:, 1])

                    TP = np.cumsum(samples_outcome[:, 1] == 1).astype(np.float32).T
                    FP = np.cumsum(samples_outcome[:, 1] == 0).astype(np.float32).T

                    recall = TP / P
                    precision = TP / (TP + FP)

                    f_measure = 2 * np.multiply(recall, precision) / (recall + precision)


                    idx = np.argmax(f_measure)

                    best_f_measure = f_measure[idx]
                    best_thr = samples_outcome[idx,0]
                    best_FP = FP[idx]
                    best_FN = P - TP[idx]

                    precision_, recall_, thresholds = precision_recall_curve(samples_outcome[:, 1], samples_outcome[:, 0])
                    FPR, TPR, _ = roc_curve(samples_outcome[:, 1], samples_outcome[:, 0])
                    AUC = auc(FPR,TPR)
                    AP = auc(recall_, precision_)

                    print('AUC=%f, and AP=%f, with best thr=%f at f-measure=%.3f and FP=%d, FN=%d' % (AUC, AP, best_thr, best_f_measure, best_FP, best_FN))

                    plt.figure(1)
                    plt.clf()
                    plt.plot(recall, precision)
                    plt.title('Average Precision=%.4f' % AP)
                    plt.xlabel('Recall')
                    plt.ylabel('Precision')
                    plt.savefig(str.format("{0}/precision-recall.pdf", eval_dir), bbox_inches='tight')

                    plt.figure(1)
                    plt.clf()
                    plt.plot(FPR, TPR)
                    plt.title('AUC=%.4f' % AUC)
                    plt.xlabel('False positive rate')
                    plt.ylabel('True positive rate')
                    plt.savefig(str.format("{0}/ROC.pdf", eval_dir), bbox_inches='tight')




            except Exception as e:  # pylint: disable=broad-except
                coord.request_stop(e)

            coord.request_stop()
            coord.join(threads, stop_grace_period_secs=10)

        return samples_outcome,samples_names, samples_speed_eval

evaluate

调用_eval_once方法定期评估模型性能

(根据run_once, eval_interval_secs 评定)

    def evaluate(self, dataset, run_once = True, eval_interval_secs = 5, plot_results=True):
        """Evaluate model on Dataset for a number of steps."""
        with tf.Graph().as_default():
            # Get images and labels from the input_data.
            images, labels, img_names = self.input.add_inputs_nodes(dataset, False)

            # Number of classes in the Dataset label set plus 1.
            # Label 0 is reserved for an (unused) background class.
            num_classes = dataset.num_classes() + 1

            # Build a Graph that computes the logits predictions from the
            # inference model.
            with tf.name_scope('%s_%d' % (self.TOWER_NAME, 0)) as scope:
                net, decision,  _ = self.model.get_inference(images, num_classes, scope=scope)

            # Restore the moving average version of the learned variables for eval.
            variable_averages = tf.train.ExponentialMovingAverage(self.model.MOVING_AVERAGE_DECAY)
            variables_to_restore = variable_averages.variables_to_restore()

            eval_dir = os.path.join(self.eval_dir, dataset.subset)
            try:
                os.makedirs(eval_dir)
            except:
                pass

            while True:
                samples_outcome,samples_names, samples_speed_eval = self._eval_once(eval_dir, variables_to_restore, net, decision, images, labels, img_names, dataset.num_examples_per_epoch(),plot_results)
                if run_once:
                    break
                time.sleep(eval_interval_secs)

            num_params = np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()])

        return samples_outcome,samples_names, samples_speed_eval,num_params

if __name__=='__main__'

解析命令行参数,初始化模型训练器,运行训练和评估流程。

import tensorflow as tf

from segdec_model import SegDecModel
from segdec_data import InputData

if __name__ == '__main__':

    import argparse, glob, shutil

    def str2bool(v):
        return v.lower() in ("yes", "true", "t", "1")

    parser = argparse.ArgumentParser()

    # add boolean parser to allow using 'false' in arguments
    parser.register('type', 'bool', str2bool)

    parser.add_argument('--folds',type=str, help="Comma delimited list of ints identifying which folds to use.")
    parser.add_argument('--gpu', type=str, help="Comma delimited list of ints identifying which GPU ids to use.")
    parser.add_argument('--storage_dir', help='Path to your storage dir where segdec_train (tensorboard info) and segdec_eval (results) will be stored.',
                        type=str,
                        default='/opt/workspace/host_storage_hdd/')
    parser.add_argument('--dataset_dir', help='Path to your input_data dirs.',
                        type=str,
                        default='/opt/workspace/host_storage_hdd/')
    parser.add_argument('--datasets', help='Comma delimited list of input_data names to use, e.g., "Dataset1,Dataset2".',
                        type=str, default=','.join(['KolektorSDD']))
    parser.add_argument('--name_prefix',type=str, default=None)
    parser.add_argument('--train_subset', type=str, default="train_pos")
    parser.add_argument('--pretrained_model', type=str, default=None)
    parser.add_argument('--pretrained_main_folder', type=str, default=None)

    parser.add_argument('--size_height', type=int, default=2*704)
    parser.add_argument('--size_width', type=int, default=2*256)

    parser.add_argument('--seg_net_type', type=str, default='MSE')

    parser.add_argument('--input_rotation', type='bool', default=False)

    parser.add_argument('--with_seg_net', type='bool', default=True)
    parser.add_argument('--with_decision_net', type='bool', default=False)
    parser.add_argument('--lr', type=float, default=0)
    parser.add_argument('--max_steps', type=int, default=6600)

    parser.add_argument('--channels', type=int, default=1)
    parser.add_argument('--pos_weights', type=float, default=1)

    parser.add_argument('--ensure_posneg_balance', type='bool', default=True)

    args = parser.parse_args()

    main_storage_dir = args.storage_dir
    main_dataset_folder = args.dataset_dir
    dataset_list = args.datasets.split(",")
    train_subset = args.train_subset
    pretrained_model = args.pretrained_model
    pretrained_main_folder = args.pretrained_main_folder
    pos_weights = args.pos_weights
    ensure_posneg_balance = args.ensure_posneg_balance

    size_height = args.size_height
    size_width = args.size_width
    channels = args.channels

    seg_net_type = args.seg_net_type

    input_rotation = args.input_rotation

    with_seg_net = args.with_seg_net
    with_decision_net = args.with_decision_net

    max_steps = args.max_steps
    lr = args.lr

    if seg_net_type == 'MSE':
        lr_val = 0.005
        use_corss_entropy_seg_net = False
    elif seg_net_type == 'ENTROPY':
        lr_val = 0.1
        use_corss_entropy_seg_net = True
    else:
        raise Exception('Unkown SEG-NET type; allowed only: \'MSE\' or \'ENTROPY\'')


    if lr > 0:
        lr_val = lr

    folds = [int(i) for i in args.folds.split(",")]
    for i in folds:
        if i >= 0:
            fold_name = 'fold_%d' % i
        else:
            fold_name = ''

        for d in dataset_list:

            run_name = os.path.join(d, fold_name if args.name_prefix is None else os.path.join(args.name_prefix, fold_name))

            dataset_folder = os.path.join(main_dataset_folder, d)
            print("running", dataset_folder, run_name)

            if with_decision_net is False:
                # use bigger lr for sigmoid_corss_correlation loss
                net_model = SegDecModel(decision_net=SegDecModel.DECISION_NET_NONE,
                                        use_corss_entropy_seg_net=use_corss_entropy_seg_net,
                                        positive_weight=pos_weights)
            else:
                # use lr=0.005 ofr mean squated error loss
                net_model = SegDecModel(decision_net=SegDecModel.DECISION_NET_FULL,
                                        use_corss_entropy_seg_net=use_corss_entropy_seg_net,
                                        positive_weight = pos_weights)
            current_pretrained_model = pretrained_model

            if current_pretrained_model is None and pretrained_main_folder is not None:
                current_pretrained_model = os.path.join(pretrained_main_folder,fold_name)

            train = SegDecTrain(net_model,
                                storage_dir=main_storage_dir,
                                run_string=run_name,
                                image_size=(size_height,size_width,channels),  # NOTE size should be dividable by 16 !!!
                                batch_size=1,
                                learning_rate=lr_val,
                                max_steps=max_steps,
                                max_epochs=1200,
                                visible_device_list=args.gpu,
                                pretrained_model_checkpoint_path=current_pretrained_model,
                                train_segmentation_net=with_seg_net,
                                train_decision_net=with_decision_net,
                                use_random_rotation=input_rotation,
                                ensure_posneg_balance=ensure_posneg_balance)

            dataset_fold_folder = os.path.join(dataset_folder,fold_name)

            # Run training
            train.train(InputData(train_subset, dataset_fold_folder))

            if with_decision_net:
                # Run evaluation on test data
                samples_outcome_test,samples_names_test, samples_speed_eval,num_params = train.evaluate(InputData('test', dataset_fold_folder))

                np.save(os.path.join(main_storage_dir, 'segdec_eval', run_name, 'test', 'results_decision_net.npy'), samples_outcome_test)
                np.save(os.path.join(main_storage_dir, 'segdec_eval', run_name, 'test', 'results_decision_net_names.npy'), samples_names_test)
                np.save(os.path.join(main_storage_dir, 'segdec_eval', run_name, 'test', 'results_decision_net_speed_eval.npy'), samples_speed_eval)

                # Copy results from test dir of specific fold into common folder for this input_data
                src_dir = os.path.join(main_storage_dir, 'segdec_eval', run_name, 'test')
                dst_dir = os.path.join(main_storage_dir, 'segdec_eval', d if args.name_prefix is None else os.path.join(d,args.name_prefix))
                for src_file in glob.glob(os.path.join(src_dir, '*.pdf')):
                    shutil.copy(src_file, dst_dir)

segdec_print_eval.py

这个主要是计算指标:精准率、召回率、F1分数、AUC、AP(平均精度)

calc_confusion_mat

计算FP  FN  TN  TP

def calc_confusion_mat(D, Y):
    FP = (D != Y) & (Y.astype(np.bool) == False)
    FN = (D != Y) & (Y.astype(np.bool) == True)
    TN = (D == Y) & (Y.astype(np.bool) == False)
    TP = (D == Y) & (Y.astype(np.bool) == True)

    return FP, FN, TN, TP

get_performance_eval

计算指标

def get_performance_eval(P,Y):
    precision_, recall_, thresholds = precision_recall_curve(Y.astype(np.int32), P)
    FPR, TPR, _ = roc_curve(Y.astype(np.int32), P)
    AUC = auc(FPR, TPR)
    AP = average_precision_score(Y.astype(np.int32), P)

    f_measure = 2 * (precision_ * recall_) / (precision_ + recall_ + 0.0000000001)

    best_idx = np.argmax(f_measure)

    f_measure[best_idx]
    thr = thresholds[best_idx]

    FP, FN, TN, TP = calc_confusion_mat(P >= thr, Y)

    FP_, FN_, TN_, TP_ = calc_confusion_mat(P >= thresholds[np.where(recall_ >= 1)], Y)

    F_measure = (2 * TP.sum()) / float(2 * TP.sum() + FP.sum() + FN.sum())

    return TP, FP, FN, TN, TP_, FP_, FN_, TN_, F_measure, AUC, AP

evaluate_decision

打印指标

def evaluate_decision(data_dir, folds_list = [0,1,2]):

    PD_decision_net = None

    num_params_list = []

    for f in folds_list:
        if f >= 0:
            fold_name = 'fold_%d' % f
        else:
            fold_name = ''

        sample_outcomes = np.load(os.path.join(data_dir, fold_name, 'test', 'results_decision_net.npy'))

        if len(sample_outcomes) > 0:
            PD_decision_net = np.concatenate((PD_decision_net, sample_outcomes)) if PD_decision_net is not None else sample_outcomes

        num_params_filename = os.path.join(data_dir, fold_name, 'test', 'decision_net_num_params.npy')
        if os.path.exists(num_params_filename):
            num_params_list.append(np.load(num_params_filename))

    results = None

    if PD_decision_net is not None:

        TP, FP, FN, TN, TP_, FP_, FN_, TN_, F_measure, AUC, AP = get_performance_eval(PD_decision_net[:,0], PD_decision_net[:,1])

        print("AP: %.03f, FP/FN: %d/%d, FP@FN=0: %d" % (AP, FP.sum(), FN.sum(), FP_.sum()))

        results = {'TP': TP.sum(),
                   'FP': FP.sum(),
                   'FN': FN.sum(),
                   'TN': FN.sum(),
                   'FP@FN=0': FP_.sum(),
                   'f-measure': F_measure,
                   'AUC': AUC,
                   'AP': AP}

    return results

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值