基于Theano的深度学习(Deep Learning)框架Keras学习随笔-10-回调

本文介绍了Keras中的回调函数(Callbacks),它允许在模型训练过程中查看内部信息和统计数据。详细讲解了回调函数基类、内置回调如模型检查点(model checkpoints)的使用,以及如何创建自定义回调函数来记录损失率。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

        原地址:http://blog.youkuaiyun.com/niuwei22007/article/details/49229909

        Callbacks(回调函数)是一组用于在模型训练期间指定阶段被调用的函数。可以通过回调函数查看在模型训练过程中的模型内部信息和统计数据。可以通过传递一个回调函数list给fit()函数,然后相关的回调函数就可以在指定的阶段被调用了。

一、Callbacks基类

keras.callbacks.Callback()

属性:

  • Params:字典类型。训练参数(例如verbosity, batch size, numberof epochs…)
  • Model:keras.models.Model类型。被训练模型的引用。

方法:

  • on_train_begin(logs={}): 训练前调用
  • on_train_end(logs={}):训练后调用
  • on_epoch_begin(epoch, logs={}):第epoch次前调用
  • on_epoch_end(epoch, logs={}): 第epoch次后调用
  • on_batch_begin(batch, logs={}):第batch块前调用
  • on_batch_end(batch, logs={}):第batch块后调用

        上面方法中的logs是一个字典,会记录着当前batch或epoch的训练数据(比如误差率、准确率、batch_size等)。一般的,fit()方法会包含以下数据:

  • on_epoch_end: 记录会包含误差率(验证可用)、准确率(验证和精度监测可用)
  • on_batch_begin: 记录包含当前样本batch的size。
  • on_batch_end: 记录包含误差率,准确率(精度监测可用) 

二、可用的回调函数

keras.callbacks.ModelCheckpoint(filepath,verbose=0, save_best_only=False)

        用户每次epoch之后保存模型数据。如果save_best_only=True,则最近验证误差最好的模型数据会被保存下来。filepath是由epoch和logs的键构成的。比如filepath=weights.{epoch:02d}-{val_loss:.2f}.hdf5,那么会保存很多带有epoch和val_loss信息的文件;当然也可以是某个路径。

keras.callbacks.EarlyStopping(monitor='val_loss', patience=0, verbose=0)
        当monitor不再有改善的时候就会停止训练,这个可以通过patience看出来。  

三、创建自己的回调函数

        我们可以通过扩展基类keras.callbacks.Callback创建一个普通的回调函数。回调函数必须要通过self.model属性与训练model联系起来。下面是一个简单的回调函数,用于保存每次batch训练后的误差率:

classLossHistory(keras.callbacks.Callback):
    defon_train_begin(self, logs={}):
        self.losses = []
 
    defon_batch_end(self, batch, logs={}):
        self.losses.append(logs.get('loss'))

四、记录损失率实例介绍

# 用于记录损失率的回调函数
classLossHistory(keras.callbacks.Callback):
    defon_train_begin(self, logs={}):
        self.losses = []
 
    defon_batch_end(self, batch, logs={}):
        self.losses.append(logs.get('loss'))
# 定义一个模型
model = Sequential()
model.add(Dense(10, input_dim=784, init='uniform'))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
 
# 定义一个回调实例对象
history = LossHistory()
# 模型训练
model.fit(X_train,Y_train, batch_size=128, nb_epoch=20, verbose=0, callbacks=[history])
 
# 回调输出
print history.losses
'''
[0.66047596406559383, 0.3547245744908703,..., 0.25953155204159617, 0.25901699725311789]
''' 

五、内置回调函数(model checkpoints)实例介绍

from keras.callbacks import ModelCheckpoint
 
model = Sequential()
model.add(Dense(10, input_dim=784, init='uniform'))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
 
'''
每次epoch之后,如果验证误差减少,则保存模型数据。
'''
checkpointer =ModelCheckpoint(filepath="/tmp/weights.hdf5", verbose=1, save_best_only=True)
model.fit(X_train, Y_train, batch_size=128, nb_epoch=20, verbose=0, validation_data=(X_test, Y_test), callbacks=[checkpointer])
 


参考资料:

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值