Tensorflow 2.* 网络训练(三) keras.callbacks 回调函数

回调函数在TensorFlow 2.x的模型训练中扮演重要角色,能够监控性能、中断训练、调整学习率等。本文详细介绍了BaseLogger、CSVLogger、EarlyStopping、LearningRateScheduler、ModelCheckpoint、ReduceLROnPlateau等内置回调,并提供自定义Callback的示例,同时提及了TensorBoard的使用和如何在model.fit中应用回调。

回调函数(callback)是在调用fit 时传入模型的一个对象(即实现特定方法的类实例),它在训练过程中的不同时间点都会被模型调用。它可以访问关于模型状态与性能的所有可用数据,还可以采取行动:中断训练、保存模型、加载一组不同的权重或改变模型的状态。
keras中也提供了丰富的回调API,我们可以根据需求自定义相关的对象。

BaseLogger

顾名思义,基础日志,用于记录每个epoch的平均metrics.
该回调函数在每个模型中都会被自动调用

CSVLogger

记录每个epoch的结果到csv文件中

csvlogger = tf.keras.callbacks.CSVLogger(filename, separator=',', append=False)
参数 注解
fiename 保存的csv文件名,如run/log.csv
separator 字符串,csv分隔符
append 默认为False,为True时csv文件如果存在则继续写入,为False时总是覆盖csv文件

结果如下:包括训练集和验证集合的loss,以及学习率
在这里插入图片描述

EarlyStopping

当metric停止提升时,停止训练

earlystoppoing = tf.keras.callbacks.EarlyStopping(
				    monitor='val_loss', min_delta=0, patience=0, verbose=0, mode='auto',
				    baseline=None, restore_best_weights=False)
参数 注解
monitor 监视标准
min_delta 监控标准在训练过程中允许的最小改变量,即小于此值后便认为性能没有提升
patience 相比上一个epoch训练监控标准没有提升(小于min_delta),则经过patience个epoch后停止训练(还是很形象的patience 耐心)
verbose 信息展示模式
mode ‘auto’,‘min’,‘max’之一,与monitor对应。比如loss对应min,acc对应max;auto根据monitor的名称自动定义
baseline 监控标准的基准线,当训练过程相对基准没有提升则停止训练
restore_best_weights 布尔型,Ture,重载最好的监控量的epoch对应的weight,False,最后一步的权重

LearningRateScheduler

定义学习率日程表,即自定义不同epcoh的学习率

# 前十组为0.001 后面指数减少
def scheduler(epoch):
  if epoch < 10:
    return 0.001
  else:
    return 0.001 * tf.math.exp(0.1 * (10 - epoch))

learningreatescheduler = tf.keras.callbacks.LearningRateScheduler(scheduler)

ModelCheckpoint

按照定义以一定频率存储模型或者权值文件

modelcheckpoint=tf.keras.callbacks.ModelCheckpoint(
		    filepath, monitor='val_loss', verbose=0, save_best_only=False
你提到的导入语句: ```python from tensorflow.keras.callbacks import ... ``` 通常用于从 Keras 中导入回调函数Callbacks),例如 `ModelCheckpoint`, `EarlyStopping`, `TensorBoard` 等。这些回调在模型训练过程中用于监控、记录、保存模型状态或提前终止训练等。 --- ## ⚠️ 可能出现的警告 在 TensorFlow 2.16+ 中,你可能会看到类似如下警告: ``` WARNING:tensorflow:From /path/to/your/code.py:1: The name tf.keras.callbacks.ModelCheckpoint is deprecated. Please use tf.keras.callbacks.ModelCheckpoint instead. ``` ⚠️ 实际上这个警告并不影响功能,它只是提示某些 API 已被标记为“legacy”,因为 TensorFlow 正在逐步将内部 Keras 模块与独立的 [Keras](https://github.com/keras-team/keras) 包统一。 --- ## ✅ 推荐用法:标准导入方式(兼容性好) 目前大多数回调类仍然可以直接使用: ```python from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoard, ReduceLROnPlateau callbacks = [ ModelCheckpoint('model.h5', save_best_only=True), EarlyStopping(patience=3), TensorBoard(log_dir='./logs'), ReduceLROnPlateau(factor=0.2, patience=2) ] model.fit(..., callbacks=callbacks) ``` ✅ **注意:** 在 TF 2.16+ 中,虽然有些回调没有显式移入 `legacy` 模块,但它们的实现可能已经改用新的 Keras 核心机制,所以建议关注官方文档更新。 --- ## 🔁 如何避免警告信息? ### 方法一:使用 `tf.keras.callbacks.LegacyCallback` 如果你希望完全避免警告,可以尝试直接使用 `legacy` 模块(部分回调支持): ```python from tensorflow.keras.callbacks import legacy as legacy_callbacks # 示例: early_stopping = legacy_callbacks.EarlyStopping(patience=3) ``` > 注意:不是所有回调都支持 `.legacy` 路径,因此仍推荐使用标准路径导入。 --- ## 🧪 常见回调及其用途 | 回调名称 | 功能 | |----------|------| | `ModelCheckpoint` | 每隔一定 epoch 或最佳性能时保存模型 | | `EarlyStopping` | 当验证损失不再改善时提前停止训练 | | `TensorBoard` | 启动 TensorBoard 日志记录 | | `ReduceLROnPlateau` | 当性能停滞时自动降低学习率 | | `CSVLogger` | 将每个 epoch 的训练日志写入 CSV 文件 | | `LearningRateScheduler` | 自定义学习率调度器 | --- ## ✅ 示例代码:完整使用回调的训练流程 ```python import tensorflow as tf from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Dense from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping # 构建简单模型 model = Sequential([ Dense(64, activation='relu', input_shape=(10,)), Dense(1) ]) model.compile(optimizer='adam', loss='mse') # 准备数据 import numpy as np x = np.random.rand(1000, 10) y = np.random.rand(1000, 1) # 设置回调 callbacks = [ ModelCheckpoint('best_model.h5', save_best_only=True), EarlyStopping(patience=5) ] # 开始训练 model.fit(x, y, epochs=100, validation_split=0.2, callbacks=callbacks) ``` --- ##
评论 1
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值