keras.callback fit_generator

本文深入解析了Keras中fit_generator函数的使用方法与训练逻辑,包括参数解释、回调函数的作用及模型训练流程。理解如何利用生成器进行高效的数据加载与训练。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1.fit_generator

fit_generator(self, generator, samples_per_epoch, nb_epoch, verbose=1, callbacks=[], validation_data=None, nb_val_samples=None, class_weight=None, max_q_size=10)

函数的参数是:

generator:生成器函数,生成器的输出应该为:
一个形如(inputs,targets)的tuple

一个形如(inputs, targets,sample_weight)的tuple。所有的返回值都应该包含相同数目的样本。生成器将无限在数据集上循环。每个epoch以经过模型的样本数达到samples_per_epoch时,记一个epoch结束

    def next_train(self):
        while 1:
            ret = self.get_batch(self.cur_train_index, self.minibatch_size, train=True)
            self.cur_train_index += self.minibatch_size
            if self.cur_train_index >= self.val_split:
                self.cur_train_index = self.cur_train_index % 32
                (self.X_text, self.Y_data, self.Y_len) = shuffle_mats_or_lists(
                    [self.X_text, self.Y_data, self.Y_len], self.val_split)
            yield ret

这里写成了一个死循环while True,因为model.fit_generator()在使用在个函数的时候, 并不会在每一个epoch之后重新调用,那么如果这时候generator自己结束了就会有问题。

samples_per_epoch:整数,当模型处理的样本达到此数目时计一个epoch结束,执行下一个epoch

verbose:日志显示,0为不在标准输出流输出日志信息,1为输出进度条记录,2为每个epoch输出一行记录

validation_data:具有以下三种形式之一

生成验证集的生成器

一个形如(inputs,targets)的tuple

一个形如(inputs,targets,sample_weights)的tuple

nb_val_samples:仅当validation_data是生成器时使用,用以限制在每个epoch结束时用来验证模型的验证集样本数,功能类似于samples_per_epoch

max_q_size:生成器队列的最大容量

函数返回一个History对象

2.fit_generator 训练逻辑过程

model.fit_generator 训练入口函数(参考上面的函数原型定义)

   callbacks.on_train_begin()
     while epoch < epochs:
             callbacks.on_epoch_begin(epoch)
             while steps_done < steps_per_epoch:
             	#generator_output是一个死循环while True,因为model.fit_generator()在使用在个函数的时候, 并不会在每一个epoch之后重新调用,那么如果这时候generator自己结束了就会有问题。
                 generator_output = next(output_generator)       #生成器next函数取输入数据进行训练,每次取一个batch大小的量
                 callbacks.on_batch_begin(batch_index, batch_logs)
                 outs = self.train_on_batch(x, y,sample_weight=sample_weight,class_weight=class_weight)
                 callbacks.on_batch_end(batch_index, batch_logs)	
              end of while steps_done < steps_per_epoch	
              self.evaluate_generator(...)          #当一个epoch的最后一次batch执行完毕,执行一次训练效果的评估	
              callbacks.on_epoch_end(epoch, epoch_logs)          #在这个执行过程中实现模型数据的保存操作
     end of while epoch < epochs	
     callbacks.on_train_end()
``
# 回调函数
通过传递回调函数列表到模型的.fit()中,即可在给定的训练阶段调用该函数集中的函数。eras的回调函数是一个类

```python

keras.callbacks.Callback()

这是回调函数的抽象类,定义新的回调函数必须继承自该类

3.类属性

params:字典,训练参数集(如信息显示方法verbosity,batch大小,epoch数)

model:keras.models.Model对象,为正在训练的模型的引用回调函数以字典logs为参数,该字典包含了一系列与当前batch或epoch相关的信息。

目前,模型的.fit()中有下列参数会被记录到logs中:

在每个epoch的结尾处(on_epoch_end),logs将包含训练的正确率和误差,acc和loss,如果指定了验证集,还会包含验证集正确率和误差val_acc)和val_loss,val_acc还额外需要在.compile中启用metrics=[‘accuracy’]。

在每个batch的开始处(on_batch_begin):logs包含size,即当前batch的样本数

在每个batch的结尾处(on_batch_end):logs包含loss,若启用accuracy则还包含acc

``` import os from keras.models import Model from keras.layers import Dense, Dropout from keras.applications.inception_resnet_v2 import InceptionResNetV2 from keras.callbacks import ModelCheckpoint, TensorBoard from keras.optimizers import Adam from keras import backend as K from utils.data_loader import train_generator, val_generator ''' Below is a modification to the TensorBoard callback to perform batchwise writing to the tensorboard, instead of only at the end of the batch. ''' class TensorBoardBatch(TensorBoard): def __init__(self, *args, **kwargs): super(TensorBoardBatch, self).__init__(*args) # conditionally import tensorflow iff TensorBoardBatch is created self.tf = __import__('tensorflow') def on_batch_end(self, batch, logs=None): logs = logs or {} for name, value in logs.items(): if name in ['batch', 'size']: continue summary = self.tf.Summary() summary_value = summary.value.add() summary_value.simple_value = value.item() summary_value.tag = name self.writer.add_summary(summary, batch) self.writer.flush() def on_epoch_end(self, epoch, logs=None): logs = logs or {} for name, value in logs.items(): if name in ['batch', 'size']: continue summary = self.tf.Summary() summary_value = summary.value.add() summary_value.simple_value = value.item() summary_value.tag = name self.writer.add_summary(summary, epoch * self.batch_size) self.writer.flush() def earth_mover_loss(y_true, y_pred): cdf_ytrue = K.cumsum(y_true, axis=-1) cdf_ypred = K.cumsum(y_pred, axis=-1) samplewise_emd = K.sqrt(K.mean(K.square(K.abs(cdf_ytrue - cdf_ypred)), axis=-1)) return K.mean(samplewise_emd) image_size = 224 base_model = InceptionResNetV2(input_shape=(image_size, image_size, 3), include_top=False, pooling='avg') for layer in base_model.layers: layer.trainable = False x = Dropout(0.75)(base_model.output) x = Dense(10, activation='softmax')(x) model = Model(base_model.input, x) model.summary() optimizer = Adam(lr=1e-3) model.compile(optimizer, loss=earth_mover_loss) # load weights from trained model if it exists if os.path.exists('weights/inception_resnet_weights.h5'): model.load_weights('weights/inception_resnet_weights.h5') # load pre-trained NIMA(Inception ResNet V2) classifier weights # if os.path.exists('weights/inception_resnet_pretrained_weights.h5'): # model.load_weights('weights/inception_resnet_pretrained_weights.h5', by_name=True) checkpoint = ModelCheckpoint('weights/inception_resnet_weights.h5', monitor='val_loss', verbose=1, save_weights_only=True, save_best_only=True, mode='min') tensorboard = TensorBoardBatch() callbacks = [checkpoint, tensorboard] batchsize = 100 epochs = 20 model.fit_generator(train_generator(batchsize=batchsize), steps_per_epoch=(250000. // batchsize), epochs=epochs, verbose=1, callbacks=callbacks, validation_data=val_generator(batchsize=batchsize), validation_steps=(5000. // batchsize))```我要使用NIMA测算图片的美学评分和技术评分,但是需要根据美学和技术分别先训练出他们的权重,请问上述代码怎么训练美学评分的权重
最新发布
03-18
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值