Keras中Callbacks的使用

本文详细介绍了Keras库中的各种Callback类,如ModelCheckpoint用于定期保存模型和权重,EarlyStopping监控指标停止训练,TensorBoard提供可视化支持。通过实例演示了如何使用ModelCheckpoint在每个epoch和自定义命名checkpoint保存模型。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

Keras中Callbacks的使用

模块和方法

Modules

experimental module: Public API for tf.keras.callbacks.experimental namespace.

Classes

class BaseLogger: Callback that accumulates epoch averages of metrics. 记录epoch的累计平均metrics

class CSVLogger: Callback that streams epoch results to a CSV file. 文件流将epoch结果写入csv

class Callback: Abstract base class used to build new callbacks.

class CallbackList: Container abstracting a list of callbacks.

class EarlyStopping: Stop training when a monitored metric has stopped improving.

class History: Callback that records events into a History object.

class LambdaCallback: Callback for creating simple, custom callbacks on-the-fly.

class LearningRateScheduler: Learning rate scheduler.

class ModelCheckpoint: Callback to save the Keras model or model weights at some frequency.

class ProgbarLogger: Callback that prints metrics to stdout.

class ReduceLROnPlateau: Reduce learning rate when a metric has stopped improving.

class RemoteMonitor: Callback used to stream events to a server.

class TensorBoard: Enable visualizations for TensorBoard.

class TerminateOnNaN: Callback that terminates training when a NaN loss is encountered.

比较常用的方法

ModelCheckpoint

作用:Callback to save the Keras model or model weights at some frequency.

回调以保存某个频率(epoch)下的Keras模型或者模型的权重。

例子1

#初始化模型
model = Classifier()
#设置保存路径
checkpoint_filepath = 'E:/Python_Workspace/Saved_models/checkpoint'
#设置 modelcheckpoint参数
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath = checkpoint_filepath,
    save_weights_only=True,
    monitor='val_accuracy',
    mode='max',
    save_best_only = True)
#编译模型
model.compile(optimizer=keras.optimizers.Adam(learning_rate=0.001),loss=keras.losses.SparseCategoricalCrossentropy(),metrics=['accuracy'])
#训练模型
history = model.fit(train_x,train_y,epochs=500,batch_size=256,validation_data=(val_x, val_y),callbacks=model_checkpoint_callback)
# 用保存的权重恢复模型 
model.load_weights(checkpoint_filepath)

例子 2 训练一个新模型,每五个 epochs 保存一次唯一命名的 checkpoint :

# Include the epoch in the file name (uses `str.format`)
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

batch_size = 32

# Create a callback that saves the model's weights every 5 epochs
cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path, 
    verbose=1, 
    save_weights_only=True,
    save_freq=5*batch_size)

# Create a new model instance
model = create_model()

# Save the weights using the `checkpoint_path` format
model.save_weights(checkpoint_path.format(epoch=0))

# Train the model with the new callback
model.fit(train_images, 
          train_labels,
          epochs=50, 
          batch_size=batch_size, 
          callbacks=[cp_callback],
          validation_data=(test_images, test_labels),
          verbose=0)

latest = tf.train.latest_checkpoint(checkpoint_dir)
# Create a new model instance
model = create_model()

# Load the previously saved weights
model.load_weights(latest)

# Re-evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
参数说明
tf.keras.callbacks.ModelCheckpoint(
    filepath, monitor='val_loss', verbose=0, save_best_only=False,
    save_weights_only=False, mode='auto', save_freq='epoch',
    options=None, **kwargs
)

filepath 保存模型文件的路径

monitor 监测的metric的名称 如 loss ,val_loss,acc,val_acc 一般是这几个。

history.historyhistory = model.fit()

可以通过查看metric有哪些

verbose 用于记录训练日志 verbose = 0 为不在标准输出流输出日志信息
verbose = 1 为输出进度条记录
verbose = 2 为每个epoch输出一行记录

save_best_only 仅需要保存最佳模型 true 或者false

mode {‘auto’, ‘min’, ‘max’} 监测量的目标最大或者最小

save_weights_only 是否保存完整的模型

save_freq 保存模型或者权重的时刻 默认是epoch,即每个epoch结束时保存 ,若设置为integer类型,则是N个批次后保存

保存的文件

若保存权重则会出现以下文件:
在这里插入图片描述

这些文件是什么?

上述代码将权重存储到 checkpoint—— 格式化文件的集合中,这些文件仅包含二进制格式的训练权重。 Checkpoints 包含:

  • 一个或多个包含模型权重的分片。
  • 一个索引文件,指示哪些权重存储在哪个分片中。

如果您在一台计算机上训练模型,您将获得一个具有如下后缀的分片:.data-00000-of-00001

保存整个模型

整个模型可以保存为两种不同的文件格式(SavedModelHDF5)。TensorFlow SavedModel 格式是 TF2.x 中的默认文件格式。但是,模型能够以 HDF5 格式保存。

Keras使用 HDF5标准提供了一种基本的保存格式。

加载模型可以使用

new_model = tf.keras.models.load_model('my_model.h5')

保存自定义模型

odel(‘my_model.h5’)


​    [保存自定义模型](https://tensorflow.google.cn/tutorials/keras/save_and_load?hl=zh-cn#%E4%BF%9D%E5%AD%98%E8%87%AA%E5%AE%9A%E4%B9%89%E5%AF%B9%E8%B1%A1)
后续再更==
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值