Callback
1. 先看一下callback类源码
class Callback(object):
def __init__(self):
self.validation_data = None
self.model = None
def set_params(self, params):
self.params = params
def set_model(self, model):
self.model = model
def on_epoch_begin(self, epoch, logs=None):
pass
def on_epoch_end(self, epoch, logs=None):
pass
def on_batch_begin(self, batch, logs=None):
pass
def on_batch_end(self, batch, logs=None):
pass
def on_train_begin(self, logs=None):
pass
def on_train_end(self, logs=None):
pass
我们在调用是可以这样写(保存最优模型)
# 回调类1
class Evaluator1(keras.callbacks.Callback):
def __init__(self):
self.best_val_acc = 0.
# 每迭代一次,调用一次
def on_epoch_end(self, epoch, logs=None):
val_acc = evaluate(valid_generator)
if val_acc > self.best_val_acc:
self.best_val_acc = val_acc
model.save_weights(r'data/best_model.weights')
test_acc = evaluate(test_generator)
print(
u'val_acc: %.5f, best_val_acc: %.5f, test_acc: %.5f\n' %
(val_acc, self.best_val_acc, test_acc)
)
2. 自带的回调方法
2.1 ModelCheckpoint
在每个训练期之后保存模型。相当于在on_epoch_end()中保存了模型
参数:

2.2 EarlyStopping
当被监测的数量不再提升,则停止训练。
参数:

2.3 TensorBoard
学习中…
本文介绍了在Keras中自定义回调函数Callback的使用,通过`Evaluator1`类展示了如何在每个训练周期结束时保存最佳模型,并评估验证集和测试集的准确性。同时提到了内置的回调方法如ModelCheckpoint用于模型保存,EarlyStopping用于提前终止训练以及TensorBoard用于可视化训练过程。这些回调方法有助于优化模型性能并监控训练状态。
313

被折叠的 条评论
为什么被折叠?



