转载请注明出处:https://blog.youkuaiyun.com/q_z_r_s/article/details/88315620
机器感知 一个专注于SLAM、三维重建、机器视觉等相关技术文章分享的公众号 ![]() |
Caffe代码执行流分析
- caffe.cpp
main()
caffe::GlobalInit(&argc, &argv):解析参数
GetBrewFunction(caffe::string(argv[1]))():根据传入参数调用train()或test()
shared_ptr<caffe::Solver<float> > \
solver(caffe::SolverRegistry<float>::CreateSolver(solver_param)):构造solver,默认SGD
- sgd_solver.hpp
explicit SGDSolver(const SolverParameter& param)
: Solver<Dtype>(param) { PreSolve(); }:先进入Solver类的构造函数,然后PreSolve
- solver.cpp
Solver<Dtype>::Solver(const SolverParameter& param)
: net_(), callbacks_(), requested_early_exit_(false) {
Init(param);
}:其实就是执行Init(param),其中的核心代码:InitTrainNet(); InitTestNets();
void Solver<Dtype>::InitTrainNet() {
...
net_.reset(new Net<Dtype>(net_param));
...
}:从SolverParameter中读取net_param,然后构建net,转而执行net的构造函数
- net.cpp
Net<Dtype>::Net(const NetParameter& param) {
Init(param);
}:根据传入的NetParameter构建Net,构建Net中的各个层及它们之间的联系。
至此网络模型就大致构建完毕了!!!调回train()函数继续往下执行
- caffe.cpp
train()::solver->Solve():最终执行此函数对网络模型可学习参数进行优化!!!
- solver.cpp
void Solver<Dtype>::Solve(const char* resume_file) {
...
int start_iter = iter_;
Step(param_.max_iter() - iter_);
...
}:实际执行优化的代码
从配置文件建立各个layer之间的输入输出
//caffe.proto中定义SolverParameter生成的数据结构中
caffe::ReadSolverParamsFromTextFileOrDie(FLAGS_solver, &solver_param);
...
// 实例化具体的Solver,例如SGD等
shared_ptr<caffe::Solver<float> >
solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));
...
如前述部分,实例化具体的Solver时会构造Net,下面主要分析Net的构建
void Solver<Dtype>::Init::
-->void Solver<Dtype>::InitTrainNet() {
...
NetParameter net_param;
...
net_.reset(new Net<Dtype>(net_param));
}
上述代码将跳转至Net类中构建各个具体的层
template <typename Dtype>
Net<Dtype>::Net(const NetParameter& param) {
Init(param);
}
void Net<Dtype>::Init(const NetParameter& in_param) {
// 根据 inlude/exclude 参数过滤掉一些层
NetParameter filtered_param;
FilterNet(in_param, &filtered_param);
// 重点在这里
InsertSplits(filtered_param, ¶m);
...
// For each layer, set up its input and output
// 这几个变量很重要
bottom_vecs_.resize(param.layer_size());
top_vecs_.resize(param.layer_size());
bottom_id_vecs_.resize(param.layer_size());
param_id_vecs_.resize(param.layer_size());
top_id_vecs_.resize(param.layer_size());
bottom_need_backward_.resize(param.layer_size());
}
void InsertSplits(const NetParameter& param, NetParameter* param_split) {
// 完成输入输出关系的绑定
}
接下来看看整个网络是如何运行起来的
// 回到 caffe::main 函数中
int train() {
...
solver->Solve();
...
}
接下来进入Solver::Solve
void Solver<Dtype>::Solve(const char* resume_file) {
...
// 如果有预训练权值等就加载进来
if (resume_file) {
LOG(INFO) << "Restoring previous solver status from " << resume_file;
Restore(resume_file);
}
...
Step(param_.max_iter() - iter_);
...
}
下面进入 Solver::Step
template <typename Dtype>
void Solver<Dtype>::Step(int iters) {
// 如果是加载了 solver state,则iter_不是从零开始的
const int start_iter = iter_;
const int stop_iter = iter_ + iters;
...
while (iter_ < stop_iter) {
...
// 一些 hook 函数
for (int i = 0; i < callbacks_.size(); ++i) {
callbacks_[i]->on_start();
}
...
Dtype loss = 0;
for (int i = 0; i < param_.iter_size(); ++i) {
// 关键就是这步了
loss += net_->ForwardBackward();
}
loss /= param_.iter_size();
...
下面进入 Net::ForwardBackward
// 只分析 Forward
Dtype ForwardBackward() {
Dtype loss;
Forward(&loss);
Backward();
return loss;
}
// 跳转到这里执行
template <typename Dtype>
const vector<Blob<Dtype>*>& Net<Dtype>::Forward(Dtype* loss) {
if (loss != NULL) {
*loss = ForwardFromTo(0, layers_.size() - 1);
} else {
ForwardFromTo(0, layers_.size() - 1);
}
return net_output_blobs_;
}
// 再跳转到这里执行真正的逐个 layer 的 forward
template <typename Dtype>
Dtype Net<Dtype>::ForwardFromTo(int start, int end) {
CHECK_GE(start, 0);
CHECK_LT(end, layers_.size());
Dtype loss = 0;
for (int i = start; i <= end; ++i) {
for (int c = 0; c < before_forward_.size(); ++c) {
before_forward_[c]->run(i);
}
// 关键代码,bottom_vecs_[i], top_vecs[i] 这就是前述的
// 代码建立的各个 layer 与输入输出关系时建立的
// 至此整个 Net 如何从 ProtoBuffer 文件一步步到真正的可运行的网络已经很清晰了
// 当然这里的 Forward 执行的不是基类中的方法,而是 ProtoBuffer 文件中
// 指定的某个特定类型的 layer,即基类的派生类中实现的方法
Dtype layer_loss = layers_[i]->Forward(bottom_vecs_[i], top_vecs_[i]);
loss += layer_loss;
if (debug_info_) { ForwardDebugInfo(i); }
for (int c = 0; c < after_forward_.size(); ++c) {
after_forward_[c]->run(i);
}
}
return loss;
}
总结
- 根据 ProtoBuffer 文件建立Solver
- 在 Solver 构造函数只能构建整个 Net,主要是建立各个层的输入输出关系
- Solver::Step --> Net::ForwardBackward