#ifndef CAFFE_DATA_LAYERS_HPP_
#define CAFFE_DATA_LAYERS_HPP_
#include <vector>
#include "caffe/blob.hpp"
#include "caffe/data_transformer.hpp"
#include "caffe/internal_thread.hpp"
#include "caffe/layer.hpp"
#include "caffe/proto/caffe.pb.h"
#include "caffe/util/blocking_queue.hpp"
namespace caffe {
/**
* @brief Provides base for data layers that feed blobs to the Net.
*
* TODO(dox): thorough documentation for Forward and proto params.
*/
template <typename Dtype>
//BaseDataLayer类 继承自Layer
class BaseDataLayer : public Layer<Dtype> {
public:
//显示构造函数
explicit BaseDataLayer(const LayerParameter& param);
// LayerSetUp: 实现常见的数据层设置功能,并且调用函数DataLayerSetUp实现个人特定的数据层设置
// 除了BasePrefetchingDataLayer都不能重写该方法
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
// Data layers应该被多个solver同步共享
virtual inline bool ShareInParallel() const { return true; }
//数据层的参数设置
virtual void DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {}
// 数据层没有Bottom,所以reshape不重要
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {}
//前向传播cpu函数
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {}
//反向传播cpu函数
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {}
protected:
//声明变量
TransformationParameter transform_param_;
shared_ptr<DataTransformer<Dtype> > data_transformer_;
bool output_labels_;
};
template <typename Dtype>
//Batch类,存放的就是数据和标签
class Batch {
public:
Blob<Dtype> data_, label_;
};
template <typename Dtype>
class BasePrefetchingDataLayer :
public BaseDataLayer<Dtype>, public InternalThread {
public:
//显示构造函数
explicit BasePrefetchingDataLayer(const LayerParameter& param);
//LayerSetUp: 实现常见的数据层设置功能,并且调用函数DataLayerSetUp实现个人特定的数据层设置
//不能重写该方法
void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
//前向传播
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
//反向传播
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
//把需要的batches提前取出,这样可以提高速度,这里是设定提前去多少数量的batches
// Prefetches batches (asynchronously if to GPU memory)
static const int PREFETCH_COUNT = 3;
protected:
//声明变量
virtual void InternalThreadEntry();
virtual void load_batch(Batch<Dtype>* batch) = 0;
Batch<Dtype> prefetch_[PREFETCH_COUNT];
BlockingQueue<Batch<Dtype>*> prefetch_free_;
BlockingQueue<Batch<Dtype>*> prefetch_full_;
Blob<Dtype> transformed_data_;
};
} // namespace caffe
#endif // CAFFE_DATA_LAYERS_HPP_
Caffe代码解读(二):base_data_layer.hpp
最新推荐文章于 2019-01-15 13:24:25 发布