Keras中Callbacks的使用
模块和方法
Modules
experimental
module: Public API for tf.keras.callbacks.experimental namespace.
Classes
class BaseLogger
: Callback that accumulates epoch averages of metrics. 记录epoch的累计平均metrics
class CSVLogger
: Callback that streams epoch results to a CSV file. 文件流将epoch结果写入csv
class Callback
: Abstract base class used to build new callbacks.
class CallbackList
: Container abstracting a list of callbacks.
class EarlyStopping
: Stop training when a monitored metric has stopped improving.
class History
: Callback that records events into a History
object.
class LambdaCallback
: Callback for creating simple, custom callbacks on-the-fly.
class LearningRateScheduler
: Learning rate scheduler.
class ModelCheckpoint
: Callback to save the Keras model or model weights at some frequency.
class ProgbarLogger
: Callback that prints metrics to stdout.
class ReduceLROnPlateau
: Reduce learning rate when a metric has stopped improving.
class RemoteMonitor
: Callback used to stream events to a server.
class TensorBoard
: Enable visualizations for TensorBoard.
class TerminateOnNaN
: Callback that terminates training when a NaN loss is encountered.
比较常用的方法
ModelCheckpoint
作用:Callback to save the Keras model or model weights at some frequency.
回调以保存某个频率(epoch)下的Keras模型或者模型的权重。
例子1
#初始化模型
model = Classifier()
#设置保存路径
checkpoint_filepath = 'E:/Python_Workspace/Saved_models/checkpoint'
#设置 modelcheckpoint参数
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath = checkpoint_filepath,
save_weights_only=True,
monitor='val_accuracy',
mode='max',
save_best_only = True)
#编译模型
model.compile(optimizer=keras.optimizers.Adam(learning_rate=0.001),loss=keras.losses.SparseCategoricalCrossentropy(),metrics=['accuracy'])
#训练模型
history = model.fit(train_x,train_y,epochs=500,batch_size=256,validation_data=(val_x, val_y),callbacks=model_checkpoint_callback)
# 用保存的权重恢复模型
model.load_weights(checkpoint_filepath)
例子 2 训练一个新模型,每五个 epochs 保存一次唯一命名的 checkpoint :
# Include the epoch in the file name (uses `str.format`)
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
batch_size = 32
# Create a callback that saves the model's weights every 5 epochs
cp_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_path,
verbose=1,
save_weights_only=True,
save_freq=5*batch_size)
# Create a new model instance
model = create_model()
# Save the weights using the `checkpoint_path` format
model.save_weights(checkpoint_path.format(epoch=0))
# Train the model with the new callback
model.fit(train_images,
train_labels,
epochs=50,
batch_size=batch_size,
callbacks=[cp_callback],
validation_data=(test_images, test_labels),
verbose=0)
latest = tf.train.latest_checkpoint(checkpoint_dir)
# Create a new model instance
model = create_model()
# Load the previously saved weights
model.load_weights(latest)
# Re-evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
参数说明
tf.keras.callbacks.ModelCheckpoint(
filepath, monitor='val_loss', verbose=0, save_best_only=False,
save_weights_only=False, mode='auto', save_freq='epoch',
options=None, **kwargs
)
filepath 保存模型文件的路径
monitor 监测的metric的名称 如 loss ,val_loss,acc,val_acc 一般是这几个。
history.historyhistory = model.fit()
可以通过查看metric有哪些
verbose 用于记录训练日志 verbose = 0 为不在标准输出流输出日志信息
verbose = 1 为输出进度条记录
verbose = 2 为每个epoch输出一行记录
save_best_only 仅需要保存最佳模型 true 或者false
mode {‘auto’, ‘min’, ‘max’} 监测量的目标最大或者最小
save_weights_only 是否保存完整的模型
save_freq 保存模型或者权重的时刻 默认是epoch,即每个epoch结束时保存 ,若设置为integer类型,则是N个批次后保存
保存的文件
若保存权重则会出现以下文件:
这些文件是什么?
上述代码将权重存储到 checkpoint—— 格式化文件的集合中,这些文件仅包含二进制格式的训练权重。 Checkpoints 包含:
- 一个或多个包含模型权重的分片。
- 一个索引文件,指示哪些权重存储在哪个分片中。
如果您在一台计算机上训练模型,您将获得一个具有如下后缀的分片:.data-00000-of-00001
保存整个模型
整个模型可以保存为两种不同的文件格式(SavedModel
和 HDF5
)。TensorFlow SavedModel
格式是 TF2.x 中的默认文件格式。但是,模型能够以 HDF5
格式保存。
Keras使用 HDF5标准提供了一种基本的保存格式。
加载模型可以使用
new_model = tf.keras.models.load_model('my_model.h5')
保存自定义模型
odel(‘my_model.h5’)
[保存自定义模型](https://tensorflow.google.cn/tutorials/keras/save_and_load?hl=zh-cn#%E4%BF%9D%E5%AD%98%E8%87%AA%E5%AE%9A%E4%B9%89%E5%AF%B9%E8%B1%A1)
后续再更==