keras通过model.fit_generator训练模型(节省内存)
前言
前段时间在训练模型的时候,发现当训练集的数量过大,并且输入的图片维度过大时,很容易就超内存了,举个简单例子,如果我们有20000个样本,输入图片的维度是224x224x3,用float32存储,那么如果我们一次性将全部数据载入内存的话,总共就需要20000x224x224x3x32bit/8=11.2GB 这么大的内存,所以如果一次性要加载全部数据集的话是需要很大内存的。
如果我们直接用keras的fit函数来训练模型的话,是需要传入全部训练数据,但是好在提供了fit_generator,可以分批次的读取数据,节省了我们的内存,我们唯一要做的就是实现一个生成器(generator)。
1.fit_generator函数简介
fit_generator(generator,
steps_per_epoch=None,
epochs=1,
verbose=1,
callbacks=None,
validation_data=None,
validation_steps=None,
class_weight=None,
max_queue_size=10,
workers=1,
use_multiprocessing=False,
shuffle=True,
initial_epoch=0)
参数:
- generator:一个生成器,或者一个 Sequence (keras.utils.Sequence) 对象的实例。这是我们实现的重点,后面会着介绍生成器和sequence的两种实现方式。
- steps_per_epoch:这个是我们在每个epoch中需要执行多少次生成器来生产数据,fit_generator函数没有batch_size这个参数,是通过steps_per_epoch来实现的,每次生产的数据就是一个batch,因此steps_per_epoch的值我们通过会设为(样本数/batch_size)。如果我们的generator是sequence类型,那么这个参数是可选的,默认使用len(generator) 。
- epochs:即我们训练的迭代次数。
- verbose:0, 1 或 2。日志显示模式。 0 = 安静模式, 1 = 进度条, 2 = 每轮一行
- callbacks:在训练时调用的一系列回调函数。
- validation_data:和我们的generator类似,只是这个使用于验证的,不参与训练。
- validation_steps:和前面的steps_per_epoch类似。
- class_weight:可选的将类索引(整数)映射到权重(浮点)值的字典,用于加权损失函数(仅在训练期间)。 这可以用来告诉模型「更多地关注」来自代表性不足的类的样本。(感觉这个参数用的比较少)
- max_queue_size:整数。生成器队列的最大尺寸。默认为10.
- workers:整数。使用的最大进程数量,如果使用基于进程的多线程。 如未指定,workers 将默认为 1。如果为 0,将在主线程上执行生成器。
- use_multiprocessing:布尔值。如果 True,则使用基于进程的多线程。默认为False。
- shuffle:是否在每轮迭代之前打乱 batch 的顺序。 只能与Sequence(keras.utils.Sequence) 实例同用。
- initial_epoch: 开始训练的轮次(有助于恢复之前的训练)
2.generator实现
2.1生成器的实现方式
样例代码:
import keras
from keras.models import Sequential
from keras.layers import Dense
import numpy as np
from sklearn.model_selection import train_test_split
from PIL import Image
def process_x(path):
img = Image.open(path)
img = img.resize((96,96))
img = img.convert('RGB')
img = np.array(img
Keras分批训练技巧

本文详细介绍如何使用Keras的fit_generator函数进行分批训练,有效解决大规模数据集导致的内存溢出问题。通过生成器(generator)和Sequence类实现数据的动态加载,降低内存消耗,提升训练效率。
最低0.47元/天 解锁文章





