Keras中使用fit_generator遇到报错:StopIteration
在keras中训练模型时,训练到一半报错StopIteration
遇到的问题
_________________________________________________________________
Epoch 1/1000
16107/16107 [==============================] - 15052s 935ms/step - loss: 2.5982 - acc: 0.8210 - val_loss: 2.8054 - val_acc: 0.7029
Epoch 2/1000
1/16107 [..............................] - ETA: 4:10:08 - loss: 2.5140 - acc: 0.9133Traceback (most recent call last):
File "train_generator_v1.py", line 104, in <module>
run()
File "train_generator_v1.py", line 97, in run
validation_data=(X, Y), shuffle=True, callbacks=cbks)
File "/home/limin/.conda/envs/py36-2/lib/python3.6/site-packages/keras/legacy/interfaces.py", line 91, in wrapper
return func(*args, **kwargs)
File "/home/limin/.conda/envs/py36-2/lib/python3.6/site-packages/keras/engine/training.py", line 1415, in fit_generator
initial_epoch=initial_epoch)
File "/home/limin/.conda/envs/py36-2/lib/python3.6/site-packages/keras/engine/training_generator.py", line 177, in fit_generator
generator_output = next(output_generator)
File "/home/limin/.conda/envs/py36-2/lib/python3.6/site-packages/keras/utils/data_utils.py", line 785, in get
raise StopIteration()
StopIteration
解决办法
给自己自定义的生成器内部嵌套一个while True:
下面是keras官方使用文档中fit_generator的demo
def generate_arrays_from_file(path): while True: # 请注意位置 with open(path) as f: for line in f: x1, x2, y = process_line(line) yield ({'input_1': x1, 'input_2': x2}, {'output': y}) f.close() model.fit_generator(generate_arrays_from_file('/my_file.txt'),steps_per_epoch=10000, epochs=10)
问题原理
我们希望知道为什么会出现这样的问题,就得先了解两个事情
什么是 StopIteration
python中,迭代是一个特殊的特征,这一方法__next__
可以通过内置函数next(iterator)
访问,当遍历结束时,迭代器协议会抛出StopIteration异常。
def __next__(self):
'Return the next item from the iterator. When exhausted, raise StopIteration'
raise StopIteration
我们可以从文档中找到__next__
的说明,当迭代器中已无下一个元素,抛出StopIteration。
接下来,我们尝试抛出这个异常
创建一个简单的生成器,再进行多次迭代
>>>a = (i for i in [1,2])
>>>a.next()
1
>>>a.next()
2
>>>a.next()
Traceback (most recent call last):
File “”, line 1, in
StopIteration
生成器迭代两次后,在第三次抛出了异常,表明迭代器已无下一个元素,它就像把魔术帽中的鸽子放出来,预备好的鸽子没有的时候就会告诉你帽子空了这个操作无法执行。
为什么迭代结束的异常会在fit_generator
的调用时出现
我们使用fit
的时候,直接把x,y传入内存,每周期都在重复使用,但由于内存不足我们利用fit_generator
来调用生成器获取数据。
fit_generator
并没有选择每周期调用传入的生成器,而是选择了把fit_generator
包装成一个GeneratorEnqueuer
,在训练期间,不断调用GeneratorEnqueuer.get()
来获取数据。
class GeneratorEnqueuer(SequenceEnqueuer):
"""Builds a queue out of a data generator.
The provided generator can be finite in which case the class will throw
a `StopIteration` exception.
Used in `fit_generator`, `evaluate_generator`, `predict_generator`.
"""
我们可以从GeneratorEnqueuer
的说明中看到,“提供的生成器可以是有限的,那样的话将会抛出StopIteration
”,fit_generator
, evaluate_generator
, predict_generator
三个方法同样通过GeneratorEnqueuer
实现。
也许,我们还希望了解一下文档为什么这样去规范用法,我们从GeneratorEnqueuer._data_generator_task
方法中看到GeneratorEnqueuer
是加载数据的。
try:
if (self.queue is not None and
self.queue.qsize() < self.max_queue_size):
# On all OSes, avoid **SYSTEMATIC** error
# in multithreading mode:
# `ValueError: generator already executing`
# => Serialize calls to
# infinite iterator/generator's next() function
generator_output = next(self._generator)
self.queue.put((True, generator_output))
else:
time.sleep(self.wait_time)
except StopIteration:
break
先解释下变量,self.queue
是GeneratorEnqueuer
的内置变量,self.max_queue_size
是fit_generator
参数列表传入的max_queue_size
,默认为10
我们再一次看到了文档中有说明,需要无限的迭代器/生成器next()
方法。
很明显,如果self._generaotr是个有限的生成器的时候,next(self._generator)
最终必然抛出异常,try捕捉到异常后break
break?对,程序继续走下去了,但是在get方法中判断到问题,然后抛出异常,这是不就是大家追溯到抛出异常的地方
if all_finished and self.queue.empty():
raise StopIteration()