1. solver.cpp中,主要是Solver的初始化函数Solver::Init(const SolverParameter& param),以及训练网络net_的初始化InitTrainNets() ,和测试网络test_nets_初始化InitTestNets() 。还有就是关于手动中断训练的相关函数GetRequestedAction()等。最最重要的,应该是Solver::solve(),而Solver::solve()中最主要的是step()函数。
2. void Solver<Dtype>::step(int iters)
这一部分主要包括:
(1)
TestAll(); //通过前向传播计算测试的loss,和想要的accuracy。
(2)
for (int i = 0; i < param_.iter_size(); ++i) {
loss += net_->ForwardBackward(); //通过前向后向传播,计算loss 和 梯度.
}
loss /= param_.iter_size();
(3)
ApplyUpdate(); //根据之前的计算,更新weights.
(4)
Snapshot(); //序列化model参数并存储
3. 接下来,就是了解net 是如何进行ForwardBackward() 了。