大部分使用 tensorflow 的同学会使用 fit() 或者 fit_generator() 方法训练模型, 这两个 api 对于刚接触深度学习的同学非常友好和方便,但是由于其是非常深度的封装,对于希望自定义训练过程的同学就显得不是那么方便,而且,对于 GAN 这种需要分步进行训练的模型,也无法直接使用 fit 或者 fit_generator 直接训练的。因此,tensorflow 提供了 train_on_batch 这个 api,对一个 mini-batch 的数据进行梯度更新。
总结优点如下:
- 更精细自定义训练过程,更精准的收集 loss 和 metrics
- 分步训练模型-GAN的实现
- 多GPU训练保存模型更加方便
- 更多样的数据加载方式
函数原型:
y_pred = Model.train_on_batch(
x,
y=