Module: tf.keras.callbacks的例子

本文档介绍了如何在训练模型时使用tf.keras.callbacks模块,包括积累指标平均值、CSV记录、学习率调度等。通过示例展示了在fashion_mnist数据集上的应用,但因迭代次数限制未充分展示效果。训练完成后,模型会自动保存,并提供方法检查训练状态和使用TensorBoard进行可视化。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

在训练模型的过程中可能要做一些事。
文档:Module: tf.keras.callbacks
classes:
BaseLogger: Callback that accumulates epoch averages of metrics.
CSVLogger: Callback that streams epoch results to a csv file.
Callback: Abstract base class used to build new callbacks.
EarlyStopping: Stop training when a monitored quantity has stopped improving.
History: Callback that records events into a History object.
LambdaCallback: Callback for creating simple, custom callbacks on-the-fly.
LearningRateScheduler: Learning rate scheduler.
ModelCheckpoint: Save the model after every epoch.
ProgbarLogger: Callback that prints metrics to stdout.
ReduceLROnPlateau: Reduce learning rate when a metric has stopped improving.
RemoteMonitor: Callback used to stream events to a server.
TensorBoard: Enable visualizations for TensorBoard.
TerminateOnNaN: Callback that terminates training when a NaN loss is encountered.

记录使用EarlyStoppingModelCheckpointTensorBoard的方法

# callbacks 是在训练过程中进行监听的,因此是在fit函数中对callbacks进行添加
# 定义一个callbacks的数组,然后在fit函数中将callbacks的数组当作参数传进去

logdir = './callbacks'
if not os.path.exists(logdir):
    os.mkdir(logdir)
logdir = os.path.join("callbacks") # 好像只是在window上要添加这行,否则出错

output_model_file = os.path.join(logdir,
                                 "fashion_mnist_model.h5")  # 输出的model文件

callbacks = [
    keras.callbacks.TensorBoard(logdir),
    keras.callbacks.ModelCheckpoint(output_model_file, save_best_only=True),
    keras.callbacks.EarlyStopping(patience=5, min_delta=1e-3),
]

history = model.fit(
    x_train_scaled
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值