在训练模型的过程中可能要做一些事。
文档:Module: tf.keras.callbacks
classes:
BaseLogger
: Callback that accumulates epoch averages of metrics.
CSVLogger
: Callback that streams epoch results to a csv file.
Callback
: Abstract base class used to build new callbacks.
EarlyStopping
: Stop training when a monitored quantity has stopped improving.
History
: Callback that records events into a History object.
LambdaCallback
: Callback for creating simple, custom callbacks on-the-fly.
LearningRateScheduler
: Learning rate scheduler.
ModelCheckpoint
: Save the model after every epoch.
ProgbarLogger
: Callback that prints metrics to stdout.
ReduceLROnPlateau
: Reduce learning rate when a metric has stopped improving.
RemoteMonitor
: Callback used to stream events to a server.
TensorBoard
: Enable visualizations for TensorBoard.
TerminateOnNaN
: Callback that terminates training when a NaN loss is encountered.
记录使用EarlyStopping
、ModelCheckpoint
、TensorBoard
的方法
# callbacks 是在训练过程中进行监听的,因此是在fit函数中对callbacks进行添加
# 定义一个callbacks的数组,然后在fit函数中将callbacks的数组当作参数传进去
logdir = './callbacks'
if not os.path.exists(logdir):
os.mkdir(logdir)
logdir = os.path.join("callbacks") # 好像只是在window上要添加这行,否则出错
output_model_file = os.path.join(logdir,
"fashion_mnist_model.h5") # 输出的model文件
callbacks = [
keras.callbacks.TensorBoard(logdir),
keras.callbacks.ModelCheckpoint(output_model_file, save_best_only=True),
keras.callbacks.EarlyStopping(patience=5, min_delta=1e-3),
]
history = model.fit(
x_train_scaled