keras通过model.fit_generator训练模型(节省内存)

Keras分批训练技巧
本文详细介绍如何使用Keras的fit_generator函数进行分批训练,有效解决大规模数据集导致的内存溢出问题。通过生成器(generator)和Sequence类实现数据的动态加载,降低内存消耗,提升训练效率。


前言
前段时间在训练模型的时候,发现当训练集的数量过大,并且输入的图片维度过大时,很容易就超内存了,举个简单例子,如果我们有20000个样本,输入图片的维度是224x224x3,用float32存储,那么如果我们一次性将全部数据载入内存的话,总共就需要20000x224x224x3x32bit/8=11.2GB 这么大的内存,所以如果一次性要加载全部数据集的话是需要很大内存的。

如果我们直接用keras的fit函数来训练模型的话,是需要传入全部训练数据,但是好在提供了fit_generator,可以分批次的读取数据,节省了我们的内存,我们唯一要做的就是实现一个生成器(generator)。

1.fit_generator函数简介

fit_generator(generator, 
				steps_per_epoch=None, 
				epochs=1, 
				verbose=1, 
				callbacks=None, 
				validation_data=None, 
				validation_steps=None, 
				class_weight=None, 
				max_queue_size=10, 
				workers=1, 
				use_multiprocessing=False, 
				shuffle=True, 
				initial_epoch=0)

参数:

  • generator:一个生成器,或者一个 Sequence (keras.utils.Sequence) 对象的实例。这是我们实现的重点,后面会着介绍生成器和sequence的两种实现方式。
  • steps_per_epoch:这个是我们在每个epoch中需要执行多少次生成器来生产数据,fit_generator函数没有batch_size这个参数,是通过steps_per_epoch来实现的,每次生产的数据就是一个batch,因此steps_per_epoch的值我们通过会设为(样本数/batch_size)。如果我们的generator是sequence类型,那么这个参数是可选的,默认使用len(generator) 。
  • epochs:即我们训练的迭代次数。
  • verbose:0, 1 或 2。日志显示模式。 0 = 安静模式, 1 = 进度条, 2 = 每轮一行
  • callbacks:在训练时调用的一系列回调函数。
  • validation_data:和我们的generator类似,只是这个使用于验证的,不参与训练。
  • validation_steps:和前面的steps_per_epoch类似。
  • class_weight:可选的将类索引(整数)映射到权重(浮点)值的字典,用于加权损失函数(仅在训练期间)。 这可以用来告诉模型「更多地关注」来自代表性不足的类的样本。(感觉这个参数用的比较少
  • max_queue_size:整数。生成器队列的最大尺寸。默认为10.
  • workers:整数。使用的最大进程数量,如果使用基于进程的多线程。 如未指定,workers 将默认为 1。如果为 0,将在主线程上执行生成器。
  • use_multiprocessing:布尔值。如果 True,则使用基于进程的多线程。默认为False。
  • shuffle:是否在每轮迭代之前打乱 batch 的顺序。 只能与Sequence(keras.utils.Sequence) 实例同用。
  • initial_epoch: 开始训练的轮次(有助于恢复之前的训练)

2.generator实现

2.1生成器的实现方式

样例代码:

import keras
from keras.models import Sequential
from keras.layers import Dense
import numpy as np
from sklearn.model_selection import train_test_split
from PIL import Image

def process_x(path):
    img = Image.open(path)
    img = img.resize((96,96))
    img = img.convert('RGB')
    img = np.array(img
评论 4
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值