TensorFlow 中 tf.app.flags.FLAGS 的用法介绍

本文介绍如何使用 TensorFlow 中 tf.app.flags.FLAGS 进行命令行参数配置,通过定义字符串、整数及浮点型参数,并在主函数中读取这些参数,实现灵活调整训练数据路径、最大句子长度等关键设置。
部署运行你感兴趣的模型镜像

下面介绍 tf.app.flags.FLAGS 的使用,主要是在用命令行执行程序时,需要传些参数,代码如下:

新建一个名为:app_flags.py 的文件。

#coding:utf-8

# 学习使用 tf.app.flags 使用,全局变量
# 可以再命令行中运行也是比较方便,如果只写 python app_flags.py 则代码运行时默认程序里面设置的默认设置
# 若 python app_flags.py --train_data_path <绝对路径 train.txt> --max_sentence_len 100
#    --embedding_size 100 --learning_rate 0.05  代码再执行的时候将会按照上面的参数来运行程序

import tensorflow as tf

FLAGS = tf.app.flags.FLAGS

# tf.app.flags.DEFINE_string("param_name", "default_val", "description")
tf.app.flags.DEFINE_string("train_data_path", "/home/yongcai/chinese_fenci/train.txt", "training data dir")
tf.app.flags.DEFINE_string("log_dir", "./logs", " the log dir")
tf.app.flags.DEFINE_integer("max_sentence_len", 80, "max num of tokens per query")
tf.app.flags.DEFINE_integer("embedding_size", 50, "embedding size")

tf.app.flags.DEFINE_float("learning_rate", 0.001, "learning rate")


def main(unused_argv):
    train_data_path = FLAGS.train_data_path
    print("train_data_path", train_data_path)
    max_sentence_len = FLAGS.max_sentence_len
    print("max_sentence_len", max_sentence_len)
    embdeeing_size = FLAGS.embedding_size
    print("embedding_size", embdeeing_size)
    abc = tf.add(max_sentence_len, embdeeing_size)

    init = tf.global_variables_initializer()

    #with tf.Session() as sess:
        #sess.run(init)
        #print("abc", sess.run(abc))

    sv = tf.train.Supervisor(logdir=FLAGS.log_dir, init_op=init)
    with sv.managed_session() as sess:
        print("abc:", sess.run(abc))

        # sv.saver.save(sess, "/home/yongcai/tmp/")


# 使用这种方式保证了,如果此文件被其他文件 import的时候,不会执行main 函数
if __name__ == '__main__':
    tf.app.run()   # 解析命令行参数,调用main 函数 main(sys.argv)


调用方法:

其中参数可以根据需求进行修改。

python app_flags.py --train_data_path <绝对路径 train.txt> --max_sentence_len 100 --embedding_size 100 --learning_rate 0.05

如果这样调用:

python app_flags.py

则会执行程序时会自动调用程序中 default 中的参数。




您可能感兴趣的与本文相关的镜像

TensorFlow-v2.15

TensorFlow-v2.15

TensorFlow

TensorFlow 是由Google Brain 团队开发的开源机器学习框架,广泛应用于深度学习研究和生产环境。 它提供了一个灵活的平台,用于构建和训练各种机器学习模型

### TensorFlowflagsFLAGS 的代码实现及用途 TensorFlow 提供了 `tf.flags`(在较新的版本中已迁移到 `absl-py` 的 `flags` 模块)来方便用户定义和解析命令行参数。以下是关于 `flags` 和 `FLAGS` 的详细解读。 #### 1. 定义与初始化 `tf.flags` 或 `absl.flags` 提供了一个简单的方式用于定义命令行参数。这些参数可以是布尔型、整型、浮点型或字符串型等。以下是一个示例代码,展示如何定义和使用这些参数: ```python import tensorflow as tf # 创建一个 Flags 对象 flags = tf.app.flags # 使用 DEFINE_* 方法定义参数 flags.DEFINE_string('model_dir', './models', '模型保存路径') flags.DEFINE_integer('num_epochs', 10, '训练的轮数') flags.DEFINE_float('learning_rate', 0.001, '学习率') flags.DEFINE_boolean('use_gpu', False, '是否使用 GPU') # 获取 FLAGS 对象 FLAGS = flags.FLAGS ``` #### 2. 参数的使用 定义好参数后,可以通过 `FLAGS` 对象访问这些参数。例如: ```python if FLAGS.use_gpu: with tf.device('/gpu:0'): print(f"Using GPU with learning rate {FLAGS.learning_rate}") else: print(f"Using CPU with learning rate {FLAGS.learning_rate}") print(f"Model will be saved in {FLAGS.model_dir}") print(f"Training for {FLAGS.num_epochs} epochs") ``` #### 3. 命令行传参 运行脚本时,可以通过命令行传递参数值。例如,假设脚本名为 `train.py`,可以通过以下方式运行: ```bash python train.py --model_dir=./new_models --num_epochs=20 --learning_rate=0.01 --use_gpu=True ``` 上述命令会覆盖默认的参数值。 #### 4. 内部实现机制 `DEFINE_*` 系列函数实际上是将参数注册到一个全局的参数表中。当脚本运行时,`tf.app.run()` 或 `flags.FLAGS.parse_args()` 会解析命令行参数,并更新这些全局变量的值[^1]。 #### 5. 示例完整代码 以下是一个完整的示例,展示如何定义和使用 `flags` 和 `FLAGS`: ```python import tensorflow as tf # 创建 Flags 对象 flags = tf.app.flags # 定义参数 flags.DEFINE_string('data_dir', './data', '数据集路径') flags.DEFINE_integer('batch_size', 64, '批量大小') flags.DEFINE_float('dropout_rate', 0.5, 'Dropout 概率') flags.DEFINE_boolean('verbose', True, '是否打印详细日志') # 获取 FLAGS 对象 FLAGS = flags.FLAGS def main(_): print(f"Data directory: {FLAGS.data_dir}") print(f"Batch size: {FLAGS.batch_size}") print(f"Dropout rate: {FLAGS.dropout_rate}") if FLAGS.verbose: print("Verbose mode is enabled") if __name__ == '__main__': tf.app.run() ``` 运行脚本时,可以通过命令行传递参数: ```bash python script.py --data_dir=./custom_data --batch_size=128 --dropout_rate=0.3 --verbose=False ``` #### 6. 注意事项 - 在 TensorFlow 2.x 中,`tf.flags` 已被废弃,建议使用 `absl-py` 的 `flags` 模块。 - 如果需要支持更多的参数类型或自定义行为,可以扩展 `DEFINE_*` 函数的功能。 - 在多线程环境中,确保 `FLAGS` 的访问是线程安全的。 --- ###
评论 5
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值