TF学习之DeepLabv3+代码阅读6(train_utils)

DeepLabv3+代码阅读之train_utils.py

一、get_model_learning_rate()

def get_model_learning_rate(learning_policy,# Learning rate policy for training.
                            base_learning_rate,# The base learning rate for model training.
                            learning_rate_decay_step, # Decay the base learning rate at a fixed step.
                            learning_rate_decay_factor,# The rate to decay the base learning rate.
                            training_number_of_steps,# Number of steps for training.
                            learning_power,# Power used for 'poly' learning policy.
                            slow_start_step,# Training model with small learning rate for the 
                            				# first few steps.
                            slow_start_learning_rate,# The learning rate employed during slow start.
                            slow_start_burnin_type='none'):# The burnin type for the slow start stage. Can be
      													   #`none` which means no burnin or `linear` which 
      													   # means the learning rate increases linearly from 
      													   # slow_start_learning_rate and reaches
      													   # base_learning_rate after slow_start_steps.
  """Gets model's learning rate.

  Computes the model's learning rate for different learning policy.
  Right now, only "step" and "poly" are supported.
  (1) The learning policy for "step" is computed as follows:
    current_learning_rate = base_learning_rate *
      learning_rate_decay_factor ^ (global_step / learning_rate_decay_step)
  See tf.train.exponential_decay for details.
  (2) The learning policy for "poly" is computed as follows:
    current_learning_rate = base_learning_rate *
      (1 - global_step / training_number_of_steps) ^ learning_power

  """
  global_step = tf.train.get_or_create_global_step()
  adjusted_global_step = global_step

  if slow_start_burnin_type != 'none':
    adjusted_global_step -= slow_start_step

  if learning_policy == 'step':
    learning_rate = tf.train.exponential_decay(
        base_learning_rate,
        adjusted_global_step,
        learning_rate_decay_step,
        learning_rate_decay_factor,
        staircase=True)
  elif learning_policy == 'poly':
    learning_rate = tf.train.polynomial_decay(
        base_learning_rate,
        adjusted_global_step,
        training_number_of_steps,
        end_learning_rate=0,
        power=learning_power)
  else:
    raise ValueError('Unknown learning policy.')

  adjusted_slow_start_learning_rate = slow_start_learning_rate
  if slow_start_burnin_type == 'linear':
    # Do linear burnin. Increase linearly from slow_start_learning_rate and
    # reach base_learning_rate after (global_step >= slow_start_steps).
    adjusted_slow_start_learning_rate = (
        slow_start_learning_rate +
        (base_learning_rate - slow_start_learning_rate) 
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值