tf.train.MonitoredTrainingSession()解析【精】

最近看了下cifar10源码,训练代码中使用了tf.train.SessionRunHook(),tf.train.MonitoredTrainingSession();查看官方API后终于有些眉目了,特记录备忘。

 

首先,先讲下tf.train.MonitoredTrainingSession();

 一 .MonitoredTrainingSession()

 

首先,tf.train.MonitorSession()从单词的字面意思理解是用于监控训练的回话,返回值是tf.train.MonitorSession()类的一个实例Object, tf.train.MonitorSession()会在下面讲。

MonitoredTrainingSession(
    master='',
    is_chief=True,
    checkpoint_dir=None,
    scaffold=None,
    hooks=None,
    chief_only_hooks=None,
    save_checkpoint_secs=600,
    save_summaries_steps=USE_DEFAULT,
    save_summaries_secs=USE_DEFAULT,
    config=None,
    stop_grace_period_secs=120,
    log_step_count_steps=100
)

Args:

  •  is_chief:用于分布式系统中,用于判断该系统是否是chief,如果为True,它将负责初始化并恢复底层TensorFlow会话。如果为False,它将等待chief初始化或恢复TensorFlow会话。

  •  checkpoint_dir:一个字符串。指定一个用于恢复变量的checkpoint文件路径。

  •  scaffold:用于收集或建立支持性操作的脚手架。如果未指定,则会创建默认一个默认的scaffold。它用于完成图表

  •  hooks:SessionRunHook对象的可选列表。可自己定义SessionRunHook对象,也可用已经预定义好的SessionRunHook对象,如:tf.train.StopAtStepHook()设置停止训练的条件;tf.train.NanTensorHook(loss):如果loss的值为Nan则停止训练;

  •  chief_only_hooks:SessionRunHook对象列表。如果is_chief== True,则激活这些挂钩,否则忽略。

  •  

     save_checkpoint_secs:用默认的checkpoint saver保存checkpoint的频率(以秒为单位)。如果save_checkpoint_secs设置为None,不保存checkpoint。

  • save_summaries_steps:使用默认summaries saver将摘要写入磁盘的频率(以全局步数表示)。如果save_summaries_steps和save_summaries_secs都设置为None,则不使用默认的summaries saver保存summaries。默认为100

  •  

    save_summaries_secs:使用默认summaries saver将摘要写入磁盘的频率(以秒为单位)。如果save_summaries_steps和save_summaries_secs都设置为None,则不使用默认的摘要保存。默认未启用。

  •  

    config:用于配置会话的tf.ConfigProtoproto的实例。它是tf.Session的构造函数的config参数。

  •  

     stop_grace_period_secs:调用close()后线程停止的秒数。

  •  

     log_step_count_steps:记录全局步/秒的全局步数的频率

Returns:          

       一个MonitoredSession() 实例。


下面主要介绍tf.train.MonitoredSession()类

 

二tf.train.MonitoredSession()类

官方文档给的定义是:

Session-like object that handles initialization, recovery and hooks.

是一个处理初始化,模型恢复,和处理Hooks的类似与Session的类。

Args:

  •  

    session_creator:制定用于创建回话的ChiefSessionCreator

  •  

    hooks:tf.train.SessionRunHook()实例的列表

Returns:          

         一个MonitoredSession 实例。

 

Example usage:

saver_hook = CheckpointSaverHook(...)
summary_hook = SummarySaverHook(...)
with MonitoredSession(session_creator=ChiefSessionCreator(...),
                      hooks=[saver_hook, summary_hook]) as sess:
    while not sess.should_stop():
        sess.run(train_op)

初始化:在创建一个MonitoredSession时,会按顺序执行以下操作:

  • 调用[Hooks]列表中每一个Hook的begin()函数
  • 通过scaffold.finalize()完成图graph的定义
  • 创建会话
  • 用Scaffold提供的初始化操作(op)来初始化模型
  • 如果给定checkpoint_dir中存在checkpoint文件,则用checkpoint恢复变量
  • 启动队列线程
  • 调用hook.after_create_session()

Run:当调用run()函数时,按顺序执行以下操作

  • 调用hook.before_run()
  • 用合并后的fetches 和feed_dict调用TensorFlow的session.run() (这里是真正调用tf.Session().run(fetches ,feed_dict))
  • 调用hook.after_run()
  • 返回用户需要的session.run()的结果
  • 如果发生了AbortedError或者UnavailableError,则在再次执行run()之前恢复或者重新初始化会话

Exit:当调用close()退出时,按顺序执行下列操作

  • 调用hook.end()
  • 关闭队列线程queuerunners和会话session
  • 在monitored_session的上下文中,抑制由于处理完所有输入抛出的OutOf Range错误。

 

tf.train.MonitoredTrainingSession()中的tf.train.SessRunHook()介绍与使用请看我的另一篇博客

才疏学浅,如有错误欢迎留言指出!

 

 

<think>好的,我现在要解决用户遇到的TensorFlow 2.x中的AttributeError问题,具体是找不到'tf.train.exponential_decay'属性。首先,我得回忆一下TensorFlow 1.x和2.x之间的主要变化,特别是API的变动。 用户提供的引用内容提到了类似的问题,比如找不到Optimizer、contrib模块等,这些通常是因为从TF1升级到TF2后,许多模块和函数被移动或删除了。例如,在引用[5]中,用户将tf.train.Optimizer改为tf.keras.optimizers.Optimizer,这表明在TF2中,部分优化器被整合到Keras模块里了。所以,我需要检查exponential_decay函数是否也发生了类似的迁移。 首先,我应该确认在TF2中,exponential_decay是否存在。根据之前的经验,TF2推荐使用Keras的API,而很多tf.train中的函数可能被移到其他地方。查阅官方文档会发现,exponential_decay现在属于tf.keras.optimizers.schedules中的ExponentialDecay类。因此,用户可能需要改用新的方式定义学习率衰减。 接下来,用户原来的代码可能是这样写的: learning_rate = tf.train.exponential_decay(...) 这在TF2中会报错,因为tf.train下已经没有这个函数了。正确的做法是使用tf.keras.optimizers.schedules.ExponentialDecay,然后将其传递给优化器。例如: lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(...) optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule) 这样就能正确使用学习率衰减了。 另外,用户可能还在使用TF1兼容模式,即使用tf.compat.v1,这时候可以暂时用tf.compat.v1.train.exponential_decay,但这不是长久之计,应该迁移到TF2的native API。引用[4]中的错误也提到了类似的情况,比如使用tf.train.GradientDescentOptimizer会导致错误,需要替换成Keras的优化器。 还需要检查用户的TensorFlow版本是否正确。如果用户误装了TF1,或者版本过旧,也可能出现这个问题。建议用户升级到最新的TF2版本,并确保代码中没有其他TF1的遗留代码,比如Session、placeholder等,这些在TF2中默认都被移除了。 总结解决步骤: 1. 检查TensorFlow版本,确保是2.x。 2. 替换tf.train.exponential_decay为Keras中的ExponentialDecay类。 3. 调整优化器的使用方式,将学习率计划传递给优化器。 4. 如果必须使用TF1兼容代码,启用v1兼容模式,但建议迁移到TF2原生API。 此外,可能还需要处理其他相关的API变化,比如优化器的调用方式,变量初始化等,如引用[4]中的错误提示。需要确保整个代码都符合TF2的规范,避免混合使用新旧API。</think>### TensorFlow 2.x 中解决 `'tf.train.exponential_decay'` 属性错误的方法 #### 1. **问题原因分析** 在 TensorFlow 2.x 中,许多 1.x 版本的 API 被重构或移至新模块。`tf.train.exponential_decay` 已迁移至 `tf.keras.optimizers.schedules.ExponentialDecay`[^5]。直接调用 `tf.train.exponential_decay` 会触发 `AttributeError`,因为 `tf.train` 模块不再包含此函数[^1][^4]。 --- #### 2. **解决方案** ##### 方法 1:使用 TensorFlow 2.x 推荐的新 API ```python import tensorflow as tf # 定义学习率衰减策略 initial_learning_rate = 0.1 lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( initial_learning_rate, decay_steps=1000, decay_rate=0.96, staircase=True ) # 将学习率传递给优化器 optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule) ``` ##### 方法 2:临时使用 TensorFlow 1.x 兼容模式 (不推荐长期使用,仅作过渡) ```python import tensorflow as tf # 启用 TensorFlow 1.x 兼容模式 tf.compat.v1.disable_eager_execution() # 仅需在需要旧版行为时调用 # 使用旧版 API learning_rate = tf.compat.v1.train.exponential_decay( learning_rate=0.1, global_step=global_step, decay_steps=1000, decay_rate=0.96, staircase=True ) ``` --- #### 3. **验证 TensorFlow 版本** ```python import tensorflow as tf print(tf.__version__) # 确保输出为 2.x.x ``` 如果版本低于 2.x,需升级: ```bash pip install --upgrade tensorflow ``` --- #### 4. **其他潜在问题排查** - **错误调用优化器**:如果代码中存在 `tf.train.Optimizer`,需替换为 `tf.keras.optimizers`[^5]。 - **混合新旧 API**:避免同时使用 `tf.Session` 或 `tf.placeholder` 等已弃用 API[^4]。 ---
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值