使用回调
Keras中的回调是在训练期间(在epoch开始时,batch结束时,epoch结束时等)与不同时间点调用的对象,可用于实现以下行为:
- 在训练期间的不同时间点进行验证(除了内置的按时间段验证)
- 定期检查模型是否超过某个精度阈值
- 在训练似乎停滞不前时,改变模型的学习率
- 在训练似乎停滞不前时,对顶层进行微调
- 训练结束或超出某个性能阈值时发送电子邮件或即时消息通知等等
可使用的内置回调有:
- ModelCheckpoint:定期保存模型
- EarlyStopping:当训练不再改进验证指标时停止培训
- TensorBoard:定期编写可在TensorBoard中显示的模型日志(更多细节见“可视化”)
- CSVLogger:将丢失和指标数据流式传输到CSV文件
例1 提前终止
model = get_compiled_model()
# 1)提前终止
callbacks = [
keras.callbacks.EarlyStopping(
# 不再提升的关注指标
monitor='val_loss',
# 不再提升的阈值
min_delta=1e-2,
# 不再提升的轮次
patience=2,
verbose=1)
]
model.fit(x_train, y_train,
epochs=20,
batch_size=64,
callbacks=callbacks,
validation_split=0.2)
EarlyStopping使用
- monitor: 监控的数据接口,有’acc’,’val_acc’,’loss’,’val_loss’等等。正常情况下如果有验证集,就用’val_acc’或者’val_loss’。例如monitor是’acc’,同时其变化范围在70%-90%之间,所以对于小于0.01%的变化不关心。加上观察到训练过程中存在抖动的情况(即先下降后上升),所以可以适当增大容忍程度。
- min_delta:增大或减小的阈值,只有大于这个部分才算作improvement。这个值的大小取决于monitor,也反映了容忍程度。
- patience:能够容忍多少个epoch内都没有improvement。这个设置其实是在抖动和真正的准确率下降之间做tradeoff。如果patience设的大,那么最终得到的准确率要略低于模型可以达到的最高准确率。如果patience设的小,那么模型很可能在前期抖动,还在全图搜索的阶段就停止了,准确率一般很差。patience的大小和learning rate直接相关。在learning rate设定的情况下,前期先训练几次观察抖动的epoch number,比其稍大些设置patience。在learning rate变化的情况下,建议要略小于最大的抖动epoch number。
- mode: 就’auto’, ‘min’, ‘,max’三个可能。如果知道是要上升还是下降,建议设置一下。如monitor是’acc’,所以mode=’max’。
例2 模型保存
# checkpoint模型回调
model = get_compiled_model()
check_callback = keras.callbacks.ModelCheckpoint(
# 模型路径
filepath='mymodel_{epoch}.h5',
# 是否保存最佳
save_best_only=True,
# 监控指标
monitor='val_loss',
# 进度条类型
verbose=1
)
model.fit(x_train, y_train,
epochs=3,
batch_size=64,
callbacks=[check_callback],
validation_split=0.2)
例3 学习率调整
# 动态调整学习率
initial_learning_rate = 0.1
lr_schedule = keras.optimizers.schedules.ExponentialDecay(
# 初始学习率
initial_learning_rate,
# 延迟步数
decay_steps=10000,
# 调整百分比
decay_rate=0.96,
staircase=True
)
optimizer = keras.optimizers.RMSprop(learning_rate=lr_schedule)
model.compile(
optimizer=optimizer,
loss=keras.losses.SparseCategoricalCrossentropy(),
metrics=[keras.metrics.SparseCategoricalAccuracy()])
model.fit(x_train, y_train,
epochs=3,
batch_size=64,
callbacks=[check_callback],
validation_split=0.2)
例4 训练可视化
# 使用tensorboard
tensorboard_cbk = keras.callbacks.TensorBoard(log_dir='.\logs')
model = get_compiled_model()
model.fit(x_train, y_train,
epochs=5,
batch_size=64,
callbacks=[tensorboard_cbk],
validation_split=0.2)
运行如下命令
tensorboard --logdir=/full_path_to_your_logs