18、使用 tf.app.flags 接口定义命令行参数

本文介绍如何使用TensorFlow的tf.app.flags接口定义并修改命令行参数,包括各种类型的参数如整数、浮点数、字符串及布尔值,并演示如何在Python脚本中设置这些参数及其默认值。

一、使用 tf.app.flags 接口定义命令行参数

  • 众所周知,深度学习有很多的 Hyperparameter 需要调优,TensorFlow 底层使用了python-gflags项目,然后封装成tf.app.flags接口
  • 使用tf.app.flags接口可以非常方便的调用自带的DEFINE_string, DEFINE_boolean, DEFINE_integer, DEFINE_float设置不同类型的命令行参数及其默认值,在实际项目中一般会提前定义命令行参数,如下所示:
# coding: utf-8
# filename: flags.py
import tensorflow as tf

# 定义一个全局对象来获取参数的值,在程序中使用(eg:FLAGS.iteration)来引用参数
FLAGS = tf.app.flags.FLAGS

# 定义命令行参数,第一个是:参数名称,第二个是:参数默认值,第三个是:参数描述
tf.app.flags.DEFINE_integer("iteration", 200000, "Iterations to train [2e5]")
tf.app.flags.DEFINE_integer("disp_freq", 1000, "Display the current results every display_freq iterations [1e3]")
tf.app.flags.DEFINE_integer("save_freq", 2000, "Save the checkpoints every save_freq iterations [2e3]")
tf.app.flags.DEFINE_float("learning_rate", 0.001, "Learning rate of for adam [0.001]")
tf.app.flags.DEFINE_integer("train_batch_size", 64, "The size of batch images [64]")
tf.app.flags.DEFINE_integer("val_batch_size", 100, "The size of batch images [100]")
tf.app.flags.DEFINE_integer("height", 48, "The height of image to use. [48]")
tf.app.flags.DEFINE_integer("width", 160, "The width of image to use. [160]")
tf.app.flags.DEFINE_integer("depth", 3, "Dimension of image color. [3]")
tf.app.flags.DEFINE_string("data_dir", "/path/to/data_sets/", "Directory of dataset in the form of TFRecords.")
tf.app.flags.DEFINE_string("checkpoint_dir", "/path/to/checkpoint_save_dir/", "Directory name to save the checkpoints [checkpoint]")
tf.app.flags.DEFINE_string("model_name", "40w_grtr", "Model name. [40w_grtr]")
tf.app.flags.DEFINE_string("gpu_id", "0", "Which GPU to be used. [0]")
tf.app.flags.DEFINE_boolean("continue_train", False, "True for continue training.[False]")
tf.app.flags.DEFINE_boolean("per_image_standardization", True, "True for per_image_standardization.[True]")


# 定义主函数
def main(argv=None):  
    print(FLAGS.iteration)
    print(FLAGS.learning_rate)
    print(FLAGS.data_dir)
    print(FLAGS.continue_train)


# 执行main函数
if __name__ == '__main__':
    tf.app.run()  

二、执行程序的方法

1、使用程序中的默认参数
  • python flags.py

这里写图片描述

2、在命令行更改程序中的默认参数
  • python flags.py --iteration=500000 --learning_rate=0.01 --data_dir='/home/test/' --continue_train=True

这里写图片描述

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值