从Keras源码看模型实现
本文以Keras自带的examples/addtion_rnn.py为例,theano为后台,分析Keras深度学习框架的源码,梳理模型训练的过程。(由于本人强行学习速成Keras,如有疏漏还望指出。)
从addtion_rnn运行至
model.fit(X_train,y_train,batch_size=BATCH_SIZE,nb_epoch=1,validation_data=(X_val,y_val))
函数(即开始训练)开始,我们来看看发生了什么。
model.fit实际上是class Sequential(Model)中的fit函数,分别再调用继承来的Model.fit函数。
Model.fit函数中比较重要的代码是下面几个部分:
self._make_train_function()该代码调用了class Model中的_make_train_function函数,这个函数的两个重要功能在于其中调用了
training_updates=self.optimizer.get_updates(self._collected_trainable_weights,self.constraints,self.total_loss)
其实看到optimizer就可以大概知道

本文通过分析Keras的addition_rnn示例,探讨Sequential模型的fit函数,讲解Model.fit如何调用_make_train_function、optimizer.get_gradients等,揭示深度学习模型训练中的反向传播和梯度下降过程。
最低0.47元/天 解锁文章
3702

被折叠的 条评论
为什么被折叠?



