(tensorflow)——tf.app.run()使用方法的解释

转载:https://blog.youkuaiyun.com/TwT520Ly/article/details/79759448

在一些github上公开的代码中,我们经常会看到这样的程序

if __name__ == '__main__':
    tf.app.run()

像网上的大多数文章一样,先粘贴一下run()的源码:

def run(main=None, argv=None):
  """Runs the program with an optional 'main' function and 'argv' list."""
  f = flags.FLAGS

  # Extract the args from the optional `argv` list.
  args = argv[1:] if argv else None

  # Parse the known flags from that list, or from the command
  # line otherwise.
  # pylint: disable=protected-access
  flags_passthrough = f._parse_flags(args=args)
  # pylint: enable=protected-access

  main = main or _sys.modules['__main__'].main

  # Call the main function, passing through any arguments
  # to the final program.
  _sys.exit(main(_sys.argv[:1] + flags_passthrough))


_allowed_symbols = [
    'run',
    # Allowed submodule.
    'flags',
]

remove_undocumented(__name__, _allowed_symbols)

源码中写的很清楚,首先加载flags的参数项,然后执行main()函数,其中参数使用tf.app.flags.FLAGS定义的。

tf.app.flags.FLAGS

# fila_name: temp.py
import tensorflow as tf

FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_string('string', 'train', 'This is a string')
tf.app.flags.DEFINE_float('learning_rate', 0.001, 'This is the rate in training')
tf.app.flags.DEFINE_boolean('flag', True, 'This is a flag')

print('string: ', FLAGS.string)
print('learning_rate: ', FLAGS.learning_rate)
print('flag: ', FLAGS.flag)

输出:

string:  train
learning_rate:  0.001
flag:  True

如果在命令行中执行python3 temp.py --help,得到输出:

usage: temp.py [-h] [--string STRING] [--learning_rate LEARNING_RATE]
               [--flag [FLAG]] [--noflag]

optional arguments:
  -h, --help            show this help message and exit
  --string STRING       This is a string
  --learning_rate LEARNING_RATE
                        This is the rate in training
  --flag [FLAG]         This is a flag
  --noflag

如果要对FLAGS的默认值进行修改,只要输入命令:

python3 temp.py --string 'test' --learning_rate 0.2 --flag False

联合使用

import tensorflow as tf

FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_string('string', 'train', 'This is a string')
tf.app.flags.DEFINE_float('learning_rate', 0.001, 'This is the rate in training')
tf.app.flags.DEFINE_boolean('flag', True, 'This is a flag')

def main(unuse_args):
    print('string: ', FLAGS.string)
    print('learning_rate: ', FLAGS.learning_rate)
    print('flag: ', FLAGS.flag)

if __name__ == '__main__':
    tf.app.run()

主函数中的tf.app.run()会调用main,并传递参数,因此必须在main函数中设置一个参数位置。如果要更换main名字,只需要在tf.app.run()中传入一个指定的函数名即可。

def test(args):
    # test
    ...
if __name__ == '__main__':
    tf.app.run(test)
### TensorFlow DTensor 使用指南 #### 安装与配置环境 为了使用DTensor,在环境中安装正确版本的TensorFlow至关重要。确保输入了正确的包名为`tensorflow_datasets`,而不是其他相似名称,并确认TensorFlow版本与`tensorflow_datasets`版本兼容[^2]。 可以通过命令行工具`pip`来查看当前已安装的TensorFlow版本: ```bash pip show tensorflow ``` 如果尚未安装最新版TensorFlow,则建议更新至支持DTensor功能的新版本。 #### 创建分布式策略 在Keras中利用基于TensorFlow的流行深度学习库实现分布式训练时,可以探索如何通过它使用分布式张量[^1]。创建一个分布策略对象作为模型构建的基础部分,这允许程序定义数据并行性和模型并行性的逻辑抽象层。 ```python import tensorflow as tf from keras import layers, models strategy = tf.distribute.MultiWorkerMirroredStrategy() with strategy.scope(): model = models.Sequential([ layers.Dense(64, activation='relu', input_shape=(784,)), layers.Dense(10, activation='softmax') ]) ``` 上述代码展示了怎样设置一个多工作节点镜像策略实例,并在其作用域内建立神经网络架构;此方式适用于跨多个GPU或TPU设备执行高效同步SGD算法优化过程中的参数更新操作。 #### 配置DTensor API 对于更复杂的场景,比如涉及不同维度上的分割模式或者混合精度计算等情况,可以直接调用DTensor提供的API接口完成自定义布局设计等工作。下面给出一段简单的例子说明如何初始化DTensor网格以及指定特定形状下的张量切分方案: ```python mesh = dtensor.create_mesh([("batch", 8), ("model", 4)]) layout = dtensor.Layout.replicated(mesh=mesh, rank=2) tensor = tf.constant([[1., 2.], [3., 4.]]) dtensor_tensor = dtensor.copy_to_mesh(tensor=tensor, layout=layout) ``` 这里先声明了一个由两个轴组成的虚拟网格结构——其中一个是批次大小方向(batch),另一个则是模型内部组件数量(model);接着为即将产生的二阶张量指定了全复制形式(`replicated`)的数据放置规则;最后把常规格式转换成带有明确位置信息的对象表示法以便后续处理阶段能够识别到各自所属的工作单元组别及其相对应的部分视图映射关系。 #### 实现端到端应用案例 当掌握了基础概念和技术细节之后就可以尝试搭建完整的机器学习流水线了。从加载公开可用的数据集开始直到最终评估预测性能指标为止,整个流程都应当遵循最佳实践原则以保证结果的有效性和可重复性。以下是有关MNIST手写数字分类任务的一个简化版示范脚本片段: ```python def main(_): datasets, info = tfds.load(name="mnist", with_info=True, as_supervised=True) train_data = prepare_dataset(datasets['train']) test_data = prepare_dataset(datasets['test']) with strategy.scope(): optimizer = optimizers.Adam() loss_fn = losses.SparseCategoricalCrossentropy(from_logits=True) @tf.function def distributed_train_step(dist_inputs): per_replica_losses = strategy.run(train_step, args=(dist_inputs,)) return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None) checkpoint_callback = callbacks.ModelCheckpoint(filepath="./checkpoints/mnist_model") for epoch in range(EPOCHS): total_loss = 0.0 num_batches = 0 for batch_x, batch_y in train_data: total_loss += float(distributed_train_step((batch_x, batch_y))) num_batches += 1 print(f'Epoch {epoch}, Loss: {total_loss / num_batches}') model.evaluate(test_data) model.save_weights('./weights/') if __name__ == '__main__': app.run(main) ``` 这段代码综合运用了前面介绍过的知识点,实现了多机环境下大规模图像样本集合上进行卷积特征提取、损失函数最小化求解等一系列核心环节的操作步骤描述。
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值