关于keras使用fit_generator中遇到StopIteration

Keras中使用fit_generator遇到报错:StopIteration

在keras中训练模型时,训练到一半报错StopIteration

遇到的问题

在这里插入图片描述

_________________________________________________________________
Epoch 1/1000
16107/16107 [==============================] - 15052s 935ms/step - loss: 2.5982 - acc: 0.8210 - val_loss: 2.8054 - val_acc: 0.7029
Epoch 2/1000
    1/16107 [..............................] - ETA: 4:10:08 - loss: 2.5140 - acc: 0.9133Traceback (most recent call last):
  File "train_generator_v1.py", line 104, in <module>
    run()
  File "train_generator_v1.py", line 97, in run
    validation_data=(X, Y), shuffle=True, callbacks=cbks)
  File "/home/limin/.conda/envs/py36-2/lib/python3.6/site-packages/keras/legacy/interfaces.py", line 91, in wrapper
    return func(*args, **kwargs)
  File "/home/limin/.conda/envs/py36-2/lib/python3.6/site-packages/keras/engine/training.py", line 1415, in fit_generator
    initial_epoch=initial_epoch)
  File "/home/limin/.conda/envs/py36-2/lib/python3.6/site-packages/keras/engine/training_generator.py", line 177, in fit_generator
    generator_output = next(output_generator)
  File "/home/limin/.conda/envs/py36-2/lib/python3.6/site-packages/keras/utils/data_utils.py", line 785, in get
    raise StopIteration()
StopIteration

解决办法

给自己自定义的生成器内部嵌套一个while True:

下面是keras官方使用文档中fit_generator的demo

def generate_arrays_from_file(path):
   while True: # 请注意位置
       with open(path) as f:
           for line in f:
               x1, x2, y = process_line(line)
               yield ({'input_1': x1, 'input_2': x2}, {'output': y})
       f.close()

model.fit_generator(generate_arrays_from_file('/my_file.txt'),steps_per_epoch=10000, epochs=10)

问题原理

我们希望知道为什么会出现这样的问题,就得先了解两个事情

什么是 StopIteration

python中,迭代是一个特殊的特征,这一方法__next__可以通过内置函数next(iterator)访问,当遍历结束时,迭代器协议会抛出StopIteration异常。

    def __next__(self):
        'Return the next item from the iterator. When exhausted, raise StopIteration'
        raise StopIteration

我们可以从文档中找到__next__的说明,当迭代器中已无下一个元素,抛出StopIteration

接下来,我们尝试抛出这个异常

创建一个简单的生成器,再进行多次迭代

>>>a = (i for i in [1,2])
>>>a.next()
1
>>>a.next()
2
>>>a.next()

Traceback (most recent call last):
File “”, line 1, in
StopIteration

生成器迭代两次后,在第三次抛出了异常,表明迭代器已无下一个元素,它就像把魔术帽中的鸽子放出来,预备好的鸽子没有的时候就会告诉你帽子空了这个操作无法执行。

为什么迭代结束的异常会在fit_generator的调用时出现

我们使用fit的时候,直接把x,y传入内存,每周期都在重复使用,但由于内存不足我们利用fit_generator来调用生成器获取数据。

fit_generator并没有选择每周期调用传入的生成器,而是选择了把fit_generator包装成一个GeneratorEnqueuer,在训练期间,不断调用GeneratorEnqueuer.get()来获取数据。

class GeneratorEnqueuer(SequenceEnqueuer):
    """Builds a queue out of a data generator.

    The provided generator can be finite in which case the class will throw
    a `StopIteration` exception.

    Used in `fit_generator`, `evaluate_generator`, `predict_generator`.
    """

我们可以从GeneratorEnqueuer的说明中看到,“提供的生成器可以是有限的,那样的话将会抛出StopIteration”,fit_generator, evaluate_generator, predict_generator三个方法同样通过GeneratorEnqueuer实现。

也许,我们还希望了解一下文档为什么这样去规范用法,我们从GeneratorEnqueuer._data_generator_task方法中看到GeneratorEnqueuer是加载数据的。

try:
    if (self.queue is not None and
            self.queue.qsize() < self.max_queue_size):
        # On all OSes, avoid **SYSTEMATIC** error
        # in multithreading mode:
        # `ValueError: generator already executing`
        # => Serialize calls to
        # infinite iterator/generator's next() function
        generator_output = next(self._generator)
        self.queue.put((True, generator_output))
    else:
        time.sleep(self.wait_time)
except StopIteration:
    break

先解释下变量,self.queueGeneratorEnqueuer的内置变量,self.max_queue_sizefit_generator参数列表传入的max_queue_size,默认为10

我们再一次看到了文档中有说明,需要无限的迭代器/生成器next()方法。

很明显,如果self._generaotr是个有限的生成器的时候,next(self._generator)最终必然抛出异常,try捕捉到异常后break

break?对,程序继续走下去了,但是在get方法中判断到问题,然后抛出异常,这是不就是大家追溯到抛出异常的地方

if all_finished and self.queue.empty():
    raise StopIteration()
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值