caffe代码阅读5: Data_layers的实现细节

本文详细解析了Caffe框架中的数据层(Data Layers),包括各数据层的类结构、功能及其实现细节。介绍了BaseDataLayer、BasePrefetchingDataLayer等关键类的作用,并对DataLayer等进行了深入分析。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

一、Data_layers.hpp文件的作用简介


Data_layers.hpp在目前caffe的master分支中已经不能存在了,分散到各个文件中去了。
而之前是存在于cafferoot\include\caffe中。现在已经变成了各个类的名称的头文件了。这里做个提醒

首先给出这个文件中所包含的几个与数据读取有关的类。
分别为:
BaseDataLayer     数据层的基类,继承自通用的类Layer
Batch      Batch实际上就是一个data_和label_类标
BasePrefetchingDataLayer     是预取层的基类,继承自BaseDataLayer和InternalThread,包含能够读取一批数据的能力
DataLayer      DataLayer才是主角,继承自BasePrefetchingDataLayer ,使用DataReader来进行数据共享,从而实现并行化
DummyDataLayer      该类是继承自Layer,通过Filler产生数据
HDF5DataLayer         从HDF5中读取,继承自Layer
HDF5OutputLayer     将数据写入到HDF5文件,继承自Layer
ImageDataLayer        从图像文件中读取数据,这个应该比较常用,继承自BasePrefetchingDataLayer
MemoryDataLayer    从内存中读取数据,这里指已经从数据文件或者图像文件中读取到了数据,然后输入到该层,继承自BaseDataLayer
WindowDataLayer    从图像文件的窗口获取数据,需要指定窗口数据文件,继承自BasePrefetchingDataLayer

二、Data_layers文件的的详细介绍

上述类虽然在同一个头文件中进行的定义,但是却都是在不同的cpp文件进行的实现。
下面给出类的实现文件
BaseDataLayer和BasePrefetchingDataLayer    对应于:base_data_layer.cpp      base_data_layer.cu
DataLayer  对应于:data_layer.cpp
DummyDataLayer  对应于:dummy_data_layer.cpp
HDF5DataLayer   HDF5OutputLayer  对应于: hdf5_data_layer.cpp hdf5_data_layer.cu   以及hdf5_output_layer.cpp   hdf5_output_layer.cu
ImageDataLayer  对应于:image_data_layer.cpp
MemoryDataLayer 对应于:memory_data_layer.cpp
WindowDataLayer 对应于 window_data_layer.cpp

接下来对这些类进行详细阐述:

(1)BaseDataLayer的类定义以及实现如下:

[cpp]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. /** 
  2.  * @brief Provides base for data layers that feed blobs to the Net. 
  3.  * 
  4.  * TODO(dox): thorough documentation for Forward and proto params. 
  5.  * 数据层的基类 
  6.  */  
  7. template <typename Dtype>  
  8. class BaseDataLayer : public Layer<Dtype> {  
  9.  public:  
  10.   // 显式构造函数  
  11.   explicit BaseDataLayer(const LayerParameter& param);  
  12.   // LayerSetUp: implements common data layer setup functionality, and calls  
  13.   // DataLayerSetUp to do special data layer setup for individual layer types.  
  14.   // This method may not be overridden except by the BasePrefetchingDataLayer.  
  15.   // 该函数只能被BasePrefetchingDataLayer层进行重载  
  16.   virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,  
  17.       const vector<Blob<Dtype>*>& top);  
  18.   // Data layers should be shared by multiple solvers in parallel  
  19.   // 数据是否需要给多个并行solver进行共享  
  20.   virtual inline bool ShareInParallel() const { return true; }  
  21.   
  22.   // 数据层的初始化  
  23.   virtual void DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,  
  24.       const vector<Blob<Dtype>*>& top) {}  
  25.   
  26.   // 数据层是没有输入的(即bottoms),所以reshape只是形式  
  27.   // Data layers have no bottoms, so reshaping is trivial.  
  28.   virtual void Reshape(const vector<Blob<Dtype>*>& bottom,  
  29.       const vector<Blob<Dtype>*>& top) {}  
  30.   
  31.   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,  
  32.       const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {}  
  33.   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,  
  34.       const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {}  
  35.   
  36.  protected:  
  37.   // 对输入的数据进行变换的参数,这其中包括是否需要mirror,是否需要crop  
  38.   // 是否需要减去meanfile,是否需要scale  
  39.   TransformationParameter transform_param_;  
  40.   // 实际执行数据变换类的指针(一个Transform函数加上参数即可完成对数据的变换,参数是数据哈)  
  41.   shared_ptr<DataTransformer<Dtype> > data_transformer_;  
  42.   bool output_labels_;  
  43. };  
具体的实现:
[cpp]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. // 构造函数就是初始化数据变换参数  
  2. template <typename Dtype>  
  3. BaseDataLayer<Dtype>::BaseDataLayer(const LayerParameter& param)  
  4.     : Layer<Dtype>(param),  
  5.       transform_param_(param.transform_param()) {  
  6. }  
  7.   
  8. // 初始化的时候根据top的大小来确定,如果是1表明只输出数据,而不输出类标  
  9. template <typename Dtype>  
  10. void BaseDataLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,  
  11.       const vector<Blob<Dtype>*>& top) {  
  12.   if (top.size() == 1) {  
  13.     output_labels_ = false;  
  14.   } else {  
  15.     output_labels_ = true;  
  16.   }  
  17.   // 初始化一个DataTransformer实例,便于对数据进行预处理  
  18.   data_transformer_.reset(  
  19.       new DataTransformer<Dtype>(transform_param_, this->phase_));  
  20.   // 初始化种子  
  21.   data_transformer_->InitRand();  
  22.   // The subclasses should setup the size of bottom and top  
  23.   // 执行数据层的初始化  
  24.   DataLayerSetUp(bottom, top);  
  25. }  

(2)BasePrefetchingDataLayer类的定义以及实现如下:

BasePrefetchingDataLayer类的定义如下:
[cpp]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. // BasePrefetchingDataLayer层是继承于BaseDataLayer的  
  2. // 是预取层的基类  
  3. template <typename Dtype>  
  4. class BasePrefetchingDataLayer :  
  5.     public BaseDataLayer<Dtype>, public InternalThread {  
  6.  public:  
  7.   explicit BasePrefetchingDataLayer(const LayerParameter& param);  
  8.   // LayerSetUp: implements common data layer setup functionality, and calls  
  9.   // DataLayerSetUp to do special data layer setup for individual layer types.  
  10.   // This method may not be overridden.  
  11.   void LayerSetUp(const vector<Blob<Dtype>*>& bottom,  
  12.       const vector<Blob<Dtype>*>& top);  
  13.   
  14.   virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,  
  15.       const vector<Blob<Dtype>*>& top);  
  16.   virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,  
  17.       const vector<Blob<Dtype>*>& top);  
  18.   
  19.   // Prefetches batches (asynchronously if to GPU memory)  
  20.   static const int PREFETCH_COUNT = 3;  
  21.   
  22.  protected:  
  23.   virtual void InternalThreadEntry();  
  24.   // 多了load_batch函数,该函数是纯虚函数,继承该函数的类都需要实现的  
  25.   virtual void load_batch(Batch<Dtype>* batch) = 0;  
  26.   // 还有prefetch数组,prefetch_free_,prefetch_full_  
  27.   Batch<Dtype> prefetch_[PREFETCH_COUNT];  
  28.   BlockingQueue<Batch<Dtype>*> prefetch_free_;  
  29.   BlockingQueue<Batch<Dtype>*> prefetch_full_;  
  30.   
  31.   Blob<Dtype> transformed_data_;  
  32. };  
  33.   
  34.   
  35. BasePrefetchingDataLayer类的具体实现如下:  
  36. // 构造函数,初始化预取的队列,free和full  
  37. template <typename Dtype>  
  38. BasePrefetchingDataLayer<Dtype>::BasePrefetchingDataLayer(  
  39.     const LayerParameter& param)  
  40.     : BaseDataLayer<Dtype>(param),  
  41.       prefetch_free_(), prefetch_full_() {  
  42.   for (int i = 0; i < PREFETCH_COUNT; ++i) {  
  43.     prefetch_free_.push(&prefetch_[i]);  
  44.   }  
  45. }  
  46.   
  47. // 进行层的初始化  
  48. template <typename Dtype>  
  49. void BasePrefetchingDataLayer<Dtype>::LayerSetUp(  
  50.     const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {  
  51.     // 首先执行基类BaseDataLayer的层初始化  
  52.   BaseDataLayer<Dtype>::LayerSetUp(bottom, top);  
  53.   // Before starting the prefetch thread, we make cpu_data and gpu_data  
  54.   // calls so that the prefetch thread does not accidentally make simultaneous  
  55.   // cudaMalloc calls when the main thread is running. In some GPUs this  
  56.   // seems to cause failures if we do not so.  
  57.   // 在开启预取线程的时候,需要让cpu数据和gpu数据分配空间  
  58.   // 这样才能够避免在某些GPU上出现问题  
  59.   
  60.   // 首先是CPU  
  61.   for (int i = 0; i < PREFETCH_COUNT; ++i) {  
  62.     prefetch_[i].data_.mutable_cpu_data();  
  63.     if (this->output_labels_) {  
  64.       prefetch_[i].label_.mutable_cpu_data();  
  65.     }  
  66.   }  
  67. #ifndef CPU_ONLY  
  68.   // 然后是GPU  
  69.   if (Caffe::mode() == Caffe::GPU) {  
  70.     for (int i = 0; i < PREFETCH_COUNT; ++i) {  
  71.       prefetch_[i].data_.mutable_gpu_data();  
  72.       if (this->output_labels_) {  
  73.         prefetch_[i].label_.mutable_gpu_data();  
  74.       }  
  75.     }  
  76.   }  
  77. #endif  
  78.   DLOG(INFO) << "Initializing prefetch";  
  79.   // 初始化随机数种子  
  80.   this->data_transformer_->InitRand();  
  81.   // 开启线程  
  82.   StartInternalThread();  
  83.   DLOG(INFO) << "Prefetch initialized.";  
  84. }  
  85.   
  86. // 在StartInternalThread开启线程后就会执行下面自己定义的函数  
  87. // 这个就是自己定义的函数,让线程去执行的  
  88. template <typename Dtype>  
  89. void BasePrefetchingDataLayer<Dtype>::InternalThreadEntry() {  
  90. #ifndef CPU_ONLY  
  91.   cudaStream_t stream;  
  92.   if (Caffe::mode() == Caffe::GPU) {  
  93.       // 创建非阻塞流  
  94.     CUDA_CHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));  
  95.   }  
  96. #endif  
  97.   
  98.   try {  
  99.     while (!must_stop()) {  
  100.         // 弹出一个batch  
  101.       Batch<Dtype>* batch = prefetch_free_.pop();  
  102.         // 装载batch  
  103.       load_batch(batch);  
  104. #ifndef CPU_ONLY  
  105.       if (Caffe::mode() == Caffe::GPU) {  
  106.           // 如果GPU模式开始,则推送到GPU  
  107.         batch->data_.data().get()->async_gpu_push(stream);  
  108.         // 检查是否成功  
  109.         CUDA_CHECK(cudaStreamSynchronize(stream));  
  110.       }  
  111. #endif  
  112.       // 将装好的batch压入full队列  
  113.       prefetch_full_.push(batch);  
  114.     }  
  115.   } catch (boost::thread_interrupted&) {  
  116.     // Interrupted exception is expected on shutdown  
  117.   }  
  118. #ifndef CPU_ONLY  
  119.   if (Caffe::mode() == Caffe::GPU) {  
  120.       // 销毁流  
  121.     CUDA_CHECK(cudaStreamDestroy(stream));  
  122.   }  
  123. #endif  
  124. }  
  125.   
  126. template <typename Dtype>  
  127. void BasePrefetchingDataLayer<Dtype>::Forward_cpu(  
  128.     const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {  
  129.     // 传递的时候是从full队列中弹出一个数据  
  130.   Batch<Dtype>* batch = prefetch_full_.pop("Data layer prefetch queue empty");  
  131.   // Reshape to loaded data.  
  132.   // 根据batch的形状改变数据形状  
  133.   top[0]->ReshapeLike(batch->data_);  
  134.   // Copy the data  
  135.   // 将batch数据复制到top[0]  
  136.   caffe_copy(batch->data_.count(), batch->data_.cpu_data(),  
  137.              top[0]->mutable_cpu_data());  
  138.   DLOG(INFO) << "Prefetch copied";  
  139.   if (this->output_labels_) {  
  140.       // 输出类标的话  
  141.     // Reshape to loaded labels.  
  142.     // 根据batch中类标的形状改变top[1]的形状  
  143.     top[1]->ReshapeLike(batch->label_);  
  144.     // Copy the labels.  
  145.     // 复制类标到top[1]  
  146.     caffe_copy(batch->label_.count(), batch->label_.cpu_data(),  
  147.         top[1]->mutable_cpu_data());  
  148.   }  
  149.   // 将该batch压入free队列  
  150.   prefetch_free_.push(batch);  
  151. }  
  152.   
  153.   
  154. // 如果没有GPU的话则在BasePrefetchingDataLayer类中生成一个Forward函数  
  155. // 该函数并不前传,而是直接报错  
  156. #ifdef CPU_ONLY  
  157. STUB_GPU_FORWARD(BasePrefetchingDataLayer, Forward);  
  158. #endif  
  159. // 初始化层  
  160. INSTANTIATE_CLASS(BaseDataLayer);  
  161. INSTANTIATE_CLASS(BasePrefetchingDataLayer);  

(3)DataLayer类的定义以及实现如下:

数据层的主要功能是:
首先给出类的定义

[cpp]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. // DataLayer才是主角,继承自BasePrefetchingDataLayer  
  2. template <typename Dtype>  
  3. class DataLayer : public BasePrefetchingDataLayer<Dtype> {  
  4.  public:  
  5.   explicit DataLayer(const LayerParameter& param);  
  6.   virtual ~DataLayer();  
  7.   virtual void DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,  
  8.       const vector<Blob<Dtype>*>& top);  
  9.   // DataLayer uses DataReader instead for sharing for parallelism  
  10.   // 多了下面几个  
  11.   virtual inline bool ShareInParallel() const { return false; }  
  12.   virtual inline const char* type() const { return "Data"; }  
  13.   virtual inline int ExactNumBottomBlobs() const { return 0; }  
  14.   virtual inline int MinTopBlobs() const { return 1; }  
  15.   virtual inline int MaxTopBlobs() const { return 2; }  
  16.   
  17.  protected:  
  18.   virtual void load_batch(Batch<Dtype>* batch);  
  19.   
  20.   DataReader reader_;  
  21. };  
具体的实现如下:
[cpp]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. #ifdef USE_OPENCV  
  2. #include <opencv2/core/core.hpp>  
  3. #endif  // USE_OPENCV  
  4. #include <stdint.h>  
  5.   
  6. #include <string>  
  7. #include <vector>  
  8.   
  9. #include "caffe/common.hpp"  
  10. #include "caffe/data_layers.hpp"  
  11. #include "caffe/layer.hpp"  
  12. #include "caffe/proto/caffe.pb.h"  
  13. #include "caffe/util/benchmark.hpp"  
  14. #include "caffe/util/io.hpp"  
  15.   
  16. namespace caffe {  
  17.   
  18. // 初始化DataReader,层参数  
  19. template <typename Dtype>  
  20. DataLayer<Dtype>::DataLayer(const LayerParameter& param)  
  21.   : BasePrefetchingDataLayer<Dtype>(param),  
  22.     reader_(param) {  
  23. }  
  24.   
  25. // 析构函数停止内部线程  
  26. template <typename Dtype>  
  27. DataLayer<Dtype>::~DataLayer() {  
  28.   this->StopInternalThread();  
  29. }  
  30.   
  31. // 数据层的初始化  
  32. template <typename Dtype>  
  33. void DataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,  
  34.       const vector<Blob<Dtype>*>& top) {  
  35.   // 从层参数中读取batch_size  
  36.   const int batch_size = this->layer_param_.data_param().batch_size();  
  37.   // Read a data point, and use it to initialize the top blob.  
  38.   // 从reader_中获取一个数据  
  39.   Datum& datum = *(reader_.full().peek());  
  40.   
  41.   // Use data_transformer to infer the expected blob shape from datum.  
  42.   // 用数据来推断blob的形状存放到top_shape  
  43.   vector<int> top_shape = this->data_transformer_->InferBlobShape(datum);  
  44.   this->transformed_data_.Reshape(top_shape);  
  45.   // Reshape top[0] and prefetch_data according to the batch_size.  
  46.   // 既然获取了数据的形状(channel,height,width),那么这里再设置一下batch_size  
  47.   // top_shape[0]=batch_size  
  48.   // top_shape[1]=channel  
  49.   // top_shape[2]=height  
  50.   // top_shape[3]=width  
  51.   top_shape[0] = batch_size;  
  52.   // 根据形状设置top[0]的形状  
  53.   top[0]->Reshape(top_shape);  
  54.   
  55.   // 设置预取数据的形状  
  56.   for (int i = 0; i < this->PREFETCH_COUNT; ++i) {  
  57.     this->prefetch_[i].data_.Reshape(top_shape);  
  58.   }  
  59.   LOG(INFO) << "output data size: " << top[0]->num() << ","  
  60.       << top[0]->channels() << "," << top[0]->height() << ","  
  61.       << top[0]->width();  
  62.   // label  
  63.   // 如果输出类标的话则把top[1]的形状也弄一下  
  64.   if (this->output_labels_) {  
  65.     vector<int> label_shape(1, batch_size);  
  66.     top[1]->Reshape(label_shape);  
  67.     for (int i = 0; i < this->PREFETCH_COUNT; ++i) {  
  68.       this->prefetch_[i].label_.Reshape(label_shape);  
  69.     }  
  70.   }  
  71. }  
  72.   
  73. // This function is called on prefetch thread  
  74. // 这个函数是在自己定义的线程执行函数内部执行的  
  75. template<typename Dtype>  
  76. void DataLayer<Dtype>::load_batch(Batch<Dtype>* batch) {  
  77.   CPUTimer batch_timer;  
  78.   batch_timer.Start();  
  79.   double read_time = 0;  
  80.   double trans_time = 0;  
  81.   CPUTimer timer;  
  82.   CHECK(batch->data_.count());  
  83.   CHECK(this->transformed_data_.count());  
  84.   
  85.   // Reshape according to the first datum of each batch  
  86.   // on single input batches allows for inputs of varying dimension.  
  87.   // 意思是像以下这种做法这样的话,每个batch的数据的维度可以不一样  
  88.   // 从参数文件获取batch_size  
  89.   const int batch_size = this->layer_param_.data_param().batch_size();  
  90.   // 获取第一个数据  
  91.   Datum& datum = *(reader_.full().peek());  
  92.   // Use data_transformer to infer the expected blob shape from datum.  
  93.   // 使用第一个数据推断blob的形状  
  94.   vector<int> top_shape = this->data_transformer_->InferBlobShape(datum);  
  95.   this->transformed_data_.Reshape(top_shape);  
  96.   // Reshape batch according to the batch_size.  
  97.   top_shape[0] = batch_size;  
  98.   batch->data_.Reshape(top_shape);  
  99.   
  100.   // top_data存数据  
  101.   Dtype* top_data = batch->data_.mutable_cpu_data();  
  102.   Dtype* top_label = NULL;  // suppress warnings about uninitialized variables  
  103.   
  104.   // top_label存类标  
  105.   if (this->output_labels_) {  
  106.     top_label = batch->label_.mutable_cpu_data();  
  107.   }  
  108.   
  109.   // 对这批数据进行处理  
  110.   for (int item_id = 0; item_id < batch_size; ++item_id) {  
  111.     timer.Start();  
  112.     // get a datum  
  113.     Datum& datum = *(reader_.full().pop("Waiting for data"));  
  114.     read_time += timer.MicroSeconds();  
  115.     timer.Start();  
  116.     // Apply data transformations (mirror, scale, crop...)  
  117.     // 对于给定批的数据获取offset,这里调用的是给定batchid,然后获取offset  
  118.     int offset = batch->data_.offset(item_id);  
  119.     this->transformed_data_.set_cpu_data(top_data + offset);  
  120.     this->data_transformer_->Transform(datum, &(this->transformed_data_));  
  121.     // Copy label.  
  122.     // 复制类标  
  123.     if (this->output_labels_) {  
  124.       top_label[item_id] = datum.label();  
  125.     }  
  126.     // 数据传输时间  
  127.     trans_time += timer.MicroSeconds();  
  128.   
  129.     // 将数据指针压到free队列  
  130.     reader_.free().push(const_cast<Datum*>(&datum));  
  131.   }  
  132.   timer.Stop();  
  133.   batch_timer.Stop();  
  134.   DLOG(INFO) << "Prefetch batch: " << batch_timer.MilliSeconds() << " ms.";  
  135.   DLOG(INFO) << "     Read time: " << read_time / 1000 << " ms.";  
  136.   DLOG(INFO) << "Transform time: " << trans_time / 1000 << " ms.";  
  137. }  
  138.   
  139. INSTANTIATE_CLASS(DataLayer);  
  140. REGISTER_LAYER_CLASS(Data);  
  141.   
  142. }  // namespace caffe  

(4)DummyDataLayer类的定义与实现介绍:

Dummy数据层的主要功能就是根据所给定的Filler产生数据,然后前向传
首先给出定义
[cpp]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. /** 
  2.  * @brief Provides data to the Net generated by a Filler. 
  3.  * 
  4.  * TODO(dox): thorough documentation for Forward and proto params. 
  5.  * 该类是继承自Layer,通过Filler产生数据 
  6.  */  
  7. template <typename Dtype>  
  8. class DummyDataLayer : public Layer<Dtype> {  
  9.  public:  
  10.   explicit DummyDataLayer(const LayerParameter& param)  
  11.       : Layer<Dtype>(param) {}  
  12.   virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,  
  13.       const vector<Blob<Dtype>*>& top);  
  14.   // Data layers should be shared by multiple solvers in parallel  
  15.   virtual inline bool ShareInParallel() const { return true; }  
  16.   // Data layers have no bottoms, so reshaping is trivial.  
  17.   virtual void Reshape(const vector<Blob<Dtype>*>& bottom,  
  18.       const vector<Blob<Dtype>*>& top) {}  
  19.   
  20.   virtual inline const char* type() const { return "DummyData"; }  
  21.   virtual inline int ExactNumBottomBlobs() const { return 0; }  
  22.   virtual inline int MinTopBlobs() const { return 1; }  
  23.   
  24.  protected:  
  25.   virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,  
  26.       const vector<Blob<Dtype>*>& top);  
  27.   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,  
  28.       const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {}  
  29.   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,  
  30.       const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {}  
  31.   
  32.   vector<shared_ptr<Filler<Dtype> > > fillers_;  
  33.   vector<bool> refill_;  
  34. };  
接下来给出详细的定义:
首先给出FillerParameter的定义,里面指定了值的类型,值是啥,最小是啥,最大是啥,平均值、方差是啥、是否稀疏、以及将扇入个数还是扇出个数还是所有的加起来求均值作为分母
[plain]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. message FillerParameter {  
  2.   // The filler type.  
  3.   optional string type = 1 [default = 'constant'];  
  4.   optional float value = 2 [default = 0]; // the value in constant filler  
  5.   optional float min = 3 [default = 0]; // the min value in uniform filler  
  6.   optional float max = 4 [default = 1]; // the max value in uniform filler  
  7.   optional float mean = 5 [default = 0]; // the mean value in Gaussian filler  
  8.   optional float std = 6 [default = 1]; // the std value in Gaussian filler  
  9.   // The expected number of non-zero output weights for a given input in  
  10.   // Gaussian filler -- the default -1 means don't perform sparsification.  
  11.   optional int32 sparse = 7 [default = -1];  
  12.   // Normalize the filler variance by fan_in, fan_out, or their average.  
  13.   // Applies to 'xavier' and 'msra' fillers.  
  14.   enum VarianceNorm {  
  15.     FAN_IN = 0;  
  16.     FAN_OUT = 1;  
  17.     AVERAGE = 2;  
  18.   }  
  19.   optional VarianceNorm variance_norm = 8 [default = FAN_IN];  
  20. }  
再看看该类的参数
[cpp]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. </pre><pre name="code" class="plain">// DummyDataLayer fills any number of arbitrarily shaped blobs with random  
  2. // (or constant) data generated by "Fillers" (see "message FillerParameter").  
  3. message DummyDataParameter {  
  4.   // This layer produces N >= 1 top blobs.  DummyDataParameter must specify 1 or N  
  5.   // shape fields, and 0, 1 or N data_fillers.  
  6.   //  
  7.   // If 0 data_fillers are specified, ConstantFiller with a value of 0 is used.  
  8.   // If 1 data_filler is specified, it is applied to all top blobs.  If N are  
  9.   // specified, the ith is applied to the ith top blob.  
  10.   repeated FillerParameter data_filler = 1;  
  11.   repeated BlobShape shape = 6;  
  12.   
  13.   // 4D dimensions -- deprecated.  Use "shape" instead.  
  14.   repeated uint32 num = 2;  
  15.   repeated uint32 channels = 3;  
  16.   repeated uint32 height = 4;  
  17.   repeated uint32 width = 5;  
  18. }  
接下来给出具体的实现
[cpp]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. #include <vector>  
  2.   
  3. #include "caffe/filler.hpp"  
  4. #include "caffe/layer.hpp"  
  5. #include "caffe/vision_layers.hpp"  
  6.   
  7. namespace caffe {  
  8.   
  9. template <typename Dtype>  
  10. void DummyDataLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,  
  11.       const vector<Blob<Dtype>*>& top) {  
  12.   // 输出有几个  
  13.   const int num_top = top.size();  
  14.   // 获取该层的参数  
  15.   const DummyDataParameter& param = this->layer_param_.dummy_data_param();  
  16.   // 有几个filler  
  17.   const int num_data_filler = param.data_filler_size();  
  18.   // 检查filler的个数,要么为0、1、或者等于输出的个数  
  19.   CHECK(num_data_filler == 0 || num_data_filler == 1 ||  
  20.         num_data_filler == num_top)  
  21.       << "Number of data fillers must be 0, 1 or equal to the number of tops: "  
  22.       << num_top << "; you specified " << num_data_filler << " data fillers.";  
  23.   
  24.   // 判断是否全部为0  
  25.   const bool legacy_dims = param.num_size() || param.channels_size() ||  
  26.                            param.height_size() || param.width_size();  
  27.   // 下面就是检查参数是不是满足要求,1或者0或者等于num_top  
  28.   if (legacy_dims) {// 如果不是全部为0  
  29.     CHECK_EQ(0, param.shape_size())  
  30.         << "Both shape and legacy fields were specified";  
  31.     // Using deprecated 4D output dim specifiers.  
  32.     CHECK(param.num_size() == 1 || param.num_size() == num_top)  
  33.         << "Must specify 'num' once, or once per top blob "  
  34.         << "(" << num_top << "); specified " << param.num_size() << ".";  
  35.     CHECK(param.channels_size() == 1 || param.channels_size() == num_top)  
  36.         << "Must specify 'channels' once, or once per top blob "  
  37.         << "(" << num_top << "); specified " << param.channels_size() << ".";  
  38.     CHECK(param.height_size() == 1 || param.height_size() == num_top)  
  39.         << "Must specify 'height' once, or once per top blob "  
  40.         << "(" << num_top << "); specified " << param.height_size() << ".";  
  41.     CHECK(param.width_size() == 1 || param.width_size() == num_top)  
  42.         << "Must specify 'width' once, or once per top blob "  
  43.         << "(" << num_top << "); specified " << param.width_size() << ".";  
  44.   } else {  
  45.     CHECK(param.shape_size() == 1 || param.shape_size() == num_top)  
  46.         << "Must specify 'shape' once, or once per top blob "  
  47.         << "(" << num_top << "); specified " << param.shape_size() << ".";  
  48.   }  
  49.   // refill_[i] tells Forward i whether or not to actually refill top Blob i.  
  50.   // If refill_[i] is false, Forward does nothing for Blob i. We use this to  
  51.   // avoid wastefully refilling "constant" Blobs in every forward pass.  
  52.   // We first fill refill_ in with the INVERSE of its final values.  
  53.   // The first time we run Forward from the LayerSetUp method, we'll fill only  
  54.   // Blobs for which refill_ is normally false.  These Blobs will never be  
  55.   // filled again.  
  56.   // refill_表明是不是需要填充Blob,如果refill_[i]=false,那么就不会Blob i做任何事  
  57.   //  
  58.   refill_.clear();  
  59.   fillers_.clear();  
  60.   // 要么是0,要么是1  
  61.   if (num_data_filler <= 1) {  
  62.       // 定义了生成数据的参数  
  63.       // 比如均值、方差等,详细请看其定义  
  64.     FillerParameter filler_param;  
  65.     if (num_data_filler == 0) {  
  66.       // 如果没有指定,那么就是常数值填充  
  67.       filler_param.set_type("constant");  
  68.       filler_param.set_value(0);  
  69.     } else {  
  70.       // 否则复制filler到filler_param  
  71.       filler_param.CopyFrom(param.data_filler(0));  
  72.     }  
  73.     // Refill on each iteration iff not using a constant filler,  
  74.     // but use the inverse of this rule for the first run.  
  75.     // 如果  
  76.     refill_.resize(1);  
  77.     refill_[0] = (strcmp(filler_param.type().c_str(), "constant") == 0);  
  78.     fillers_.resize(1);  
  79.     // 实例化填充器  
  80.     fillers_[0].reset(GetFiller<Dtype>(filler_param));  
  81.   } else {// 如果等于=num_top  
  82.     refill_.resize(num_top);  
  83.     fillers_.resize(num_top);  
  84.     for (int i = 0; i < num_top; ++i) {  
  85.       fillers_[i].reset(GetFiller<Dtype>(param.data_filler(i)));  
  86.       // Refill on each iteration iff not using a constant filler,  
  87.       // but use the inverse of this rule for the first run.  
  88.       refill_[i] =  
  89.           (strcmp(param.data_filler(i).type().c_str(), "constant") == 0);  
  90.     }  
  91.   }  
  92.   
  93.   // 改变形状  
  94.   for (int i = 0; i < num_top; ++i) {  
  95.     if (legacy_dims) {  
  96.       const int num = (param.num_size() == 1) ? param.num(0) : param.num(i);  
  97.       const int channels =  
  98.           (param.channels_size() == 1) ? param.channels(0) : param.channels(i);  
  99.       const int height =  
  100.           (param.height_size() == 1) ? param.height(0) : param.height(i);  
  101.       const int width =  
  102.           (param.width_size() == 1) ? param.width(0) : param.width(i);  
  103.       top[i]->Reshape(num, channels, height, width);  
  104.     } else {  
  105.       const int shape_index = (param.shape_size() == 1) ? 0 : i;  
  106.       top[i]->Reshape(param.shape(shape_index));  
  107.     }  
  108.   }  
  109.   // Run Forward once, with refill_ inverted, to fill the constant Blobs.  
  110.   // 执行forward_cpu  
  111.   this->Forward(bottom, top);  
  112.   // Invert the inverted refill_ values to refill the desired (non-constant)  
  113.   // Blobs in every usual forward pass.  
  114.   for (int i = 0; i < refill_.size(); ++i) {  
  115.     refill_[i] = !refill_[i];  
  116.   }  
  117. }  
  118.   
  119. // Forward里调用了该函数  
  120. template <typename Dtype>  
  121. void DummyDataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,  
  122.       const vector<Blob<Dtype>*>& top) {  
  123.       // 调用fillers_来进行錐ill  
  124.   for (int i = 0; i < top.size(); ++i) {  
  125.     const int filler_id = (fillers_.size() > 1) ? i : 0;  
  126.     if (refill_[filler_id]) {  
  127.       fillers_[filler_id]->Fill(top[i]);  
  128.     }  
  129.   }  
  130. }  
  131.   
  132. // 初始化类  
  133. // 注册类  
  134. INSTANTIATE_CLASS(DummyDataLayer);  
  135. REGISTER_LAYER_CLASS(DummyData);  
  136.   
  137. }  // namespace caffe  

(5)HDF5DataLayer类的定义以及实现如下:

HDF5数据层的主要功能是从给定的HDF5文件列表读取数据,然后设置top,即向前传播的数据。

首先给出类的定义:
[cpp]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. template <typename Dtype>  
  2. class HDF5DataLayer : public Layer<Dtype> {  
  3.  public:  
  4.   explicit HDF5DataLayer(const LayerParameter& param)  
  5.       : Layer<Dtype>(param) {}  
  6.   virtual ~HDF5DataLayer();  
  7.   virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,  
  8.       const vector<Blob<Dtype>*>& top);  
  9.   // Data layers should be shared by multiple solvers in parallel  
  10.   virtual inline bool ShareInParallel() const { return true; }  
  11.   // Data layers have no bottoms, so reshaping is trivial.  
  12.   virtual void Reshape(const vector<Blob<Dtype>*>& bottom,  
  13.       const vector<Blob<Dtype>*>& top) {}  
  14.   
  15.   virtual inline const char* type() const { return "HDF5Data"; }  
  16.   virtual inline int ExactNumBottomBlobs() const { return 0; }  
  17.   virtual inline int MinTopBlobs() const { return 1; }  
  18.   
  19.  protected:  
  20.   virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,  
  21.       const vector<Blob<Dtype>*>& top);  
  22.   virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,  
  23.       const vector<Blob<Dtype>*>& top);  
  24.   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,  
  25.       const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {}  
  26.   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,  
  27.       const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {}  
  28.   // 从HDF5文件读取数据  
  29.   virtual void LoadHDF5FileData(const char* filename);  
  30.   
  31.   std::vector<std::string> hdf_filenames_;  
  32.   unsigned int num_files_;  
  33.   unsigned int current_file_;  
  34.   hsize_t current_row_;  
  35.   std::vector<shared_ptr<Blob<Dtype> > > hdf_blobs_;  
  36.   // 存放的是数据的索引,可以对索引进行shuffle  
  37.   std::vector<unsigned int> data_permutation_;  
  38.   // 存放的是文件名字的索引,可以对索引进行shuffle  
  39.   std::vector<unsigned int> file_permutation_;  
  40. };  
接下来给出类的具体实现:
给出实现之前先给出HDF5的操作
头文件:
[cpp]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. #ifndef CAFFE_UTIL_HDF5_H_  
  2. #define CAFFE_UTIL_HDF5_H_  
  3.   
  4. #include <string>  
  5.   
  6. #include "hdf5.h"  
  7. #include "hdf5_hl.h"  
  8.   
  9. #include "caffe/blob.hpp"  
  10.   
  11. namespace caffe {  
  12.   
  13. // 获取HDF5文件的信息以及数据的维度  
  14. template <typename Dtype>  
  15. void hdf5_load_nd_dataset_helper(  
  16.     hid_t file_id, const char* dataset_name_, int min_dim, int max_dim,  
  17.     Blob<Dtype>* blob);  
  18.   
  19. // float类型的获取数据维度和信息的包裹函数  
  20. template <typename Dtype>  
  21. void hdf5_load_nd_dataset(  
  22.     hid_t file_id, const char* dataset_name_, int min_dim, int max_dim,  
  23.     Blob<Dtype>* blob);  
  24.   
  25. // double类型的获取数据维度和信息的包裹函数  
  26. template <typename Dtype>  
  27. void hdf5_save_nd_dataset(  
  28.     const hid_t file_id, const string& dataset_name, const Blob<Dtype>& blob,  
  29.     bool write_diff = false);  
  30.   
  31. // 读取int和存储int,读取字符串和存储字符串到文件  
  32. int hdf5_load_int(hid_t loc_id, const string& dataset_name);  
  33. void hdf5_save_int(hid_t loc_id, const string& dataset_name, int i);  
  34. string hdf5_load_string(hid_t loc_id, const string& dataset_name);  
  35. void hdf5_save_string(hid_t loc_id, const string& dataset_name,  
  36.                       const string& s);  
  37.   
  38. // 获取链接数  
  39. int hdf5_get_num_links(hid_t loc_id);  
  40. // 根据名字找到索引  
  41. string hdf5_get_name_by_idx(hid_t loc_id, int idx);  
  42.   
  43. }  // namespace caffe  
  44.   
  45. #endif   // CAFFE_UTIL_HDF5_H_  
cpp文件:
[cpp]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. #include "caffe/util/hdf5.hpp"  
  2.   
  3. #include <string>  
  4. #include <vector>  
  5.   
  6. namespace caffe {  
  7.   
  8. // Verifies format of data stored in HDF5 file and reshapes blob accordingly.  
  9. // 获取HDF5文件的信息以及数据的维度  
  10. template <typename Dtype>  
  11. void hdf5_load_nd_dataset_helper(  
  12.     hid_t file_id, const char* dataset_name_, int min_dim, int max_dim,  
  13.     Blob<Dtype>* blob) {  
  14.   // Verify that the dataset exists.  
  15.   // 检查是否存在  
  16.   CHECK(H5LTfind_dataset(file_id, dataset_name_))  
  17.       << "Failed to find HDF5 dataset " << dataset_name_;  
  18.   // Verify that the number of dimensions is in the accepted range.  
  19.   herr_t status;  
  20.   int ndims;  
  21.   // 获取数据维度  
  22.   status = H5LTget_dataset_ndims(file_id, dataset_name_, &ndims);  
  23.   CHECK_GE(status, 0) << "Failed to get dataset ndims for " << dataset_name_;  
  24.   CHECK_GE(ndims, min_dim);  
  25.   CHECK_LE(ndims, max_dim);  
  26.   
  27.   // Verify that the data format is what we expect: float or double.  
  28.   std::vector<hsize_t> dims(ndims);  
  29.   H5T_class_t class_;  
  30.   // 获取数据信息  
  31.   status = H5LTget_dataset_info(  
  32.       file_id, dataset_name_, dims.data(), &class_, NULL);  
  33.   CHECK_GE(status, 0) << "Failed to get dataset info for " << dataset_name_;  
  34.   switch (class_) {  
  35.   case H5T_FLOAT:  
  36.     LOG_FIRST_N(INFO, 1) << "Datatype class: H5T_FLOAT";  
  37.     break;  
  38.   case H5T_INTEGER:  
  39.     LOG_FIRST_N(INFO, 1) << "Datatype class: H5T_INTEGER";  
  40.     break;  
  41.   case H5T_TIME:  
  42.     LOG(FATAL) << "Unsupported datatype class: H5T_TIME";  
  43.   case H5T_STRING:  
  44.     LOG(FATAL) << "Unsupported datatype class: H5T_STRING";  
  45.   case H5T_BITFIELD:  
  46.     LOG(FATAL) << "Unsupported datatype class: H5T_BITFIELD";  
  47.   case H5T_OPAQUE:  
  48.     LOG(FATAL) << "Unsupported datatype class: H5T_OPAQUE";  
  49.   case H5T_COMPOUND:  
  50.     LOG(FATAL) << "Unsupported datatype class: H5T_COMPOUND";  
  51.   case H5T_REFERENCE:  
  52.     LOG(FATAL) << "Unsupported datatype class: H5T_REFERENCE";  
  53.   case H5T_ENUM:  
  54.     LOG(FATAL) << "Unsupported datatype class: H5T_ENUM";  
  55.   case H5T_VLEN:  
  56.     LOG(FATAL) << "Unsupported datatype class: H5T_VLEN";  
  57.   case H5T_ARRAY:  
  58.     LOG(FATAL) << "Unsupported datatype class: H5T_ARRAY";  
  59.   default:  
  60.     LOG(FATAL) << "Datatype class unknown";  
  61.   }  
  62.   
  63.   // 设置blob的维度  
  64.   vector<int> blob_dims(dims.size());  
  65.   for (int i = 0; i < dims.size(); ++i) {  
  66.     blob_dims[i] = dims[i];  
  67.   }  
  68.   blob->Reshape(blob_dims);  
  69. }  
  70.   
  71. // float类型的获取数据维度和信息的包裹函数  
  72. template <>  
  73. void hdf5_load_nd_dataset<float>(hid_t file_id, const char* dataset_name_,  
  74.         int min_dim, int max_dim, Blob<float>* blob) {  
  75.   hdf5_load_nd_dataset_helper(file_id, dataset_name_, min_dim, max_dim, blob);  
  76.   herr_t status = H5LTread_dataset_float(  
  77.     file_id, dataset_name_, blob->mutable_cpu_data());  
  78.   CHECK_GE(status, 0) << "Failed to read float dataset " << dataset_name_;  
  79. }  
  80.   
  81. // double类型的获取数据维度和信息的包裹函数  
  82. template <>  
  83. void hdf5_load_nd_dataset<double>(hid_t file_id, const char* dataset_name_,  
  84.         int min_dim, int max_dim, Blob<double>* blob) {  
  85.   hdf5_load_nd_dataset_helper(file_id, dataset_name_, min_dim, max_dim, blob);  
  86.   herr_t status = H5LTread_dataset_double(  
  87.     file_id, dataset_name_, blob->mutable_cpu_data());  
  88.   CHECK_GE(status, 0) << "Failed to read double dataset " << dataset_name_;  
  89. }  
  90.   
  91.   
  92. // 存放float类型到hdf5文件  
  93. template <>  
  94. void hdf5_save_nd_dataset<float>(  
  95.     const hid_t file_id, const string& dataset_name, const Blob<float>& blob,  
  96.     bool write_diff) {  
  97.   // blob信息放到dims  
  98.   int num_axes = blob.num_axes();  
  99.   hsize_t *dims = new hsize_t[num_axes];  
  100.   for (int i = 0; i < num_axes; ++i) {  
  101.     dims[i] = blob.shape(i);  
  102.   }  
  103.   
  104.   // 获取数据指针  
  105.   const float* data;  
  106.   if (write_diff) {  
  107.     data = blob.cpu_diff();  
  108.   } else {  
  109.     data = blob.cpu_data();  
  110.   }  
  111.   
  112.   // 存放数据到hdf5  
  113.   herr_t status = H5LTmake_dataset_float(  
  114.       file_id, dataset_name.c_str(), num_axes, dims, data);  
  115.   CHECK_GE(status, 0) << "Failed to make float dataset " << dataset_name;  
  116.   delete[] dims;  
  117. }  
  118.   
  119. // 存放double类型到hdf5文件  
  120. template <>  
  121. void hdf5_save_nd_dataset<double>(  
  122.     hid_t file_id, const string& dataset_name, const Blob<double>& blob,  
  123.     bool write_diff) {  
  124.   int num_axes = blob.num_axes();  
  125.   hsize_t *dims = new hsize_t[num_axes];  
  126.   for (int i = 0; i < num_axes; ++i) {  
  127.     dims[i] = blob.shape(i);  
  128.   }  
  129.   const double* data;  
  130.   if (write_diff) {  
  131.     data = blob.cpu_diff();  
  132.   } else {  
  133.     data = blob.cpu_data();  
  134.   }  
  135.   herr_t status = H5LTmake_dataset_double(  
  136.       file_id, dataset_name.c_str(), num_axes, dims, data);  
  137.   CHECK_GE(status, 0) << "Failed to make double dataset " << dataset_name;  
  138.   delete[] dims;  
  139. }  
  140.   
  141. // 读取string到字符串  
  142. string hdf5_load_string(hid_t loc_id, const string& dataset_name) {  
  143.   // Get size of dataset  
  144.   size_t size;  
  145.   H5T_class_t class_;  
  146.   herr_t status = \  
  147.     H5LTget_dataset_info(loc_id, dataset_name.c_str(), NULL, &class_, &size);  
  148.   CHECK_GE(status, 0) << "Failed to get dataset info for " << dataset_name;  
  149.   char *buf = new char[size];  
  150.   status = H5LTread_dataset_string(loc_id, dataset_name.c_str(), buf);  
  151.   CHECK_GE(status, 0)  
  152.     << "Failed to load int dataset with name " << dataset_name;  
  153.   string val(buf);  
  154.   delete[] buf;  
  155.   return val;  
  156. }  
  157.   
  158. // 保存string到字符串  
  159. void hdf5_save_string(hid_t loc_id, const string& dataset_name,  
  160.                       const string& s) {  
  161.   herr_t status = \  
  162.     H5LTmake_dataset_string(loc_id, dataset_name.c_str(), s.c_str());  
  163.   CHECK_GE(status, 0)  
  164.     << "Failed to save string dataset with name " << dataset_name;  
  165. }  
  166.   
  167. // 载入int类型  
  168. int hdf5_load_int(hid_t loc_id, const string& dataset_name) {  
  169.   int val;  
  170.   herr_t status = H5LTread_dataset_int(loc_id, dataset_name.c_str(), &val);  
  171.   CHECK_GE(status, 0)  
  172.     << "Failed to load int dataset with name " << dataset_name;  
  173.   return val;  
  174. }  
  175.   
  176. // 存储int类型  
  177. void hdf5_save_int(hid_t loc_id, const string& dataset_name, int i) {  
  178.   hsize_t one = 1;  
  179.   herr_t status = \  
  180.     H5LTmake_dataset_int(loc_id, dataset_name.c_str(), 1, &one, &i);  
  181.   CHECK_GE(status, 0)  
  182.     << "Failed to save int dataset with name " << dataset_name;  
  183. }  
  184.   
  185. // 获取链接数  
  186. int hdf5_get_num_links(hid_t loc_id) {  
  187.   H5G_info_t info;  
  188.   herr_t status = H5Gget_info(loc_id, &info);  
  189.   CHECK_GE(status, 0) << "Error while counting HDF5 links.";  
  190.   return info.nlinks;  
  191. }  
  192.   
  193. // 通过名字找到索引  
  194. string hdf5_get_name_by_idx(hid_t loc_id, int idx) {  
  195.   ssize_t str_size = H5Lget_name_by_idx(  
  196.       loc_id, ".", H5_INDEX_NAME, H5_ITER_NATIVE, idx, NULL, 0, H5P_DEFAULT);  
  197.   CHECK_GE(str_size, 0) << "Error retrieving HDF5 dataset at index " << idx;  
  198.   char *c_str = new char[str_size+1];  
  199.   ssize_t status = H5Lget_name_by_idx(  
  200.       loc_id, ".", H5_INDEX_NAME, H5_ITER_NATIVE, idx, c_str, str_size+1,  
  201.       H5P_DEFAULT);  
  202.   CHECK_GE(status, 0) << "Error retrieving HDF5 dataset at index " << idx;  
  203.   string result(c_str);  
  204.   delete[] c_str;  
  205.   return result;  
  206. }  
  207.   
  208. }  // namespace caffe  
  209.   
  210. 给出具体实现:  
  211. /* 
  212. TODO: 
  213. - load file in a separate thread ("prefetch") 
  214. - can be smarter about the memcpy call instead of doing it row-by-row 
  215.   :: use util functions caffe_copy, and Blob->offset() 
  216.   :: don't forget to update hdf5_daa_layer.cu accordingly 
  217. - add ability to shuffle filenames if flag is set 
  218. */  
  219. #include <fstream>  // NOLINT(readability/streams)  
  220. #include <string>  
  221. #include <vector>  
  222.   
  223. #include "hdf5.h"  
  224. #include "hdf5_hl.h"  
  225. #include "stdint.h"  
  226.   
  227. #include "caffe/data_layers.hpp"  
  228. #include "caffe/layer.hpp"  
  229. #include "caffe/util/hdf5.hpp"  
  230.   
  231. namespace caffe {  
  232.   
  233. template <typename Dtype>  
  234. HDF5DataLayer<Dtype>::~HDF5DataLayer<Dtype>() { }  
  235.   
  236. // Load data and label from HDF5 filename into the class property blobs.  
  237. // 读取HDF5文件数据到hdf_blobs  
  238. template <typename Dtype>  
  239. void HDF5DataLayer<Dtype>::LoadHDF5FileData(const char* filename) {  
  240.   DLOG(INFO) << "Loading HDF5 file: " << filename;  
  241.   // 打开文件  
  242.   hid_t file_id = H5Fopen(filename, H5F_ACC_RDONLY, H5P_DEFAULT);  
  243.   if (file_id < 0) {  
  244.     LOG(FATAL) << "Failed opening HDF5 file: " << filename;  
  245.   }  
  246.   
  247.   int top_size = this->layer_param_.top_size();  
  248.   hdf_blobs_.resize(top_size);  
  249.   
  250.   const int MIN_DATA_DIM = 1;  
  251.   const int MAX_DATA_DIM = INT_MAX;  
  252.   
  253.   for (int i = 0; i < top_size; ++i) {  
  254.     hdf_blobs_[i] = shared_ptr<Blob<Dtype> >(new Blob<Dtype>());  
  255.     // message LayerParameter {  
  256.     // optional string name = 1; // the layer name  
  257.     // optional string type = 2; // the layer type  
  258.     // repeated string bottom = 3; // the name of each bottom blob  
  259.     // repeated string top = 4; // the name of each top blob  
  260.     hdf5_load_nd_dataset(file_id, this->layer_param_.top(i).c_str(),  
  261.         MIN_DATA_DIM, MAX_DATA_DIM, hdf_blobs_[i].get());  
  262.   }  
  263.   
  264.   herr_t status = H5Fclose(file_id);  
  265.   CHECK_GE(status, 0) << "Failed to close HDF5 file: " << filename;  
  266.   
  267.   // MinTopBlobs==1 guarantees at least one top blob  
  268.   CHECK_GE(hdf_blobs_[0]->num_axes(), 1) << "Input must have at least 1 axis.";  
  269.   const int num = hdf_blobs_[0]->shape(0);  
  270.   for (int i = 1; i < top_size; ++i) {  
  271.     CHECK_EQ(hdf_blobs_[i]->shape(0), num);  
  272.   }  
  273.   // Default to identity permutation.  
  274.   data_permutation_.clear();  
  275.   data_permutation_.resize(hdf_blobs_[0]->shape(0));  
  276.   for (int i = 0; i < hdf_blobs_[0]->shape(0); i++)  
  277.     data_permutation_[i] = i;  
  278.   
  279.   // Shuffle if needed.  
  280.   // 将数据索引映射表进行shuffle  
  281.   if (this->layer_param_.hdf5_data_param().shuffle()) {  
  282.     std::random_shuffle(data_permutation_.begin(), data_permutation_.end());  
  283.     DLOG(INFO) << "Successully loaded " << hdf_blobs_[0]->shape(0)  
  284.                << " rows (shuffled)";  
  285.   } else {  
  286.     DLOG(INFO) << "Successully loaded " << hdf_blobs_[0]->shape(0) << " rows";  
  287.   }  
  288. }  
  289.   
  290. // 主要的功能就是读取HDF5文件,并且设置top blob的形状  
  291. template <typename Dtype>  
  292. void HDF5DataLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,  
  293.       const vector<Blob<Dtype>*>& top) {  
  294.   // Refuse transformation parameters since HDF5 is totally generic.  
  295.   CHECK(!this->layer_param_.has_transform_param()) <<  
  296.       this->type() << " does not transform data.";  
  297.   // Read the source to parse the filenames.  
  298.   // 读取HDF列表文件  
  299.   const string& source = this->layer_param_.hdf5_data_param().source();  
  300.   LOG(INFO) << "Loading list of HDF5 filenames from: " << source;  
  301.   hdf_filenames_.clear();  
  302.   std::ifstream source_file(source.c_str());  
  303.   if (source_file.is_open()) {  
  304.     std::string line;  
  305.     while (source_file >> line) {  
  306.       hdf_filenames_.push_back(line);  
  307.     }  
  308.   } else {  
  309.     LOG(FATAL) << "Failed to open source file: " << source;  
  310.   }  
  311.   source_file.close();  
  312.   num_files_ = hdf_filenames_.size();  
  313.   current_file_ = 0;  
  314.   LOG(INFO) << "Number of HDF5 files: " << num_files_;  
  315.   CHECK_GE(num_files_, 1) << "Must have at least 1 HDF5 filename listed in "  
  316.     << source;  
  317.   
  318.   file_permutation_.clear();  
  319.   file_permutation_.resize(num_files_);  
  320.   // 文件名字是否shuffle  
  321.   // Default to identity permutation.  
  322.   for (int i = 0; i < num_files_; i++) {  
  323.     file_permutation_[i] = i;  
  324.   }  
  325.   
  326.   // Shuffle if needed.  
  327.   if (this->layer_param_.hdf5_data_param().shuffle()) {  
  328.     std::random_shuffle(file_permutation_.begin(), file_permutation_.end());  
  329.   }  
  330.   
  331.   // Load the first HDF5 file and initialize the line counter.  
  332.   // 从给定的文件名列表中的第一个文件名读取数据到hdf_blobs  
  333.   LoadHDF5FileData(hdf_filenames_[file_permutation_[current_file_]].c_str());  
  334.   // 设置行指针  
  335.   current_row_ = 0;  
  336.   
  337.   // Reshape blobs.  
  338.   // 根据读取的hdf_blobs形状改变top的形状  
  339.   const int batch_size = this->layer_param_.hdf5_data_param().batch_size();  
  340.   const int top_size = this->layer_param_.top_size();  
  341.   vector<int> top_shape;  
  342.   for (int i = 0; i < top_size; ++i) {  
  343.     top_shape.resize(hdf_blobs_[i]->num_axes());  
  344.     top_shape[0] = batch_size;  
  345.     for (int j = 1; j < top_shape.size(); ++j) {  
  346.       top_shape[j] = hdf_blobs_[i]->shape(j);  
  347.     }  
  348.     top[i]->Reshape(top_shape);  
  349.   }  
  350. }  
  351.   
  352. template <typename Dtype>  
  353. void HDF5DataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,  
  354.       const vector<Blob<Dtype>*>& top) {  
  355.   const int batch_size = this->layer_param_.hdf5_data_param().batch_size();  
  356.   for (int i = 0; i < batch_size; ++i, ++current_row_) {  
  357.       // 因为SetUp里面已经读取了第一个文件的数据了  
  358.     if (current_row_ == hdf_blobs_[0]->shape(0)) {  
  359.       if (num_files_ > 1) {// 如果文件数目大于1  
  360.         ++current_file_;  
  361.         // 如果current_file是最后一个文件的索引编号则  
  362.         if (current_file_ == num_files_) {  
  363.           current_file_ = 0;// 重置  
  364.           // 打乱文件索引,再来一遍  
  365.           if (this->layer_param_.hdf5_data_param().shuffle()) {  
  366.             std::random_shuffle(file_permutation_.begin(),  
  367.                                 file_permutation_.end());  
  368.           }  
  369.           DLOG(INFO) << "Looping around to first file.";  
  370.         }  
  371.         // 读取数据到hdf_blobs  
  372.         LoadHDF5FileData(  
  373.             hdf_filenames_[file_permutation_[current_file_]].c_str());  
  374.       }// end of if (current_row_  
  375.       current_row_ = 0;  
  376.       // 打乱数据顺序索引  
  377.       if (this->layer_param_.hdf5_data_param().shuffle())  
  378.         std::random_shuffle(data_permutation_.begin(), data_permutation_.end());  
  379.     }  
  380.     // 复制数据到top  
  381.     for (int j = 0; j < this->layer_param_.top_size(); ++j) {  
  382.       int data_dim = top[j]->count() / top[j]->shape(0);  
  383.       caffe_copy(data_dim,  
  384.           &hdf_blobs_[j]->cpu_data()[data_permutation_[current_row_]  
  385.             * data_dim], &top[j]->mutable_cpu_data()[i * data_dim]);  
  386.     }  
  387.   }  
  388. }  
  389.   
  390. #ifdef CPU_ONLY  
  391. STUB_GPU_FORWARD(HDF5DataLayer, Forward);  
  392. #endif  
  393.   
  394. INSTANTIATE_CLASS(HDF5DataLayer);  
  395. REGISTER_LAYER_CLASS(HDF5Data);  
  396.   
  397. }  // namespace caffe  

(6)HDF5OutputLayer类的定义以及实现如下:

HDF5输出层主要就是将传递过来的数据存储到HDF5文件,并没有向前传播数据啥的,也没有反传,仅仅是将前一层传输过来的bottom存储到文件。
HDF5输出层的定义:
[cpp]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. /** 
  2.  * @brief Write blobs to disk as HDF5 files. 
  3.  * 
  4.  * TODO(dox): thorough documentation for Forward and proto params. 
  5.  * 将数据写入到HDF5文件 
  6.  */  
  7. template <typename Dtype>  
  8. class HDF5OutputLayer : public Layer<Dtype> {  
  9.  public:  
  10.   explicit HDF5OutputLayer(const LayerParameter& param)  
  11.       : Layer<Dtype>(param), file_opened_(false) {}  
  12.   virtual ~HDF5OutputLayer();  
  13.   virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,  
  14.       const vector<Blob<Dtype>*>& top);  
  15.   // Data layers should be shared by multiple solvers in parallel  
  16.   virtual inline bool ShareInParallel() const { return true; }  
  17.   // Data layers have no bottoms, so reshaping is trivial.  
  18.   virtual void Reshape(const vector<Blob<Dtype>*>& bottom,  
  19.       const vector<Blob<Dtype>*>& top) {}  
  20.   
  21.   virtual inline const char* type() const { return "HDF5Output"; }  
  22.   // TODO: no limit on the number of blobs  
  23.   virtual inline int ExactNumBottomBlobs() const { return 2; }  
  24.   virtual inline int ExactNumTopBlobs() const { return 0; }  
  25.   
  26.   inline std::string file_name() const { return file_name_; }  
  27.   
  28.  protected:  
  29.   // HDF5输出层不前向传也不反向传,只是将前一层传递过来的数据写入HDF5文件  
  30.   virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,  
  31.       const vector<Blob<Dtype>*>& top);  
  32.   virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,  
  33.       const vector<Blob<Dtype>*>& top);  
  34.   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,  
  35.       const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);  
  36.   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,  
  37.       const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);  
  38.   // 将bottom的数据存储到文件  
  39.   virtual void SaveBlobs();  
  40.   
  41.   bool file_opened_;  
  42.   std::string file_name_;  
  43.   hid_t file_id_;  
  44.   Blob<Dtype> data_blob_;  
  45.   Blob<Dtype> label_blob_;  
  46. };
HDF5输出层的实现如下:
[cpp]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. #include <vector>  
  2.   
  3. #include "hdf5.h"  
  4. #include "hdf5_hl.h"  
  5.   
  6. #include "caffe/blob.hpp"  
  7. #include "caffe/common.hpp"  
  8. #include "caffe/layer.hpp"  
  9. #include "caffe/util/hdf5.hpp"  
  10. #include "caffe/vision_layers.hpp"  
  11.   
  12. namespace caffe {  
  13.   
  14. template <typename Dtype>  
  15. void HDF5OutputLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,  
  16.     const vector<Blob<Dtype>*>& top) {  
  17.   // 参数文件中的文件名  
  18.   file_name_ = this->layer_param_.hdf5_output_param().file_name();  
  19.   // 打开文件  
  20.   file_id_ = H5Fcreate(file_name_.c_str(), H5F_ACC_TRUNC, H5P_DEFAULT,  
  21.                        H5P_DEFAULT);  
  22.   CHECK_GE(file_id_, 0) << "Failed to open HDF5 file" << file_name_;  
  23.   file_opened_ = true;// 设置文件打开标志  
  24. }  
  25.   
  26. template <typename Dtype>  
  27. HDF5OutputLayer<Dtype>::~HDF5OutputLayer<Dtype>() {  
  28.   if (file_opened_) {  
  29.     herr_t status = H5Fclose(file_id_);  
  30.     CHECK_GE(status, 0) << "Failed to close HDF5 file " << file_name_;  
  31.   }  
  32. }  
  33.   
  34. // 将blob存放到hdf5文件  
  35. // 数据和类标  
  36. template <typename Dtype>  
  37. void HDF5OutputLayer<Dtype>::SaveBlobs() {  
  38.   // TODO: no limit on the number of blobs  
  39.   LOG(INFO) << "Saving HDF5 file " << file_name_;  
  40.   CHECK_EQ(data_blob_.num(), label_blob_.num()) <<  
  41.       "data blob and label blob must have the same batch size";  
  42.   hdf5_save_nd_dataset(file_id_, HDF5_DATA_DATASET_NAME, data_blob_);  
  43.   hdf5_save_nd_dataset(file_id_, HDF5_DATA_LABEL_NAME, label_blob_);  
  44.   LOG(INFO) << "Successfully saved " << data_blob_.num() << " rows";  
  45. }  
  46.   
  47. // 实际上就是从bottom将输入过来的数据存放到hdf5文件  
  48. template <typename Dtype>  
  49. void HDF5OutputLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,  
  50.       const vector<Blob<Dtype>*>& top) {  
  51.   CHECK_GE(bottom.size(), 2);  
  52.   CHECK_EQ(bottom[0]->num(), bottom[1]->num());  
  53.   // 改变data_blob_的形状以及label_blob_的形状  
  54.   data_blob_.Reshape(bottom[0]->num(), bottom[0]->channels(),  
  55.                      bottom[0]->height(), bottom[0]->width());  
  56.   label_blob_.Reshape(bottom[1]->num(), bottom[1]->channels(),  
  57.                      bottom[1]->height(), bottom[1]->width());  
  58.   const int data_datum_dim = bottom[0]->count() / bottom[0]->num();  
  59.   const int label_datum_dim = bottom[1]->count() / bottom[1]->num();  
  60.   
  61.   // 从bottom[0]和[1]复制到data_blob_和label_blob_  
  62.   for (int i = 0; i < bottom[0]->num(); ++i) {  
  63.     caffe_copy(data_datum_dim, &bottom[0]->cpu_data()[i * data_datum_dim],  
  64.         &data_blob_.mutable_cpu_data()[i * data_datum_dim]);  
  65.     caffe_copy(label_datum_dim, &bottom[1]->cpu_data()[i * label_datum_dim],  
  66.         &label_blob_.mutable_cpu_data()[i * label_datum_dim]);  
  67.   }  
  68.   // 存放到文件  
  69.   SaveBlobs();  
  70. }  
  71.   
  72. // 不反传  
  73. template <typename Dtype>  
  74. void HDF5OutputLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,  
  75.       const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {  
  76.   return;  
  77. }  
  78.   
  79. #ifdef CPU_ONLY  
  80. STUB_GPU(HDF5OutputLayer);  
  81. #endif  
  82.   
  83. INSTANTIATE_CLASS(HDF5OutputLayer);  
  84. REGISTER_LAYER_CLASS(HDF5Output);  
  85.   
  86. }  // namespace caffe  

(7)ImageDataLayer类的定义以及实现如下:

该层主要的功能是,从参数中给定的列表文件读取图像列表以及类标,读取图像的时候会进行预处理,然后前向传。
首先给出该层的参数的定义:
[plain]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. message ImageDataParameter {  
  2.   // Specify the data source.  
  3.   // 列表文件包含图像的路径和对应的类标,以空格隔开  
  4.   optional string source = 1;  
  5.   // Specify the batch size.  
  6.   // 批大小  
  7.   optional uint32 batch_size = 4 [default = 1];  
  8.   // The rand_skip variable is for the data layer to skip a few data points  
  9.   // to avoid all asynchronous sgd clients to start at the same point. The skip  
  10.   // point would be set as rand_skip * rand(0,1). Note that rand_skip should not  
  11.   // be larger than the number of keys in the database.  
  12.   // 随机调过一些数据  
  13.   optional uint32 rand_skip = 7 [default = 0];  
  14.   // 是否需要打乱数据顺序  
  15.   // Whether or not ImageLayer should shuffle the list of files at every epoch.  
  16.   optional bool shuffle = 8 [default = false];  
  17.   // It will also resize images if new_height or new_width are not zero.  
  18.   // 将图像resize到新的高度的宽度  
  19.   optional uint32 new_height = 9 [default = 0];  
  20.   optional uint32 new_width = 10 [default = 0];  
  21.   // Specify if the images are color or gray  
  22.   // 图像是否是彩色的  
  23.   optional bool is_color = 11 [default = true];  
  24.   // DEPRECATED. See TransformationParameter. For data pre-processing, we can do  
  25.   // simple scaling and subtracting the data mean, if provided. Note that the  
  26.   // mean subtraction is always carried out before scaling.  
  27.   // 是否需要对图像进行缩放  
  28.   optional float scale = 2 [default = 1];  
  29.   // 均值文件  
  30.   optional string mean_file = 3;  
  31.   // DEPRECATED. See TransformationParameter. Specify if we would like to randomly  
  32.   // crop an image.  
  33.   // crop的大小  
  34.   optional uint32 crop_size = 5 [default = 0];  
  35.   // DEPRECATED. See TransformationParameter. Specify if we want to randomly mirror  
  36.   // data.  
  37.   // 是否需要对图像进行镜像,所谓镜像就是左边复制到右边  
  38.   optional bool mirror = 6 [default = false];  
  39.   // 图像的根目录  
  40.   optional string root_folder = 12 [default = ""];  
  41. }  
首先给出类的定义:
[cpp]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. /** 
  2.  * @brief Provides data to the Net from image files. 
  3.  * 
  4.  * TODO(dox): thorough documentation for Forward and proto params. 
  5.  * 从图像文件中读取数据,这个应该比较常用 
  6.  * 从一个列表文件读取图像的路径和类标,列表文件的路径在层参数的配置文件中指定 
  7.  */  
  8. template <typename Dtype>  
  9. class ImageDataLayer : public BasePrefetchingDataLayer<Dtype> {  
  10.  public:  
  11.   explicit ImageDataLayer(const LayerParameter& param)  
  12.       : BasePrefetchingDataLayer<Dtype>(param) {}  
  13.   virtual ~ImageDataLayer();  
  14.   virtual void DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,  
  15.       const vector<Blob<Dtype>*>& top);  
  16.   
  17.   virtual inline const char* type() const { return "ImageData"; }  
  18.   virtual inline int ExactNumBottomBlobs() const { return 0; }  
  19.   virtual inline int ExactNumTopBlobs() const { return 2; }  
  20.   
  21.  protected:  
  22.   shared_ptr<Caffe::RNG> prefetch_rng_;  
  23.   // 对图像索引进行打乱  
  24.   virtual void ShuffleImages();  
  25.   virtual void load_batch(Batch<Dtype>* batch);  
  26.   
  27.   // 图像路径和类标的vector  
  28.   vector<std::pair<std::string, int> > lines_;  
  29.   // 随机跳过的图像的个数,也就是调过之后的一开始的图像的id  
  30.   int lines_id_;  
  31. };  
下面给出具体的实现细节:
[cpp]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. #ifdef USE_OPENCV  
  2. #include <opencv2/core/core.hpp>  
  3.   
  4. #include <fstream>  // NOLINT(readability/streams)  
  5. #include <iostream>  // NOLINT(readability/streams)  
  6. #include <string>  
  7. #include <utility>  
  8. #include <vector>  
  9.   
  10. #include "caffe/data_layers.hpp"  
  11. #include "caffe/layer.hpp"  
  12. #include "caffe/util/benchmark.hpp"  
  13. #include "caffe/util/io.hpp"  
  14. #include "caffe/util/math_functions.hpp"  
  15. #include "caffe/util/rng.hpp"  
  16.   
  17. namespace caffe {  
  18.   
  19. template <typename Dtype>  
  20. ImageDataLayer<Dtype>::~ImageDataLayer<Dtype>() {  
  21.   this->StopInternalThread();  
  22. }  
  23.   
  24. template <typename Dtype>  
  25. void ImageDataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,  
  26.       const vector<Blob<Dtype>*>& top) {  
  27.   // 根据参数文件设置参数  
  28.   // 图像的高度、宽度、是否彩色图像、图像目录  
  29.   const int new_height = this->layer_param_.image_data_param().new_height();  
  30.   const int new_width  = this->layer_param_.image_data_param().new_width();  
  31.   const bool is_color  = this->layer_param_.image_data_param().is_color();  
  32.   string root_folder = this->layer_param_.image_data_param().root_folder();  
  33.   
  34.   // 当前只支持读取高度和宽度同样大小的图像  
  35.   CHECK((new_height == 0 && new_width == 0) ||  
  36.       (new_height > 0 && new_width > 0)) << "Current implementation requires "  
  37.       "new_height and new_width to be set at the same time.";  
  38.   
  39.   // Read the file with filenames and labels  
  40.   // 读取存放图像文件名和类标的列表文件  
  41.   const string& source = this->layer_param_.image_data_param().source();  
  42.   LOG(INFO) << "Opening file " << source;  
  43.   std::ifstream infile(source.c_str());  
  44.   string filename;  
  45.   int label;  
  46.   // lines_存放文件名和类标的pair  
  47.   while (infile >> filename >> label) {  
  48.     lines_.push_back(std::make_pair(filename, label));  
  49.   }  
  50.   
  51.   // 是否需要打乱文件的顺序  
  52.   if (this->layer_param_.image_data_param().shuffle()) {  
  53.     // randomly shuffle data  
  54.     LOG(INFO) << "Shuffling data";  
  55.     const unsigned int prefetch_rng_seed = caffe_rng_rand();  
  56.     prefetch_rng_.reset(new Caffe::RNG(prefetch_rng_seed));  
  57.     ShuffleImages();  
  58.   }  
  59.   LOG(INFO) << "A total of " << lines_.size() << " images.";  
  60.   
  61.   // 随机跳过的图像,调过的图像个数在[0, rand_skip-1]之间  
  62.   lines_id_ = 0;  
  63.   // Check if we would need to randomly skip a few data points  
  64.   // 如果参数中的rand_skip大于1,则随机跳过[0,rand_skip-1]个图片  
  65.   //  
  66.   if (this->layer_param_.image_data_param().rand_skip()) {  
  67.     unsigned int skip = caffe_rng_rand() %  
  68.         this->layer_param_.image_data_param().rand_skip();  
  69.     LOG(INFO) << "Skipping first " << skip << " data points.";  
  70.     CHECK_GT(lines_.size(), skip) << "Not enough points to skip";  
  71.     lines_id_ = skip;  
  72.   }  
  73.   // Read an image, and use it to initialize the top blob.  
  74.   // 读取文件名到Mat  
  75.   cv::Mat cv_img = ReadImageToCVMat(root_folder + lines_[lines_id_].first,  
  76.                                     new_height, new_width, is_color);  
  77.   CHECK(cv_img.data) << "Could not load " << lines_[lines_id_].first;  
  78.   // Use data_transformer to infer the expected blob shape from a cv_image.  
  79.   // 对数据的形状进行推断  
  80.   vector<int> top_shape = this->data_transformer_->InferBlobShape(cv_img);  
  81.   // 设置transformed_data_的形状  
  82.   this->transformed_data_.Reshape(top_shape);  
  83.   // Reshape prefetch_data and top[0] according to the batch_size.  
  84.   // 设置batch_size  
  85.   const int batch_size = this->layer_param_.image_data_param().batch_size();  
  86.   CHECK_GT(batch_size, 0) << "Positive batch size required";  
  87.   top_shape[0] = batch_size;  
  88.   // 设置预取数组中数据的形状  
  89.   for (int i = 0; i < this->PREFETCH_COUNT; ++i) {  
  90.     this->prefetch_[i].data_.Reshape(top_shape);  
  91.   }  
  92.   // 设置输出的数据的形状  
  93.   top[0]->Reshape(top_shape);  
  94.   
  95.   LOG(INFO) << "output data size: " << top[0]->num() << ","  
  96.       << top[0]->channels() << "," << top[0]->height() << ","  
  97.       << top[0]->width();  
  98.   // label  
  99.   // 设置输出的类标的形状  
  100.   vector<int> label_shape(1, batch_size);  
  101.   top[1]->Reshape(label_shape);  
  102.   // 设置预取数组中类标的形状  
  103.   for (int i = 0; i < this->PREFETCH_COUNT; ++i) {  
  104.     this->prefetch_[i].label_.Reshape(label_shape);  
  105.   }  
  106. }  
  107.   
  108. // 产生打乱图像顺序的数组  
  109. template <typename Dtype>  
  110. void ImageDataLayer<Dtype>::ShuffleImages() {  
  111.   caffe::rng_t* prefetch_rng =  
  112.       static_cast<caffe::rng_t*>(prefetch_rng_->generator());  
  113.   shuffle(lines_.begin(), lines_.end(), prefetch_rng);  
  114. }  
  115.   
  116. // This function is called on prefetch thread  
  117. // 该函数会被内部的线程调用  
  118. template <typename Dtype>  
  119. void ImageDataLayer<Dtype>::load_batch(Batch<Dtype>* batch) {  
  120.   CPUTimer batch_timer;  
  121.   batch_timer.Start();  
  122.   double read_time = 0;  
  123.   double trans_time = 0;  
  124.   CPUTimer timer;  
  125.   CHECK(batch->data_.count());  
  126.   CHECK(this->transformed_data_.count());  
  127.   // 获取层参数,具体参见层参数的定义的解释  
  128.   ImageDataParameter image_data_param = this->layer_param_.image_data_param();  
  129.   const int batch_size = image_data_param.batch_size();  
  130.   const int new_height = image_data_param.new_height();  
  131.   const int new_width = image_data_param.new_width();  
  132.   const bool is_color = image_data_param.is_color();  
  133.   string root_folder = image_data_param.root_folder();  
  134.   
  135.   // Reshape according to the first image of each batch  
  136.   // on single input batches allows for inputs of varying dimension.  
  137.   // 读取跳过之后的第一幅图像,然后根据该图像设置相撞  
  138.   cv::Mat cv_img = ReadImageToCVMat(root_folder + lines_[lines_id_].first,  
  139.       new_height, new_width, is_color);  
  140.   CHECK(cv_img.data) << "Could not load " << lines_[lines_id_].first;  
  141.   // Use data_transformer to infer the expected blob shape from a cv_img.  
  142.   // 推断图像形状  
  143.   vector<int> top_shape = this->data_transformer_->InferBlobShape(cv_img);  
  144.   // 设置transformed_data_形状  
  145.   this->transformed_data_.Reshape(top_shape);  
  146.   // Reshape batch according to the batch_size.  
  147.   // 设置batch_size  
  148.   top_shape[0] = batch_size;  
  149.   batch->data_.Reshape(top_shape);  
  150.   
  151.   Dtype* prefetch_data = batch->data_.mutable_cpu_data();  
  152.   Dtype* prefetch_label = batch->label_.mutable_cpu_data();  
  153.   
  154.   // datum scales  
  155.   // 读取一批图像,并进行预处理  
  156.   const int lines_size = lines_.size();  
  157.   for (int item_id = 0; item_id < batch_size; ++item_id) {  
  158.     // get a blob  
  159.     timer.Start();  
  160.     CHECK_GT(lines_size, lines_id_);  
  161.     cv::Mat cv_img = ReadImageToCVMat(root_folder + lines_[lines_id_].first,  
  162.         new_height, new_width, is_color);  
  163.     CHECK(cv_img.data) << "Could not load " << lines_[lines_id_].first;  
  164.     read_time += timer.MicroSeconds();  
  165.     timer.Start();  
  166.     // Apply transformations (mirror, crop...) to the image  
  167.     // 进行预处理  
  168.   
  169.     // 根据图像的批次获得图像数据的偏移量  
  170.     int offset = batch->data_.offset(item_id);  
  171.     // 设置图像数据的指针到transformed_data_  
  172.     this->transformed_data_.set_cpu_data(prefetch_data + offset);  
  173.     // 进行预处理  
  174.     this->data_transformer_->Transform(cv_img, &(this->transformed_data_));  
  175.     trans_time += timer.MicroSeconds();//统计预处理时间  
  176.   
  177.     // 复制类标到prefetch_label  
  178.     prefetch_label[item_id] = lines_[lines_id_].second;  
  179.     // go to the next iter  
  180.     lines_id_++;  
  181.     // 是否是图像目录中的最后一个图像  
  182.     if (lines_id_ >= lines_size) {  
  183.       // We have reached the end. Restart from the first.  
  184.       DLOG(INFO) << "Restarting data prefetching from start.";  
  185.       lines_id_ = 0;  
  186.       // 打乱图像索引的顺序  
  187.       if (this->layer_param_.image_data_param().shuffle()) {  
  188.         ShuffleImages();  
  189.       }  
  190.     }  
  191.   }  
  192.   batch_timer.Stop();  
  193.   DLOG(INFO) << "Prefetch batch: " << batch_timer.MilliSeconds() << " ms.";  
  194.   DLOG(INFO) << "     Read time: " << read_time / 1000 << " ms.";  
  195.   // 预处理时间  
  196.   DLOG(INFO) << "Transform time: " << trans_time / 1000 << " ms.";  
  197. }  
  198.   
  199. INSTANTIATE_CLASS(ImageDataLayer);  
  200. REGISTER_LAYER_CLASS(ImageData);  
  201.   
  202. }  // namespace caffe  
  203. #endif  // USE_OPENCV  

(8)MemoryDataLayer 类的定义以及实现如下:

该类主要就是对于读取好的Datum或者OpenCV读取的Mat的Vector进行预处理(图像的crop、scale等),然后前传。
首先给出该类的定义
[cpp]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. /** 
  2.  * @brief Provides data to the Net from memory. 
  3.  * 从内存中读取数据,这里指已经从数据文件或者图像文件中读取到了数据,然后输入到该层 
  4.  * TODO(dox): thorough documentation for Forward and proto params. 
  5.  */  
  6. template <typename Dtype>  
  7. class MemoryDataLayer : public BaseDataLayer<Dtype> {  
  8.  public:  
  9.   explicit MemoryDataLayer(const LayerParameter& param)  
  10.       : BaseDataLayer<Dtype>(param), has_new_data_(false) {}  
  11.   virtual void DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,  
  12.       const vector<Blob<Dtype>*>& top);  
  13.   
  14.   virtual inline const char* type() const { return "MemoryData"; }  
  15.   virtual inline int ExactNumBottomBlobs() const { return 0; }  
  16.   virtual inline int ExactNumTopBlobs() const { return 2; }  
  17.   
  18.   // 将内存中的数据加入added_data_和added_label_(数据和类标)  
  19.   virtual void AddDatumVector(const vector<Datum>& datum_vector);  
  20. #ifdef USE_OPENCV  
  21.   // 如果有opencv则将opencv读取到的Mat,并且将labels加入added_data_和added_label_(数据和类标)  
  22.   virtual void AddMatVector(const vector<cv::Mat>& mat_vector,  
  23.       const vector<int>& labels);  
  24. #endif  // USE_OPENCV  
  25.   
  26.   // Reset should accept const pointers, but can't, because the memory  
  27.   //  will be given to Blob, which is mutable  
  28.   // Reset函数实际上是将data、label、以及batchsize(n)设置到内部的变量里面去  
  29.   void Reset(Dtype* data, Dtype* label, int n);  
  30.   void set_batch_size(int new_size);  
  31.   
  32.   int batch_size() { return batch_size_; }  
  33.   int channels() { return channels_; }  
  34.   int height() { return height_; }  
  35.   int width() { return width_; }  
  36.   
  37.  protected:  
  38.   virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,  
  39.       const vector<Blob<Dtype>*>& top);  
  40.   
  41.   int batch_size_, channels_, height_, width_, size_;  
  42.   Dtype* data_;  
  43.   Dtype* labels_;  
  44.   // batch_size  
  45.   int n_;  
  46.   size_t pos_;  
  47.   // 内部的数据和类标  
  48.   Blob<Dtype> added_data_;  
  49.   Blob<Dtype> added_label_;  
  50.   // 是否有新的数据  
  51.   bool has_new_data_;  
  52. };  
下面给出具体的实现细节:
[cpp]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. #ifdef USE_OPENCV  
  2. #include <opencv2/core/core.hpp>  
  3. #endif  // USE_OPENCV  
  4.   
  5. #include <vector>  
  6.   
  7. #include "caffe/data_layers.hpp"  
  8. #include "caffe/layer.hpp"  
  9. #include "caffe/util/io.hpp"  
  10.   
  11. namespace caffe {  
  12.   
  13. template <typename Dtype>  
  14. void MemoryDataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,  
  15.      const vector<Blob<Dtype>*>& top) {  
  16.   // 从参数文件获取参数  
  17.   batch_size_ = this->layer_param_.memory_data_param().batch_size();  
  18.   channels_ = this->layer_param_.memory_data_param().channels();  
  19.   height_ = this->layer_param_.memory_data_param().height();  
  20.   width_ = this->layer_param_.memory_data_param().width();  
  21.   size_ = channels_ * height_ * width_;  
  22.   CHECK_GT(batch_size_ * size_, 0) <<  
  23.       "batch_size, channels, height, and width must be specified and"  
  24.       " positive in memory_data_param";  
  25.   // 设置top的形状  
  26.   vector<int> label_shape(1, batch_size_);  
  27.   top[0]->Reshape(batch_size_, channels_, height_, width_);  
  28.   top[1]->Reshape(label_shape);  
  29.   // 设置内部变量added_data_和added_label_的形状  
  30.   added_data_.Reshape(batch_size_, channels_, height_, width_);  
  31.   added_label_.Reshape(label_shape);  
  32.   data_ = NULL;  
  33.   labels_ = NULL;  
  34.   added_data_.cpu_data();  
  35.   added_label_.cpu_data();  
  36. }  
  37.   
  38. // 将Datum的vector放入到added_data_和added_label_  
  39. // 并进行预处理  
  40. template <typename Dtype>  
  41. void MemoryDataLayer<Dtype>::AddDatumVector(const vector<Datum>& datum_vector) {  
  42.   CHECK(!has_new_data_) <<  
  43.       "Can't add data until current data has been consumed.";  
  44.   size_t num = datum_vector.size();  
  45.   CHECK_GT(num, 0) << "There is no datum to add.";  
  46.   CHECK_EQ(num % batch_size_, 0) <<  
  47.       "The added data must be a multiple of the batch size.";  
  48.   // 改变形状  
  49.   added_data_.Reshape(num, channels_, height_, width_);  
  50.   added_label_.Reshape(num, 1, 1, 1);  
  51.   // Apply data transformations (mirror, scale, crop...)  
  52.   // 对数据进行预处理  
  53.   this->data_transformer_->Transform(datum_vector, &added_data_);  
  54.   // Copy Labels  
  55.   // 复制类标到top_label  
  56.   Dtype* top_label = added_label_.mutable_cpu_data();  
  57.   for (int item_id = 0; item_id < num; ++item_id) {  
  58.     top_label[item_id] = datum_vector[item_id].label();  
  59.   }  
  60.   // num_images == batch_size_  
  61.   Dtype* top_data = added_data_.mutable_cpu_data();  
  62.   // 将数据、类标以及数据个数设置到该类的内部变量  
  63.   Reset(top_data, top_label, num);  
  64.   // 设置标记为true  
  65.   has_new_data_ = true;  
  66. }  
  67.   
  68. // 如果定义OPENCV,则对数据进行处理存放到added_data_和added_label_  
  69. #ifdef USE_OPENCV  
  70. template <typename Dtype>  
  71. void MemoryDataLayer<Dtype>::AddMatVector(const vector<cv::Mat>& mat_vector,  
  72.     const vector<int>& labels) {  
  73.   size_t num = mat_vector.size();  
  74.   CHECK(!has_new_data_) <<  
  75.       "Can't add mat until current data has been consumed.";  
  76.   CHECK_GT(num, 0) << "There is no mat to add";  
  77.   CHECK_EQ(num % batch_size_, 0) <<  
  78.       "The added data must be a multiple of the batch size.";  
  79.   added_data_.Reshape(num, channels_, height_, width_);  
  80.   added_label_.Reshape(num, 1, 1, 1);  
  81.   // Apply data transformations (mirror, scale, crop...)  
  82.   // 预处理  
  83.   this->data_transformer_->Transform(mat_vector, &added_data_);  
  84.   // Copy Labels  
  85.   Dtype* top_label = added_label_.mutable_cpu_data();  
  86.   for (int item_id = 0; item_id < num; ++item_id) {  
  87.     top_label[item_id] = labels[item_id];  
  88.   }  
  89.   // num_images == batch_size_  
  90.   Dtype* top_data = added_data_.mutable_cpu_data();  
  91.   Reset(top_data, top_label, num);  
  92.   has_new_data_ = true;  
  93. }  
  94. #endif  // USE_OPENCV  
  95.   
  96. // 将数据和类标设置到内部的变量  
  97. // data_、labels_、n_  
  98. // 并且设置位置pos_=0  
  99. template <typename Dtype>  
  100. void MemoryDataLayer<Dtype>::Reset(Dtype* data, Dtype* labels, int n) {  
  101.   CHECK(data);  
  102.   CHECK(labels);  
  103.   CHECK_EQ(n % batch_size_, 0) << "n must be a multiple of batch size";  
  104.   // Warn with transformation parameters since a memory array is meant to  
  105.   // be generic and no transformations are done with Reset().  
  106.   if (this->layer_param_.has_transform_param()) {  
  107.     LOG(WARNING) << this->type() << " does not transform array data on Reset()";  
  108.   }  
  109.   data_ = data;  
  110.   labels_ = labels;  
  111.   n_ = n;// batch_size  
  112.   pos_ = 0;  
  113. }  
  114.   
  115. // 设置内内部变量added_data_和added_label_的批数  
  116. template <typename Dtype>  
  117. void MemoryDataLayer<Dtype>::set_batch_size(int new_size) {  
  118.   CHECK(!has_new_data_) <<  
  119.       "Can't change batch_size until current data has been consumed.";  
  120.   batch_size_ = new_size;  
  121.   added_data_.Reshape(batch_size_, channels_, height_, width_);  
  122.   added_label_.Reshape(batch_size_, 1, 1, 1);  
  123. }  
  124.   
  125. // 将内部变量added_data_和added_label_复制到top传递给下一层  
  126. template <typename Dtype>  
  127. void MemoryDataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,  
  128.       const vector<Blob<Dtype>*>& top) {  
  129.   CHECK(data_) << "MemoryDataLayer needs to be initalized by calling Reset";  
  130.   // 这里直接使用内部变量将数据复制到top[0]、将类标复制到top[1]  
  131.   top[0]->Reshape(batch_size_, channels_, height_, width_);  
  132.   top[1]->Reshape(batch_size_, 1, 1, 1);  
  133.   top[0]->set_cpu_data(data_ + pos_ * size_);  
  134.   top[1]->set_cpu_data(labels_ + pos_);  
  135.   pos_ = (pos_ + batch_size_) % n_;  
  136.   if (pos_ == 0)  
  137.     has_new_data_ = false;// 传过一次之后,就没有新数据啦  
  138. }  
  139.   
  140. INSTANTIATE_CLASS(MemoryDataLayer);  
  141. REGISTER_LAYER_CLASS(MemoryData);  
  142.   
  143. }  // namespace caffe  

(9)WindowDataLayer类的定义以及实现如下:

该类主要就是对于读取好的Datum或者OpenCV读取的Mat的Vector进行预处理(图像的crop、scale等),然后前传。
首先给出窗口数据文件的格式,便于自己训练
窗口文件的格式如下:
# 图像索引(举例:# 1就表示第一个图像,注意#号与数字之间有空格)
图像的路径
图像通道数
图像高度
图像宽度
窗口数目
类标,与前景目标的重叠率,x1,y1,x2,y2
注:x1,y1,x2,y2是窗口的左上和右下的坐标

为了理解的更清楚我这里举个例子:
# 1 /1.jpg 3 720 480 100 1 1 0 0 100 100 2 30 100 1500 1500
上述的例子表示一个编号为1的图像相对路径为/1.jpg,通道为3,高度为720
宽度为480,窗口数目为100,类标为1,与前景目标的重叠率为0.8,类标为1窗口的左上坐标为(0,0),右下坐标为(100,100)
类标为2的窗口的左上角坐标为(30,100),右下角的坐标为(1500,1500)。有多少窗口后面就这么继续写下去

接下来给出该层的参数:
[plain]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. message WindowDataParameter {  
  2.   // Specify the data source.  
  3.   // 装有窗口数据的列表文件  
  4.   optional string source = 1;  
  5.   // For data pre-processing, we can do simple scaling and subtracting the  
  6.   // data mean, if provided. Note that the mean subtraction is always carried  
  7.   // out before scaling.  
  8.   // 是否需要缩放图像中的像素值,注意哈这不是缩放图像的大小,是拿图像的像素值乘以这个  
  9.   optional float scale = 2 [default = 1];  
  10.   // 平均值文件路径  
  11.   optional string mean_file = 3;  
  12.   // Specify the batch size.  
  13.   // 批大小  
  14.   optional uint32 batch_size = 4;  
  15.   // Specify if we would like to randomly crop an image.  
  16.   // 随机crop的图像块的大小  
  17.   optional uint32 crop_size = 5 [default = 0];  
  18.   // Specify if we want to randomly mirror data.  
  19.   // 是否随机镜像图像  
  20.   optional bool mirror = 6 [default = false];  
  21.   // Foreground (object) overlap threshold  
  22.   // 前景重叠阈值  
  23.   optional float fg_threshold = 7 [default = 0.5];  
  24.   // Background (non-object) overlap threshold  
  25.   // 背景重叠阈值  
  26.   optional float bg_threshold = 8 [default = 0.5];  
  27.   // Fraction of batch that should be foreground objects  
  28.   // 每一批中有多少比例应该是前景(也就是是你要检测的物体)  
  29.   optional float fg_fraction = 9 [default = 0.25];  
  30.   // Amount of contextual padding to add around a window  
  31.   // (used only by the window_data_layer)  
  32.   // 是否需要在窗口周围padding  
  33.   optional uint32 context_pad = 10 [default = 0];  
  34.   // Mode for cropping out a detection window  
  35.   // warp: cropped window is warped to a fixed size and aspect ratio  
  36.   // square: the tightest square around the window is cropped  
  37.   // crop的模式,square还是warp  
  38.   optional string crop_mode = 11 [default = "warp"];  
  39.   // cache_images: will load all images in memory for faster access  
  40.   // 是否将文件缓冲到内存  
  41.   optional bool cache_images = 12 [default = false];  
  42.   // append root_folder to locate images  
  43.   // 图像文件根目录  
  44.   optional string root_folder = 13 [default = ""];  
  45. }  
我们给出该类的定义:
[cpp]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. /** 
  2.  * @brief Provides data to the Net from windows of images files, specified 
  3.  *        by a window data file. 
  4.  *  从图像文件的窗口获取数据,需要指定窗口数据文件 
  5.  * TODO(dox): thorough documentation for Forward and proto params. 
  6.  */  
  7. template <typename Dtype>  
  8. class WindowDataLayer : public BasePrefetchingDataLayer<Dtype> {  
  9.  public:  
  10.   explicit WindowDataLayer(const LayerParameter& param)  
  11.       : BasePrefetchingDataLayer<Dtype>(param) {}  
  12.   virtual ~WindowDataLayer();  
  13.   virtual void DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,  
  14.       const vector<Blob<Dtype>*>& top);  
  15.   
  16.   virtual inline const char* type() const { return "WindowData"; }  
  17.   virtual inline int ExactNumBottomBlobs() const { return 0; }  
  18.   virtual inline int ExactNumTopBlobs() const { return 2; }  
  19.   
  20.  protected:  
  21.   virtual unsigned int PrefetchRand();  
  22.   virtual void load_batch(Batch<Dtype>* batch);  
  23.   
  24.   shared_ptr<Caffe::RNG> prefetch_rng_;  
  25.   vector<std::pair<std::string, vector<int> > > image_database_;  
  26.   // 窗口类中所使用的窗口数据的枚举  
  27.   // 就是定义个vector<float>,然后里面按顺序存放下面这些类型的数据  
  28.   enum WindowField { IMAGE_INDEX, LABEL, OVERLAP, X1, Y1, X2, Y2, NUM };  
  29.   vector<vector<float> > fg_windows_;  
  30.   vector<vector<float> > bg_windows_;  
  31.   Blob<Dtype> data_mean_;  
  32.   vector<Dtype> mean_values_;  
  33.   bool has_mean_file_;  
  34.   bool has_mean_values_;  
  35.   bool cache_images_;  
  36.   vector<std::pair<std::string, Datum > > image_database_cache_;  
  37. };  
然后给出该类的实现
[cpp]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. #ifdef USE_OPENCV  
  2. #include <opencv2/highgui/highgui_c.h>  
  3. #include <stdint.h>  
  4.   
  5. #include <algorithm>  
  6. #include <map>  
  7. #include <string>  
  8. #include <utility>  
  9. #include <vector>  
  10.   
  11. #include "opencv2/core/core.hpp"  
  12. #include "opencv2/highgui/highgui.hpp"  
  13. #include "opencv2/imgproc/imgproc.hpp"  
  14.   
  15. #include "caffe/common.hpp"  
  16. #include "caffe/data_layers.hpp"  
  17. #include "caffe/layer.hpp"  
  18. #include "caffe/util/benchmark.hpp"  
  19. #include "caffe/util/io.hpp"  
  20. #include "caffe/util/math_functions.hpp"  
  21. #include "caffe/util/rng.hpp"  
  22.   
  23. // caffe.proto > LayerParameter > WindowDataParameter  
  24. //   'source' field specifies the window_file  
  25. //   'crop_size' indicates the desired warped size  
  26.   
  27. namespace caffe {  
  28.   
  29. template <typename Dtype>  
  30. WindowDataLayer<Dtype>::~WindowDataLayer<Dtype>() {  
  31.   this->StopInternalThread();  
  32. }  
  33.   
  34. // 读取窗口数据文件的信息,并设置各个数据结构的形状  
  35. template <typename Dtype>  
  36. void WindowDataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,  
  37.       const vector<Blob<Dtype>*>& top) {  
  38.   // LayerSetUp runs through the window_file and creates two structures  
  39.   // that hold windows: one for foreground (object) windows and one  
  40.   // for background (non-object) windows. We use an overlap threshold  
  41.   // to decide which is which.  
  42.   
  43.   // window_file format  
  44.   // repeated:  
  45.   //    # image_index  
  46.   //    img_path (abs path)  
  47.   //    channels  
  48.   //    height  
  49.   //    width  
  50.   //    num_windows  
  51.   //    class_index overlap x1 y1 x2 y2  
  52.   
  53.   // 窗口文件的格式如下:  
  54.   // # 图像索引(举例:# 1就表示第一个图像,注意#号与数字之间有空格)  
  55.   // 图像的路径  
  56.   // 图像通道数  
  57.   // 图像高度  
  58.   // 图像宽度  
  59.   // 窗口数目  
  60.   // 类标,overlap,x1,y1,x2,y2  
  61.   // 注:x1,y1,x2,y2是窗口的左上和右下的坐标  
  62.   // 我这里举个例子  
  63.   // # 1 /1.jpg 3 720 480 100 1 1 0 0 100 100  
  64.   // 上述的例子即使表示一个编号为1的图像相对路径为/1.jpg,通道为3,高度为720  
  65.   // 宽度为480,窗口数目为100,类标为1,overlap为1,窗口的左上坐标为(0,0),右下坐标为(100,100)  
  66.   
  67.   
  68.   LOG(INFO) << "Window data layer:" << std::endl  
  69.       << "  foreground (object) overlap threshold: "  
  70.       << this->layer_param_.window_data_param().fg_threshold() << std::endl  
  71.       << "  background (non-object) overlap threshold: "  
  72.       << this->layer_param_.window_data_param().bg_threshold() << std::endl  
  73.       << "  foreground sampling fraction: "  
  74.       << this->layer_param_.window_data_param().fg_fraction() << std::endl  
  75.       << "  cache_images: "  
  76.       << this->layer_param_.window_data_param().cache_images() << std::endl  
  77.       << "  root_folder: "  
  78.       << this->layer_param_.window_data_param().root_folder();  
  79.   
  80.   cache_images_ = this->layer_param_.window_data_param().cache_images();  
  81.   string root_folder = this->layer_param_.window_data_param().root_folder();  
  82.   
  83.   // 根据参数文件中是否需要进行左右mirror,或者是否进行crop,  
  84.   // 来判断是否需要初始化随机数种子  
  85.   const bool prefetch_needs_rand =  
  86.       this->transform_param_.mirror() ||  
  87.       this->transform_param_.crop_size();  
  88.   if (prefetch_needs_rand) {  
  89.     const unsigned int prefetch_rng_seed = caffe_rng_rand();  
  90.     prefetch_rng_.reset(new Caffe::RNG(prefetch_rng_seed));  
  91.   } else {  
  92.     prefetch_rng_.reset();  
  93.   }  
  94.   
  95.   // 打开窗口文件  
  96.   std::ifstream infile(this->layer_param_.window_data_param().source().c_str());  
  97.   CHECK(infile.good()) << "Failed to open window file "  
  98.       << this->layer_param_.window_data_param().source() << std::endl;  
  99.   
  100.   // 这个是类标与类标出现的次数之间的映射  
  101.   // 这里称之为类标直方图  
  102.   map<intint> label_hist;  
  103.   label_hist.insert(std::make_pair(0, 0));  
  104.   
  105.   string hashtag;  
  106.   int image_index, channels;  
  107.   // 先从窗口文件中读取一个图像索引测试一下是否为空  
  108.   if (!(infile >> hashtag >> image_index)) {  
  109.     LOG(FATAL) << "Window file is empty";  
  110.   }  
  111.   do {  
  112.       // 检查是否# 开头  
  113.     CHECK_EQ(hashtag, "#");  
  114.     // read image path  
  115.     string image_path;  
  116.     // 接下来读取图像的相对路径  
  117.     // 将该路径与根目录路径拼接  
  118.     infile >> image_path;  
  119.     image_path = root_folder + image_path;  
  120.     // read image dimensions  
  121.     vector<int> image_size(3);  
  122.     // 读取图像的维度信息,分别为channel,height , width  
  123.     infile >> image_size[0] >> image_size[1] >> image_size[2];  
  124.     channels = image_size[0];  
  125.     // 将图像路径和图像大小压入到image_database_中  
  126.     image_database_.push_back(std::make_pair(image_path, image_size));  
  127.   
  128.     // 如果需要缓存图像到内存的话,则用image_database_cache_进行存储  
  129.     if (cache_images_) {  
  130.       Datum datum;  
  131.       // 将图像数据读取到Datum这个结构  
  132.       if (!ReadFileToDatum(image_path, &datum)) {  
  133.         LOG(ERROR) << "Could not open or find file " << image_path;  
  134.         return;  
  135.       }  
  136.       // 将Datum结构的图像缓存到到image_database_cache_  
  137.       image_database_cache_.push_back(std::make_pair(image_path, datum));  
  138.     }  
  139.     // read each box  
  140.     int num_windows;  
  141.     // 读取窗口个数  
  142.     infile >> num_windows;  
  143.     // 从参数文件获取前景和背景阈值  
  144.     const float fg_threshold =  
  145.         this->layer_param_.window_data_param().fg_threshold();  
  146.     const float bg_threshold =  
  147.         this->layer_param_.window_data_param().bg_threshold();  
  148.     for (int i = 0; i < num_windows; ++i) {  
  149.       int label, x1, y1, x2, y2;  
  150.       float overlap;  
  151.       // 读取  类标,与前景目标的重叠率,x1,y1,x2,y2  
  152.       infile >> label >> overlap >> x1 >> y1 >> x2 >> y2;  
  153.   
  154.       // 按照顺序放在window这个数据结构里头  
  155.       vector<float> window(WindowDataLayer::NUM);  
  156.       window[WindowDataLayer::IMAGE_INDEX] = image_index;  
  157.       window[WindowDataLayer::LABEL] = label;  
  158.       window[WindowDataLayer::OVERLAP] = overlap;  
  159.       window[WindowDataLayer::X1] = x1;  
  160.       window[WindowDataLayer::Y1] = y1;  
  161.       window[WindowDataLayer::X2] = x2;  
  162.       window[WindowDataLayer::Y2] = y2;  
  163.   
  164.       // add window to foreground list or background list  
  165.       // 下面是将窗口的前景和背景都装入到fg_windows_和bg_windows_中去  
  166.       // 如果重叠的比例大于前景阈值,那么就认为是前景  
  167.       if (overlap >= fg_threshold) {  
  168.         int label = window[WindowDataLayer::LABEL];  
  169.         // 类标必须大于0,因为重叠区域已经大于前景阈值了  
  170.         // 此时如果类标不大于0,表明数据有误!  
  171.         CHECK_GT(label, 0);  
  172.         fg_windows_.push_back(window);  
  173.         // 该类的直方图+1  
  174.         label_hist.insert(std::make_pair(label, 0));  
  175.         label_hist[label]++;  
  176.       } else if (overlap < bg_threshold) {  
  177.       // 如果重叠阈值小于背景阈值则认为是背景  
  178.         // background window, force label and overlap to 0  
  179.         window[WindowDataLayer::LABEL] = 0;  
  180.         window[WindowDataLayer::OVERLAP] = 0;  
  181.         bg_windows_.push_back(window);  
  182.         // 0类的直方图(也就是背景的直方图)+1  
  183.         label_hist[0]++;  
  184.       }  
  185.     }  
  186.   
  187.     // 每处理100个就显示一瞎  
  188.     if (image_index % 100 == 0) {  
  189.       LOG(INFO) << "num: " << image_index << " "  
  190.           << image_path << " "  
  191.           << image_size[0] << " "  
  192.           << image_size[1] << " "  
  193.           << image_size[2] << " "  
  194.           << "windows to process: " << num_windows;  
  195.     }  
  196.   } while (infile >> hashtag >> image_index);  
  197.   
  198.   // 读取完毕后输出图像的个数  
  199.   LOG(INFO) << "Number of images: " << image_index+1;  
  200.   
  201.   // 输出统计的每个类别的个数  
  202.   for (map<intint>::iterator it = label_hist.begin();  
  203.       it != label_hist.end(); ++it) {  
  204.     LOG(INFO) << "class " << it->first << " has " << label_hist[it->first]  
  205.               << " samples";  
  206.   }  
  207.   
  208.   LOG(INFO) << "Amount of context padding: "  
  209.       << this->layer_param_.window_data_param().context_pad();  
  210.   
  211.   LOG(INFO) << "Crop mode: "  
  212.       << this->layer_param_.window_data_param().crop_mode();  
  213.   
  214.   // image  
  215.   // 获取crop_size  
  216.   const int crop_size = this->transform_param_.crop_size();  
  217.   CHECK_GT(crop_size, 0);  
  218.   // 获取batch_size  
  219.   const int batch_size = this->layer_param_.window_data_param().batch_size();  
  220.   // 将top[0]设置为batch_size,channels, crop_size, crop_size大小的  
  221.   top[0]->Reshape(batch_size, channels, crop_size, crop_size);  
  222.   // 将prefetch_中的数据形状也这么设置  
  223.   for (int i = 0; i < this->PREFETCH_COUNT; ++i)  
  224.     this->prefetch_[i].data_.Reshape(  
  225.         batch_size, channels, crop_size, crop_size);  
  226.   
  227.   LOG(INFO) << "output data size: " << top[0]->num() << ","  
  228.       << top[0]->channels() << "," << top[0]->height() << ","  
  229.       << top[0]->width();  
  230.   // label  
  231.   // 将top[1]设置为类标大小  
  232.   vector<int> label_shape(1, batch_size);  
  233.   top[1]->Reshape(label_shape);  
  234.   // 将prefetch_中的类标形状也这么设置  
  235.   for (int i = 0; i < this->PREFETCH_COUNT; ++i) {  
  236.     this->prefetch_[i].label_.Reshape(label_shape);  
  237.   }  
  238.   
  239.   // data mean  
  240.   // 是否有均值文件或者有均值  
  241.   has_mean_file_ = this->transform_param_.has_mean_file();  
  242.   has_mean_values_ = this->transform_param_.mean_value_size() > 0;  
  243.   if (has_mean_file_) {// 有均值文件就读  
  244.     const string& mean_file =  
  245.           this->transform_param_.mean_file();  
  246.     LOG(INFO) << "Loading mean file from: " << mean_file;  
  247.     BlobProto blob_proto;  
  248.     ReadProtoFromBinaryFileOrDie(mean_file.c_str(), &blob_proto);  
  249.     data_mean_.FromProto(blob_proto);  
  250.   }  
  251.   if (has_mean_values_) {// 有均值就直接从参数中获取  
  252.     CHECK(has_mean_file_ == false) <<  
  253.       "Cannot specify mean_file and mean_value at the same time";  
  254.     for (int c = 0; c < this->transform_param_.mean_value_size(); ++c) {  
  255.       mean_values_.push_back(this->transform_param_.mean_value(c));  
  256.     }  
  257.   
  258.     // 检查均值是不是等于1,或者等于图像的通道数  
  259.     // 也就是要么所有通道都使用同一个均值  
  260.     // 要么每个通道用一个均值  
  261.     CHECK(mean_values_.size() == 1 || mean_values_.size() == channels) <<  
  262.      "Specify either 1 mean_value or as many as channels: " << channels;  
  263.     if (channels > 1 && mean_values_.size() == 1) {  
  264.       // Replicate the mean_value for simplicity  
  265.       for (int c = 1; c < channels; ++c) {  
  266.         mean_values_.push_back(mean_values_[0]);  
  267.       }  
  268.     }  
  269.   }  
  270. }  
  271.   
  272. // 随机数生成器进行初始化并生成随机数  
  273. template <typename Dtype>  
  274. unsigned int WindowDataLayer<Dtype>::PrefetchRand() {  
  275.   CHECK(prefetch_rng_);  
  276.   caffe::rng_t* prefetch_rng =  
  277.       static_cast<caffe::rng_t*>(prefetch_rng_->generator());  
  278.   return (*prefetch_rng)();  
  279. }  
  280.   
  281. // 因为继承BasePrefetchingDataLayer所以要实现load_batch  
  282. // 以供线程调用  
  283. // This function is called on prefetch thread  
  284. template <typename Dtype>  
  285. void WindowDataLayer<Dtype>::load_batch(Batch<Dtype>* batch) {  
  286.   // At each iteration, sample N windows where N*p are foreground (object)  
  287.   // windows and N*(1-p) are background (non-object) windows  
  288.   CPUTimer batch_timer;  
  289.   batch_timer.Start();  
  290.   double read_time = 0;  
  291.   double trans_time = 0;  
  292.   CPUTimer timer;  
  293.   // top数据和类标  
  294.   Dtype* top_data = batch->data_.mutable_cpu_data();  
  295.   Dtype* top_label = batch->label_.mutable_cpu_data();  
  296.   // 缩放尺度  
  297.   const Dtype scale = this->layer_param_.window_data_param().scale();  
  298.   // batch_size  
  299.   const int batch_size = this->layer_param_.window_data_param().batch_size();  
  300.   // 上下文填充  
  301.   const int context_pad = this->layer_param_.window_data_param().context_pad();  
  302.   // crop_size  
  303.   const int crop_size = this->transform_param_.crop_size();  
  304.   // 是否镜像  
  305.   const bool mirror = this->transform_param_.mirror();  
  306.   // 前景比例  
  307.   const float fg_fraction =  
  308.       this->layer_param_.window_data_param().fg_fraction();  
  309.   Dtype* mean = NULL;  
  310.   int mean_off = 0;  
  311.   int mean_width = 0;  
  312.   int mean_height = 0;  
  313.   // 如果有平均值文件则  
  314.   if (this->has_mean_file_) {  
  315.     mean = this->data_mean_.mutable_cpu_data();  
  316.     // 经过crop之后的平均值图像的中心  
  317.     mean_off = (this->data_mean_.width() - crop_size) / 2;  
  318.     mean_width = this->data_mean_.width();  
  319.     mean_height = this->data_mean_.height();  
  320.   }  
  321.   cv::Size cv_crop_size(crop_size, crop_size);  
  322.   // 获取crop的模式,是warp还是square  
  323.   const string& crop_mode = this->layer_param_.window_data_param().crop_mode();  
  324.   
  325.   bool use_square = (crop_mode == "square") ? true : false;  
  326.   
  327.   // zero out batch  
  328.   caffe_set(batch->data_.count(), Dtype(0), top_data);  
  329.   
  330.   // 根据前景比例获得前景图像的数目  
  331.   const int num_fg = static_cast<int>(static_cast<float>(batch_size)  
  332.       * fg_fraction);  
  333.   // 样本数量,是前景还是背景?[0]是背景[1]是前景  
  334.   const int num_samples[2] = { batch_size - num_fg, num_fg };  
  335.   
  336.   int item_id = 0;  
  337.   // sample from bg set then fg set  
  338.   // 先对背景进行采样  
  339.   // 再对前景进行采样  
  340.   for (int is_fg = 0; is_fg < 2; ++is_fg) {  
  341.     for (int dummy = 0; dummy < num_samples[is_fg]; ++dummy) {  
  342.       // sample a window  
  343.       timer.Start();  
  344.       // 生成一个随机数  
  345.       const unsigned int rand_index = PrefetchRand();  
  346.       // fg_windows_和bg_windows_存储的是对应的窗口信息  
  347.       // 在SetUp中读取的窗口数据文件的时候获得的  
  348.       // 从该图像的若干窗口中去随机选择一个窗口  
  349.       vector<float> window = (is_fg) ?  
  350.           fg_windows_[rand_index % fg_windows_.size()] :  
  351.           bg_windows_[rand_index % bg_windows_.size()];  
  352.   
  353.       // 随机选择是否需要镜像  
  354.       bool do_mirror = mirror && PrefetchRand() % 2;  
  355.   
  356.       // load the image containing the window  
  357.       // 载入图像的路径以及类标  
  358.       pair<std::string, vector<int> > image =  
  359.           image_database_[window[WindowDataLayer<Dtype>::IMAGE_INDEX]];  
  360.   
  361.       // 读取图像  
  362.       cv::Mat cv_img;  
  363.       if (this->cache_images_) {  
  364.           // 如果图像缓冲到内存则获得对应图像的Datum  
  365.         pair<std::string, Datum> image_cached =  
  366.           image_database_cache_[window[WindowDataLayer<Dtype>::IMAGE_INDEX]];  
  367.         // 将图像的Datum解码为OpenCV的Mat  
  368.         cv_img = DecodeDatumToCVMat(image_cached.second, true);  
  369.       } else {  
  370.         // 否则直接读取  
  371.         cv_img = cv::imread(image.first, CV_LOAD_IMAGE_COLOR);  
  372.         if (!cv_img.data) {  
  373.           LOG(ERROR) << "Could not open or find file " << image.first;  
  374.           return;  
  375.         }  
  376.       }  
  377.       read_time += timer.MicroSeconds();  
  378.       timer.Start();  
  379.       const int channels = cv_img.channels();  
  380.   
  381.       // crop window out of image and warp it  
  382.       // 窗口坐标  
  383.       int x1 = window[WindowDataLayer<Dtype>::X1];  
  384.       int y1 = window[WindowDataLayer<Dtype>::Y1];  
  385.       int x2 = window[WindowDataLayer<Dtype>::X2];  
  386.       int y2 = window[WindowDataLayer<Dtype>::Y2];  
  387.   
  388.       int pad_w = 0;  
  389.       int pad_h = 0;  
  390.       // context_pad也是个大小,具体什么含义,我没有具体研究  
  391.       // 毕竟不是搞检测的  
  392.       // context_scale = crop_size / (crop_size - 2*context_pad)  
  393.       if (context_pad > 0 || use_square) {  
  394.         // scale factor by which to expand the original region  
  395.         // such that after warping the expanded region to crop_size x crop_size  
  396.         // there's exactly context_pad amount of padding on each side  
  397.         Dtype context_scale = static_cast<Dtype>(crop_size) /  
  398.             static_cast<Dtype>(crop_size - 2*context_pad);  
  399.   
  400.         // compute the expanded region  
  401.         // 高度的一半  
  402.         Dtype half_height = static_cast<Dtype>(y2-y1+1)/2.0;  
  403.         // 宽度的一半  
  404.         Dtype half_width = static_cast<Dtype>(x2-x1+1)/2.0;  
  405.         // x中心  
  406.         Dtype center_x = static_cast<Dtype>(x1) + half_width;  
  407.         // y中心  
  408.         Dtype center_y = static_cast<Dtype>(y1) + half_height;  
  409.         if (use_square) {// 如果使用正方形形状则将较大的那个赋值给小的  
  410.           if (half_height > half_width) {  
  411.             half_width = half_height;  
  412.           } else {  
  413.             half_height = half_width;  
  414.           }  
  415.         }  
  416.   
  417.         // 获取经过处理之后的x1,y1,x2,y2  
  418.         x1 = static_cast<int>(round(center_x - half_width*context_scale));  
  419.         x2 = static_cast<int>(round(center_x + half_width*context_scale));  
  420.         y1 = static_cast<int>(round(center_y - half_height*context_scale));  
  421.         y2 = static_cast<int>(round(center_y + half_height*context_scale));  
  422.   
  423.         // the expanded region may go outside of the image  
  424.         // so we compute the clipped (expanded) region and keep track of  
  425.         // the extent beyond the image  
  426.         // 经过处理之后的窗口如果不在图像内部是有问题的  
  427.         // 这里对窗口的坐标进行处理  
  428.         // 使得窗口的左上角不超过图像的左上角  
  429.         // 窗口的右下角不超过图像的右下角  
  430.         // 所以这里叫clip bounds嘛  
  431.         int unclipped_height = y2-y1+1;  
  432.         int unclipped_width = x2-x1+1;  
  433.         int pad_x1 = std::max(0, -x1);  
  434.         int pad_y1 = std::max(0, -y1);  
  435.         int pad_x2 = std::max(0, x2 - cv_img.cols + 1);  
  436.         int pad_y2 = std::max(0, y2 - cv_img.rows + 1);  
  437.         // clip bounds  
  438.         x1 = x1 + pad_x1;  
  439.         x2 = x2 - pad_x2;  
  440.         y1 = y1 + pad_y1;  
  441.         y2 = y2 - pad_y2;  
  442.         CHECK_GT(x1, -1);  
  443.         CHECK_GT(y1, -1);  
  444.         CHECK_LT(x2, cv_img.cols);  
  445.         CHECK_LT(y2, cv_img.rows);  
  446.   
  447.         // 经过clip之后的高度和宽度  
  448.         int clipped_height = y2-y1+1;  
  449.         int clipped_width = x2-x1+1;  
  450.   
  451.         // scale factors that would be used to warp the unclipped  
  452.         // expanded region  
  453.         // scale_x/scale_y=crop_size除以未经clip之后的宽度/高度  
  454.         Dtype scale_x =  
  455.             static_cast<Dtype>(crop_size)/static_cast<Dtype>(unclipped_width);  
  456.         Dtype scale_y =  
  457.             static_cast<Dtype>(crop_size)/static_cast<Dtype>(unclipped_height);  
  458.   
  459.         // size to warp the clipped expanded region to  
  460.         // 用clip的宽度和高度乘以scale_x或者scale_y得到crop_size中的宽度和高度  
  461.         cv_crop_size.width =  
  462.             static_cast<int>(round(static_cast<Dtype>(clipped_width)*scale_x));  
  463.         cv_crop_size.height =  
  464.             static_cast<int>(round(static_cast<Dtype>(clipped_height)*scale_y));  
  465.         // 再对pad的边界进行处理  
  466.         pad_x1 = static_cast<int>(round(static_cast<Dtype>(pad_x1)*scale_x));  
  467.         pad_x2 = static_cast<int>(round(static_cast<Dtype>(pad_x2)*scale_x));  
  468.         pad_y1 = static_cast<int>(round(static_cast<Dtype>(pad_y1)*scale_y));  
  469.         pad_y2 = static_cast<int>(round(static_cast<Dtype>(pad_y2)*scale_y));  
  470.   
  471.         pad_h = pad_y1;  
  472.         // if we're mirroring, we mirror the padding too (to be pedantic)  
  473.         // 如果需要镜像填充的部分也要镜像  
  474.         if (do_mirror) {  
  475.           pad_w = pad_x2;  
  476.         } else {  
  477.           pad_w = pad_x1;  
  478.         }  
  479.   
  480.         // ensure that the warped, clipped region plus the padding fits in the  
  481.         // crop_size x crop_size image (it might not due to rounding)  
  482.         // 确保大小是在crop_size x crop_size以内的  
  483.         if (pad_h + cv_crop_size.height > crop_size) {  
  484.           cv_crop_size.height = crop_size - pad_h;  
  485.         }  
  486.         if (pad_w + cv_crop_size.width > crop_size) {  
  487.           cv_crop_size.width = crop_size - pad_w;  
  488.         }  
  489.       }  
  490.   
  491.       cv::Rect roi(x1, y1, x2-x1+1, y2-y1+1);  
  492.       // 进行crop  
  493.       cv::Mat cv_cropped_img = cv_img(roi);  
  494.       // 使用线性插值进行缩放,缩放到cv_crop_size  
  495.       cv::resize(cv_cropped_img, cv_cropped_img,  
  496.           cv_crop_size, 0, 0, cv::INTER_LINEAR);  
  497.   
  498.       // horizontal flip at random  
  499.       if (do_mirror) {  
  500.           // 对图像进行镜像  
  501.         cv::flip(cv_cropped_img, cv_cropped_img, 1);  
  502.       }  
  503.   
  504.       // copy the warped window into top_data  
  505.       for (int h = 0; h < cv_cropped_img.rows; ++h) {  
  506.         const uchar* ptr = cv_cropped_img.ptr<uchar>(h);  
  507.         int img_index = 0;  
  508.         for (int w = 0; w < cv_cropped_img.cols; ++w) {  
  509.           for (int c = 0; c < channels; ++c) {  
  510.             int top_index = ((item_id * channels + c) * crop_size + h + pad_h)  
  511.                      * crop_size + w + pad_w;  
  512.             // int top_index = (c * height + h) * width + w;  
  513.             Dtype pixel = static_cast<Dtype>(ptr[img_index++]);  
  514.             if (this->has_mean_file_) {// 有均值文件减去均值文件中对应的数值  
  515.               int mean_index = (c * mean_height + h + mean_off + pad_h)  
  516.                            * mean_width + w + mean_off + pad_w;  
  517.               top_data[top_index] = (pixel - mean[mean_index]) * scale;  
  518.             } else {  
  519.               if (this->has_mean_values_) {// 有均值则减去均值  
  520.                 top_data[top_index] = (pixel - this->mean_values_[c]) * scale;  
  521.               } else {  
  522.                 top_data[top_index] = pixel * scale;// 像素值进行缩放  
  523.               }  
  524.             }  
  525.           }  
  526.         }  
  527.       }  
  528.       trans_time += timer.MicroSeconds();  
  529.       // get window label  
  530.       top_label[item_id] = window[WindowDataLayer<Dtype>::LABEL];  
  531.   
  532.       #if 0  
  533.       // useful debugging code for dumping transformed windows to disk  
  534.       string file_id;  
  535.       std::stringstream ss;  
  536.       ss << PrefetchRand();  
  537.       ss >> file_id;  
  538.       std::ofstream inf((string("dump/") + file_id +  
  539.           string("_info.txt")).c_str(), std::ofstream::out);  
  540.       inf << image.first << std::endl  
  541.           << window[WindowDataLayer<Dtype>::X1]+1 << std::endl  
  542.           << window[WindowDataLayer<Dtype>::Y1]+1 << std::endl  
  543.           << window[WindowDataLayer<Dtype>::X2]+1 << std::endl  
  544.           << window[WindowDataLayer<Dtype>::Y2]+1 << std::endl  
  545.           << do_mirror << std::endl  
  546.           << top_label[item_id] << std::endl  
  547.           << is_fg << std::endl;  
  548.       inf.close();  
  549.       std::ofstream top_data_file((string("dump/") + file_id +  
  550.           string("_data.txt")).c_str(),  
  551.           std::ofstream::out | std::ofstream::binary);  
  552.       for (int c = 0; c < channels; ++c) {  
  553.         for (int h = 0; h < crop_size; ++h) {  
  554.           for (int w = 0; w < crop_size; ++w) {  
  555.             top_data_file.write(reinterpret_cast<char*>(  
  556.                 &top_data[((item_id * channels + c) * crop_size + h)  
  557.                           * crop_size + w]),  
  558.                 sizeof(Dtype));  
  559.           }  
  560.         }  
  561.       }  
  562.       top_data_file.close();  
  563.       #endif  
  564.   
  565.       item_id++;  
  566.     }  
  567.   }  
  568.   batch_timer.Stop();  
  569.   DLOG(INFO) << "Prefetch batch: " << batch_timer.MilliSeconds() << " ms.";  
  570.   DLOG(INFO) << "     Read time: " << read_time / 1000 << " ms.";  
  571.   DLOG(INFO) << "Transform time: " << trans_time / 1000 << " ms.";  
  572. }  
  573.   
  574. INSTANTIATE_CLASS(WindowDataLayer);  
  575. REGISTER_LAYER_CLASS(WindowData);  
  576.   
  577. }  // namespace caffe  
  578. #endif  // USE_OPENCV 
最后提醒一下该类并没有重载前传函数,而是调用了基类的前传,我把对应的代码贴出来便于你整体进行理解
[cpp]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. template <typename Dtype>  
  2. void BasePrefetchingDataLayer<Dtype>::Forward_cpu(  
  3.     const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {  
  4.     // 传递的时候是从full队列中弹出一个数据  
  5.   Batch<Dtype>* batch = prefetch_full_.pop("Data layer prefetch queue empty");  
  6.   // Reshape to loaded data.  
  7.   // 根据batch的形状改变数据形状  
  8.   top[0]->ReshapeLike(batch->data_);  
  9.   // Copy the data  
  10.   // 将batch数据复制到top[0]  
  11.   caffe_copy(batch->data_.count(), batch->data_.cpu_data(),  
  12.              top[0]->mutable_cpu_data());  
  13.   DLOG(INFO) << "Prefetch copied";  
  14.   if (this->output_labels_) {  
  15.       // 输出类标的话  
  16.     // Reshape to loaded labels.  
  17.     // 根据batch中类标的形状改变top[1]的形状  
  18.     top[1]->ReshapeLike(batch->label_);  
  19.     // Copy the labels.  
  20.     // 复制类标到top[1]  
  21.     caffe_copy(batch->label_.count(), batch->label_.cpu_data(),  
  22.         top[1]->mutable_cpu_data());  
  23.   }  
  24.   // 将该batch压入free队列  
  25.   prefetch_free_.push(batch);  
  26. }  

三、总结

首先理顺类与类之间的关系:
Layer类是所有神经网络层的基类,BaseDataLayer继承自该类,BasePrefetchingDataLayer继承自BaseDataLayer,DataLayer继承自BasePrefetchingDataLayer。
有了上述几个基础的类之后,其他的类都是从这几个类进行派生。

比如DummyDataLayer,HDF5Layer和HDF5OutputLayer都是直接继承自Layer。
MemoryDataLayer则是继承自BaseDataLayer

凡是涉及到直接读取数据文件的一般都是继承自BasePrefetchingDataLayer,这样可以有效地读数据进行预取。
比如:ImageDataLayer、WindowDataLayer
继承自BasePrefetchingDataLayer需要实现load_batch函数以供内部的线程进行调用,实现数据预取。
此外每一个网络层的类(因为所有的网络层都继承自Layer类嘛)都需要实现SetUp,这个是必须的。

这一次的量还真有点大。。。

注释的代码可以从以下位置下载:
http://download.youkuaiyun.com/detail/xizero00/9474806

参考:

[1]HDF5格式的介绍

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值