caffe添加HeatmapData层 (二)

经过加caffe添加PrecisionRecallLosslayer层(一) 的学习,再继续进行学习:

本文以https://github.com/tpfister/caffe-heatmap中所实现的data_heatma.cpp和data_heatmap.hpp为例介绍如何写自己的层。

=================================================================================================================================

1、老规矩,我们现在caffe.proto中添加参数及消息:

[cpp]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. message LayerParameter {  
  2.   optional string name = 1; // the layer name  
  3.   optional string type = 2; // the layer type  
  4.   repeated string bottom = 3; // the name of each bottom blob  
  5.   repeated string top = 4; // the name of each top blob  
  6.   
  7.   // The train / test phase for computation.  
  8.   optional Phase phase = 10;  
  9.   
  10.   // The amount of weight to assign each top blob in the objective.  
  11.   // Each layer assigns a default value, usually of either 0 or 1,  
  12.   // to each top blob.  
  13.   repeated float loss_weight = 5;  
  14.   
  15.   // Specifies training parameters (multipliers on global learning constants,  
  16.   // and the name and other settings used for weight sharing).  
  17.   repeated ParamSpec param = 6;  
  18.   
  19.   // The blobs containing the numeric parameters of the layer.  
  20.   repeated BlobProto blobs = 7;  
  21.   
  22.   // Specifies on which bottoms the backpropagation should be skipped.  
  23.   // The size must be either 0 or equal to the number of bottoms.  
  24.   repeated bool propagate_down = 11;  
  25.   
  26.   // Rules controlling whether and when a layer is included in the network,  
  27.   // based on the current NetState.  You may specify a non-zero number of rules  
  28.   // to include OR exclude, but not both.  If no include or exclude rules are  
  29.   // specified, the layer is always included.  If the current NetState meets  
  30.   // ANY (i.e., one or more) of the specified rules, the layer is  
  31.   // included/excluded.  
  32.   repeated NetStateRule include = 8;  
  33.   repeated NetStateRule exclude = 9;  
  34.   
  35.   // Parameters for data pre-processing.  
  36.   optional TransformationParameter transform_param = 100;  
  37.   
  38.   // Parameters shared by loss layers.  
  39.   optional LossParameter loss_param = 101;  
  40.   
  41.   
  42.   // Options to allow visualisation可视化层的参数,就这两货哈    
  43.   optional bool visualise = 200 [ default = false ];    
  44.   optional uint32 visualise_channel = 201 [ default = 0 ];    
  45.   // Layer type-specific parameters.  
  46.   //  
  47.   // Note: certain layers may have more than one computational engine  
  48.   // for their implementation. These layers include an Engine type and  
  49.   // engine parameter for selecting the implementation.  
  50.   // The default for the engine is set by the ENGINE switch at compile-time.  
  51.   optional AccuracyParameter accuracy_param = 102;  
  52.   optional ArgMaxParameter argmax_param = 103;  
  53.   optional BatchNormParameter batch_norm_param = 139;  
  54.   optional BiasParameter bias_param = 141;  
  55.   optional ConcatParameter concat_param = 104;  
  56.   optional ContrastiveLossParameter contrastive_loss_param = 105;  
  57.   optional ConvolutionParameter convolution_param = 106;  
  58.   optional CropParameter crop_param = 144;  
  59.   optional DataParameter data_param = 107;  
  60.   optional DropoutParameter dropout_param = 108;  
  61.   optional DummyDataParameter dummy_data_param = 109;  
  62.   optional EltwiseParameter eltwise_param = 110;  
  63.   optional ELUParameter elu_param = 140;  
  64.   optional EmbedParameter embed_param = 137;  
  65.   optional ExpParameter exp_param = 111;  
  66.   optional FlattenParameter flatten_param = 135;  
  67.   optional HeatmapDataParameter heatmap_data_param = 145;// 加入自己层的参数   
  68.   optional HDF5DataParameter hdf5_data_param = 112;  
  69.   optional HDF5OutputParameter hdf5_output_param = 113;  
  70.   optional HingeLossParameter hinge_loss_param = 114;  
  71.   optional ImageDataParameter image_data_param = 115;  
  72.   optional InfogainLossParameter infogain_loss_param = 116;  
  73.   optional InnerProductParameter inner_product_param = 117;  
  74.   optional InputParameter input_param = 143;  
  75.   optional LogParameter log_param = 134;  
  76.   optional LRNParameter lrn_param = 118;  
  77.   optional MemoryDataParameter memory_data_param = 119;  
  78.   optional MVNParameter mvn_param = 120;  
  79.   optional PoolingParameter pooling_param = 121;  
  80.   optional PowerParameter power_param = 122;  
  81.   optional PReLUParameter prelu_param = 131;  
  82.   optional PythonParameter python_param = 130;  
  83.   optional ReductionParameter reduction_param = 136;  
  84.   optional ReLUParameter relu_param = 123;  
  85.   optional ReshapeParameter reshape_param = 133;  
  86.   optional ScaleParameter scale_param = 142;  
  87.   optional SigmoidParameter sigmoid_param = 124;  
  88.   optional SoftmaxParameter softmax_param = 125;  
  89.   optional SPPParameter spp_param = 132;  
  90.   optional SliceParameter slice_param = 126;  
  91.   optional TanHParameter tanh_param = 127;  
  92.   optional ThresholdParameter threshold_param = 128;  
  93.   optional TileParameter tile_param = 138;  
  94.   optional WindowDataParameter window_data_param = 129;  
  95. }  
顺便在这个layer参数后面添加HeatmapDataParameter消息:

[cpp]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. // VGG heatmap params 自己层的参数  
  2. message HeatmapDataParameter {  
  3.   optional bool segmentation = 1000 [default = false];   
  4.   optional uint32 multfact = 1001 [default = 1];  
  5.   optional uint32 num_channels = 1002 [default = 3];  
  6.   optional uint32 batchsize = 1003;  
  7.   optional string root_img_dir = 1004;  
  8.   optional bool random_crop = 1005;   // image augmentation type  
  9.   optional bool sample_per_cluster = 1006;   // image sampling type  
  10.   optional string labelinds = 1007 [default = ''];   // if specified, only use these regression variables  
  11.   optional string source = 1008;  
  12.   optional string meanfile = 1009;  
  13.   optional string crop_meanfile = 1010;  
  14.   optional uint32 cropsize = 1011 [default = 0];  
  15.   optional uint32 outsize = 1012 [default = 0];  
  16.   optional float scale = 1013 [ default = 1 ];  
  17.   optional uint32 label_width = 1014 [ default = 1 ];  
  18.   optional uint32 label_height = 1015 [ default = 1 ];  
  19.   optional bool dont_flip_first = 1016 [ default = true ];  
  20.   optional float angle_max = 1017 [ default = 0 ];   
  21.   optional bool flip_joint_labels = 1018 [ default = true ];  
  22. }  

对各个参数进行解释:
segmentation            是否分割,默认是否, 假设图像的分割模板在segs/目录
multfact                    将ground truth中的关节乘以这个multfact,就是图像中的位置,图像中的位置除以这个就是关节的位置,默认是1,也就是说关节的坐标与图像的坐标是一致大小的
num_channels           图像的channel数默认是3
batchsize                    batch大小
root_img_dir               存放图像文件的根目录
random_crop              是否需要随机crop图像(如果true则做随机crop,否则做中心crop)
sample_per_cluster     图像采样的类型(是否均匀地在clusters上采样)
labelinds                     类标索引(只使用回归变量才设置这个)
source                        存放打乱文件顺序之后的文件路径的txt文件
meanfile                    平均值文件路径
crop_meanfile           crop之后的平均值文件路径
cropsize                    crop的大小
outsize                      默认是0(就是crop出来之后的图像会缩放的因子,0表示不缩放)
scale                         默认是1,实际上就是一系列预处理(去均值、crop、缩放之后的像素值乘以该scale得到最终的图像的)
label_width               heatmap的宽
label_height               heatmap的高
dont_flip_first              不要对调第一个关节的位置,默认是true
angle_max              对图像进行旋转的最大角度,用于增强数据的,默认是0度
flip_joint_labels          默认是true(即水平翻转,将左右的关节对调)
还有可视化的测试参数设置:

[cpp]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. / NOTE  
  2. // Update the next available ID when you add a new LayerParameter field.  
  3. //  
  4. // LayerParameter next available layer-specific ID: 139 (last added: tile_param)  
  5. message LayerParameter {  
  6.   optional string name = 1; // the layer name  
  7.   optional string type = 2; // the layer type  
  8.   repeated string bottom = 3; // the name of each bottom blob  
  9.   repeated string top = 4; // the name of each top blob  
  10.   
  11.   // The train / test phase for computation.  
  12.   optional Phase phase = 10;  
  13.   
  14.   // The amount of weight to assign each top blob in the objective.  
  15.   // Each layer assigns a default value, usually of either 0 or 1,  
  16.   // to each top blob.  
  17.   repeated float loss_weight = 5;  
  18.   
  19.   // Specifies training parameters (multipliers on global learning constants,  
  20.   // and the name and other settings used for weight sharing).  
  21.   repeated ParamSpec param = 6;  
  22.   
  23.   // The blobs containing the numeric parameters of the layer.  
  24.   repeated BlobProto blobs = 7;  
  25.   
  26.   // Specifies on which bottoms the backpropagation should be skipped.  
  27.   // The size must be either 0 or equal to the number of bottoms.  
  28.   repeated bool propagate_down = 11;  
  29.   
  30.   // Rules controlling whether and when a layer is included in the network,  
  31.   // based on the current NetState.  You may specify a non-zero number of rules  
  32.   // to include OR exclude, but not both.  If no include or exclude rules are  
  33.   // specified, the layer is always included.  If the current NetState meets  
  34.   // ANY (i.e., one or more) of the specified rules, the layer is  
  35.   // included/excluded.  
  36.   repeated NetStateRule include = 8;  
  37.   repeated NetStateRule exclude = 9;  
  38.   
  39.   // Parameters for data pre-processing.  
  40.   optional TransformationParameter transform_param = 100;  
  41.   
  42.   // Parameters shared by loss layers.  
  43.   optional LossParameter loss_param = 101;  
  44.   
  45.   // Options to allow visualisation可视化层的参数,  
  46.   optional bool visualise = 200 [ default = false ];  
  47.   optional uint32 visualise_channel = 201 [ default = 0 ];  
还有一部分前面没有提到的部分就是V1LayerParameter,在这个里面添加两个我注释内容,这部分是为caffe的扩展提供了很好的帮助,但是作者在实现更新的upgrade_proto文件中,写的风格有点不符合前面风格了,全是if。。。。。

[cpp]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. // DEPRECATED: use LayerParameter.  
  2. message V1LayerParameter {  
  3.   repeated string bottom = 2;  
  4.   repeated string top = 3;  
  5.   optional string name = 4;  
  6.   repeated NetStateRule include = 32;  
  7.   repeated NetStateRule exclude = 33;  
  8.   enum LayerType {  
  9.     NONE = 0;  
  10.     ABSVAL = 35;  
  11.     ACCURACY = 1;  
  12.     ARGMAX = 30;  
  13.     BNLL = 2;  
  14.     CONCAT = 3;  
  15.     CONTRASTIVE_LOSS = 37;  
  16.     CONVOLUTION = 4;  
  17.     DATA = 5;  
  18.     DATA_HEATMAP=40;///////////自己添加  
  19.     DECONVOLUTION = 39;  
  20.     DROPOUT = 6;  
  21.     DUMMY_DATA = 32;  
  22.     EUCLIDEAN_LOSS = 7;  
  23.     ELTWISE = 25;  
  24.     EXP = 38;  
  25.     FLATTEN = 8;  
  26.     HDF5_DATA = 9;  
  27.     HDF5_OUTPUT = 10;  
  28.     HINGE_LOSS = 28;  
  29.     IM2COL = 11;  
  30.     IMAGE_DATA = 12;  
  31.     INFOGAIN_LOSS = 13;  
  32.     INNER_PRODUCT = 14;  
  33.     LRN = 15;  
  34.     MEMORY_DATA = 29;  
  35.     MULTINOMIAL_LOGISTIC_LOSS = 16;  
  36.     MVN = 34;  
  37.     POOLING = 17;  
  38.     POWER = 26;  
  39.     RELU = 18;  
  40.     SIGMOID = 19;  
  41.     SIGMOID_CROSS_ENTROPY_LOSS = 27;  
  42.     SILENCE = 36;  
  43.     SOFTMAX = 20;  
  44.     SOFTMAX_LOSS = 21;  
  45.     SPLIT = 22;  
  46.     SLICE = 33;  
  47.     TANH = 23;  
  48.     WINDOW_DATA = 24;  
  49.     THRESHOLD = 31;  
  50.   }  
  51.   optional LayerType type = 5;  
  52.   repeated BlobProto blobs = 6;  
  53.   repeated string param = 1001;  
  54.   repeated DimCheckMode blob_share_mode = 1002;  
  55.   enum DimCheckMode {  
  56.     STRICT = 0;  
  57.     PERMISSIVE = 1;  
  58.   }  
  59.   repeated float blobs_lr = 7;  
  60.   repeated float weight_decay = 8;  
  61.   repeated float loss_weight = 35;  
  62.   optional AccuracyParameter accuracy_param = 27;  
  63.   optional ArgMaxParameter argmax_param = 23;  
  64.   optional ConcatParameter concat_param = 9;  
  65.   optional ContrastiveLossParameter contrastive_loss_param = 40;  
  66.   optional ConvolutionParameter convolution_param = 10;  
  67.   optional DataParameter data_param = 11;  
  68.   optional HeatmapDataParameter heatmap_data_param = 43;// 加入自己层的参数  
  69.   optional DropoutParameter dropout_param = 12;  
  70.   optional DummyDataParameter dummy_data_param = 26;  
  71.   optional EltwiseParameter eltwise_param = 24;  
  72.   optional ExpParameter exp_param = 41;  
  73.   optional HDF5DataParameter hdf5_data_param = 13;  
  74.   optional HDF5OutputParameter hdf5_output_param = 14;  
  75.   optional HingeLossParameter hinge_loss_param = 29;  
  76.   optional ImageDataParameter image_data_param = 15;  
  77.   optional InfogainLossParameter infogain_loss_param = 16;  
  78.   optional InnerProductParameter inner_product_param = 17;  
  79.   optional LRNParameter lrn_param = 18;  
  80.   optional MemoryDataParameter memory_data_param = 22;  
  81.   optional MVNParameter mvn_param = 34;  
  82.   optional PoolingParameter pooling_param = 19;  
  83.   optional PowerParameter power_param = 21;  
  84.   optional ReLUParameter relu_param = 30;  
  85.   optional SigmoidParameter sigmoid_param = 38;  
  86.   optional SoftmaxParameter softmax_param = 39;  
  87.   optional SliceParameter slice_param = 31;  
  88.   optional TanHParameter tanh_param = 37;  
  89.   optional ThresholdParameter threshold_param = 25;  
  90.   optional WindowDataParameter window_data_param = 20;  
  91.   optional TransformationParameter transform_param = 36;  
  92.   optional LossParameter loss_param = 42;  
  93.   optional V0LayerParameter layer = 1;  
  94. }  
2、参数添加好之后就是heatmapdata层 声明和实现部分:

在介绍实现之前需要给出我们的训练数据的样子,看完参数,看一下训练的数据的格式理解一下:
下面给出一个样例:
train/FILE.jpg 123,144,165,123,66,22 372.296,720,1,480,0.53333 0
下面对样例做出解释
参数之间是以空格分隔
第一个参数是图像的路径:train/FILE.jpg
第二个参数是关节坐标:123,144,165,123,66,22
第三个参数是crop和scale的参数,分别为x_left,x_right,y_left,y_right,scaling_fact:372.296,720,1,480,0.53333
注意:第三个参数的crop的坐标其实上针对的是mean图像的,在mean图像中进行crop,然后放大到与原始图像一样大小,然后原始图像减去经过crop且放大之后的mean图像。这样在对原始图像进行crop的时候就不用担心了
第四个参数是是否cluster,是否均匀地在训练中采样图像: 0
crop在配置文件中的部分:

[cpp]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. transform_param {  
  2.    mirror: true  
  3.    crop_size: 227  
  4.    mean_file: "data/ilsvrc12/imagenet_mean.binaryproto"  
  5.  }  

上面是 caffeNet的 数据层的定义,看得出用了镜像和crop_size,还定义了 mean_file利用crop_size这种方式可以剪裁中心关注点和边角特征,mirror可以产生镜像,弥补小数据集的不足.git-issues里面有人问道这个crop_size和 mean_file的问题,一开始的时候是不能定义了crop,又用mean_file的,后来改进了.并且,这个mean_file和crop_size没什么大关系.只要你这个mean_file是根据你的训练集制作出来的就可以.应该是 先通过mean_file处理一遍数据集,再进行crop操作.用python接口去调用 python/caffe/ 下的 ilsvrc_2012_mean.npy这个文件,显示一下它的 shape,得到 3*256*256,说明,这个mean_file是根据 原数据集制作的,和crop_size 的 227 不一致,但是不影响训练.这样,就可以先根据 原数据集做出mean_file,再设计想要crop的尺寸,而不用担心 尺寸不一致的问题了 。

声明部分data_heatmap.hpp:

[cpp]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. // Copyright 2014 Tomas Pfister  
  2.   
  3. #ifndef CAFFE_HEATMAP_HPP_  
  4. #define CAFFE_HEATMAP_HPP_  
  5.   
  6. #include "caffe/layer.hpp"  
  7. #include <vector>  
  8. #include <boost/timer/timer.hpp>  
  9. #include <opencv2/core/core.hpp>  
  10.   
  11. #include "caffe/common.hpp"  
  12. #include "caffe/data_transformer.hpp"  
  13. #include "caffe/filler.hpp"  
  14. #include "caffe/internal_thread.hpp"  
  15. #include "caffe/proto/caffe.pb.h"  
  16.   
  17. namespace caffe  
  18. {  
  19.   
  20. // 继承自PrefetchingDataLayer  
  21. template<typename Dtype>  
  22. class DataHeatmapLayer: public BasePrefetchingDataLayer<Dtype>  
  23. {  
  24.   
  25. public:  
  26.   
  27.     explicit DataHeatmapLayer(const LayerParameter& param)  
  28.         : BasePrefetchingDataLayer<Dtype>(param) {}  
  29.     virtual ~DataHeatmapLayer();  
  30.     virtual void DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,  
  31.                                 const vector<Blob<Dtype>*>& top);  
  32.   
  33.     virtual inline const char* type() const { return "DataHeatmap"; }  
  34.   
  35.     virtual inline int ExactNumBottomBlobs() const { return 0; }  
  36.     virtual inline int ExactNumTopBlobs() const { return 2; }  
  37.   
  38.   
  39. protected:  
  40.     // 虚函数,就是实际读取一批数据到Batch中  
  41.     virtual void load_batch(Batch<Dtype>* batch);  
  42.     // 以下都是自己定义的要使用的函数,都在load_batch中被调用了  
  43.       
  44.     // Filename of current image  
  45.     inline void GetCurImg(string& img_name, std::vector<float>& img_class, std::vector<float>& crop_info, int& cur_class);  
  46.   
  47.     inline void AdvanceCurImg();  
  48.   
  49.     // Visualise point annotations  
  50.     inline void VisualiseAnnotations(cv::Mat img_annotation_vis, int numChannels, std::vector<float>& cur_label, int width);  
  51.   
  52.     // Random number generator  
  53.     inline float Uniform(const float min, const float max);  
  54.   
  55.     // Rotate image for augmentation  
  56.     inline cv::Mat RotateImage(cv::Mat src, float rotation_angle);  
  57.   
  58.     // Global vars  
  59.     shared_ptr<Caffe::RNG> rng_data_;  
  60.     shared_ptr<Caffe::RNG> prefetch_rng_;  
  61.     vector<std::pair<std::string, int> > lines_;  
  62.     int lines_id_;      
  63.     int datum_channels_;  
  64.     int datum_height_;  
  65.     int datum_width_;  
  66.     int datum_size_;  
  67.     int num_means_;  
  68.     int cur_class_;  
  69.     vector<int> labelinds_;  
  70.     // 图像均值的vector容器,其中存放的是每个视频的均值  
  71.     vector<cv::Mat> mean_img_;  
  72.     // 是否需要减去每个视频的均值  
  73.     bool sub_mean_;  // true if the mean should be subtracted  
  74.     // 是否对在每个类进行均匀采样  
  75.     bool sample_per_cluster_; // sample separately per cluster?  
  76.     string root_img_dir_;  
  77.     // 如果开启sample_per_cluster_则该vector中放的就是在该类别中随机采样的图像的索引  
  78.     // 举个例子,如果类别1的图像的个数是10个,那么就随机生成[0,9]之间的一个数作为采样的图像的索引  
  79.     // 从类别1中将该图像取出进行处理,就是sample_per_cluster_=true的含义  
  80.     // 这个数组实际上就是从类别到该类别的随机的一个图像编号的映射  
  81.     vector<float> cur_class_img_; // current class index  
  82.       
  83.     // 当前图像的索引,处理的时候用  
  84.     int cur_img_; // current image index  
  85.       
  86.     // 图像索引(图像的编号从0开始)到类别的映射  
  87.     vector<int> img_idx_map_; // current image indices for each class  
  88.   
  89.     // array of lists: one list of image names per class  
  90.     // 这么一长串这么吓人  
  91.     // 分解开来看,要访问的时候  
  92.     // 最外层首先要提供索引,因为第一个类型是vector  
  93.     // 第二层还是vector,所以还是需要索引才能访问  
  94.     //  第三层是pair,访问第一个可以用first,第二个用second  
  95.     // 如果第三层是first,则第四层直接就是string的值了  
  96.     // 如果第三层是second,则第四层就是pair,那么可以用first或者用second  
  97.     // 如果第四层是first,那么第五层就可以用索引访问  
  98.     // 如果第四层是second,那么第五层就直接是int值  
  99.     vector< vector< pair<string, pair<vector<float>, pair<vector<float>, int> > > > > img_list_;  
  100.   
  101.     // vector of (image, label) pairs  
  102.     // 外层是vector,所以用索引  
  103.     // 第二层是pair,所以用first或者second  
  104.     // 第三层是pair,所以继续用first或者second  
  105.     // 第四层是vector或者pair,如果第三层的是first,那么第四层就可以用索引访问  
  106.     // 如果第三层是second,那么第四层就直接得到值了  
  107.     vector< pair<string, pair<vector<float>, pair<vector<float>, int> > > > img_label_list_;      
  108. };  
  109.   
  110. }  
  111.   
  112. #endif /* CAFFE_HEATMAP_HPP_ */  

在介绍详细实现之前看看整体流程:
1)首先在SetUp该函数中读取,proto中的参数,从而获得一批数据的大小、heatmap的长和宽,对图像进行切割的大小,以及切割后的图像需要缩放到多大,还有就是是否需要对每个类别的图像进行采样、放置图像的根目录等信息。
此外还读取每个图像文件的路径、关节的坐标位置、crop的位置、是否进行采样。
如果在每个类上进行采样,还会生成一个数组,该数组对应的是图像的类别索引与图像的索引之间的映射。
此外还从文件中读取每个视频的mean,然后将所读取的mean放到vector容器中,便于在读取数据的时候从图像中取出mean。最后还会设置top的形状
2)在load_batch这个函数中就是真正地读取数据,并且对数据进行预处理,预处理主要是是否对图像进行分割,对平均值图像进行切割,并将切割的图像块放大到图像的大小,然后用图像减去该段视频切割并方法的平均值图像。减去均值大牛都说可以提升3个点,具体为什么我也不是很清楚。

实现.cpp文件部分:

[cpp]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. // Copyright 2015 Tomas Pfisterimg  
  2.   
  3. #include <fstream>  // NOLINT(readability/streams)  
  4. #include <iostream>  // NOLINT(readability/streams)  
  5. #include <string>  
  6. #include <utility>  
  7. #include <vector>  
  8.   
  9. #include "caffe/layers/data_layer.hpp"  
  10. #include "caffe/layer.hpp"  
  11. #include "caffe/util/io.hpp"  
  12. #include "caffe/util/math_functions.hpp"  
  13. #include "caffe/util/rng.hpp"  
  14.   
  15. #include <stdint.h>  
  16.   
  17. #include <cmath>  
  18.   
  19. #include <opencv2/core/core.hpp>  
  20. #include <opencv2/highgui/highgui.hpp>  
  21. #include <opencv2/highgui/highgui_c.h>  
  22. #include <opencv2/imgproc/imgproc.hpp>  
  23.   
  24. #include "caffe/layers/data_heatmap.hpp"  
  25. #include "caffe/util/benchmark.hpp"  
  26. #include <unistd.h>  
  27.   
  28.   
  29. namespace caffe  
  30. {  
  31.   
  32. template <typename Dtype>  
  33. DataHeatmapLayer<Dtype>::~DataHeatmapLayer<Dtype>() {  
  34.     this->StopInternalThread();  
  35. }  
  36.   
  37. // 读取参数文件中的一些数据什么的,然后初始化  
  38. template<typename Dtype>  
  39. void DataHeatmapLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,  
  40.         const vector<Blob<Dtype>*>& top) {  
  41.     HeatmapDataParameter heatmap_data_param = this->layer_param_.heatmap_data_param();  
  42.   
  43.     // Shortcuts  
  44.     // 类标索引字符串(也就是关节类型?)  
  45.     const std::string labelindsStr = heatmap_data_param.labelinds();  
  46.     // batchsize  
  47.     const int batchsize = heatmap_data_param.batchsize();  
  48.     // heatmap的宽度  
  49.     const int label_width = heatmap_data_param.label_width();  
  50.     // heatmap的高度  
  51.     const int label_height = heatmap_data_param.label_height();  
  52.     // crop的大小  
  53.     const int size = heatmap_data_param.cropsize();  
  54.     // crop之后再次进行resize之后的大小  
  55.     const int outsize = heatmap_data_param.outsize();  
  56.     //  label的batchsize  
  57.     const int label_batchsize = batchsize;  
  58.     // 每个cluster都要进行采样  
  59.     sample_per_cluster_ = heatmap_data_param.sample_per_cluster();  
  60.     // 存放图像文件的根路径  
  61.     root_img_dir_ = heatmap_data_param.root_img_dir();  
  62.   
  63.   
  64.     // initialise rng seed  
  65.     const unsigned int rng_seed = caffe_rng_rand();  
  66.     srand(rng_seed);  
  67.   
  68.     // get label inds to be used for training  
  69.     // 载入类标索引  
  70.     std::istringstream labelss(labelindsStr);  
  71.     LOG(INFO) << "using joint inds:";  
  72.     while (labelss)  
  73.     {  
  74.         std::string s;  
  75.         if (!std::getline(labelss, s, ',')) break;  
  76.         labelinds_.push_back(atof(s.c_str()));  
  77.         LOG(INFO) << atof(s.c_str());  
  78.     }  
  79.   
  80.     // load GT  
  81.     // shuffle file  
  82.     // 载入ground truth文件,即关节坐标文件  
  83.     std::string gt_path = heatmap_data_param.source();  
  84.     LOG(INFO) << "Loading annotation from " << gt_path;  
  85.   
  86.     std::ifstream infile(gt_path.c_str());  
  87.     string img_name, labels, cropInfos, clusterClassStr;  
  88.     if (!sample_per_cluster_)// 是否根据你指定的类别随机取图像  
  89.     {  
  90.         // sequential sampling  
  91.         // 文件名,关节位置坐标,crop的位置,是否均匀地在clusters上采样  
  92.         while (infile >> img_name >> labels >> cropInfos >> clusterClassStr)  
  93.         {  
  94.             // read comma-separated list of regression labels  
  95.             // 读取关节位置坐标  
  96.             std::vector <float> label;  
  97.             std::istringstream ss(labels);  
  98.             int labelCounter = 1;  
  99.             while (ss)  
  100.             {  
  101.                 // 读取一个数字  
  102.                 std::string s;  
  103.                 if (!std::getline(ss, s, ',')) break;  
  104.                 // 是否是类标索引中的值  
  105.                 // 如果labelinds为空或者为不为空在其中找到  
  106.                 if (labelinds_.empty() || std::find(labelinds_.begin(), labelinds_.end(), labelCounter) != labelinds_.end())  
  107.                 {  
  108.                     label.push_back(atof(s.c_str()));  
  109.                 }  
  110.                 labelCounter++;// 个数  
  111.             }  
  112.   
  113.             // read cropping info  
  114.             // 读取crop的信息  
  115.             std::vector <float> cropInfo;  
  116.             std::istringstream ss2(cropInfos);  
  117.             while (ss2)  
  118.             {  
  119.                 std::string s;  
  120.                 if (!std::getline(ss2, s, ',')) break;  
  121.                 cropInfo.push_back(atof(s.c_str()));  
  122.             }  
  123.   
  124.             int clusterClass = atoi(clusterClassStr.c_str());  
  125.             // 图像路径,关节坐标,crop信息、类别  
  126.             img_label_list_.push_back(std::make_pair(img_name, std::make_pair(label, std::make_pair(cropInfo, clusterClass))));  
  127.         }  
  128.   
  129.         // initialise image counter to 0  
  130.         cur_img_ = 0;  
  131.     }  
  132.     else  
  133.     {  
  134.         // uniform sampling w.r.t. classes  
  135.         // 根据类别均匀采样  
  136.         // 也就是说图像有若干个类别,然后每个类别下有若干个图像  
  137.         // 随机取其中一个图像  
  138.         while (infile >> img_name >> labels >> cropInfos >> clusterClassStr)  
  139.         {  
  140.             // 获得你指定的类别  
  141.             // 如果你制定为0  
  142.             int clusterClass = atoi(clusterClassStr.c_str());  
  143.         // 那么  
  144.             if (clusterClass + 1 > img_list_.size())  
  145.             {  
  146.                 // expand the array  
  147.                 img_list_.resize(clusterClass + 1);  
  148.             }  
  149.   
  150.             // read comma-separated list of regression labels  
  151.             // 读取关节的坐标位置到label这个vector  
  152.             std::vector <float> label;  
  153.             std::istringstream ss(labels);  
  154.             int labelCounter = 1;  
  155.             while (ss)  
  156.             {  
  157.                 std::string s;  
  158.                 if (!std::getline(ss, s, ',')) break;  
  159.                 if (labelinds_.empty() || std::find(labelinds_.begin(), labelinds_.end(), labelCounter) != labelinds_.end())  
  160.                 {  
  161.                     label.push_back(atof(s.c_str()));  
  162.                 }  
  163.                 labelCounter++;  
  164.             }  
  165.   
  166.             // read cropping info  
  167.             // 读取crop信息到cropinfo这个vector  
  168.             std::vector <float> cropInfo;  
  169.             std::istringstream ss2(cropInfos);  
  170.             while (ss2)  
  171.             {  
  172.                 std::string s;  
  173.                 if (!std::getline(ss2, s, ',')) break;  
  174.                 cropInfo.push_back(atof(s.c_str()));  
  175.             }  
  176.         // 每个clusterClass下都是一个vector,用于装各种图像  
  177.             img_list_[clusterClass].push_back(std::make_pair(img_name, std::make_pair(label, std::make_pair(cropInfo, clusterClass))));  
  178.         }// while结尾  
  179.         
  180.       // 图像的类别个数  
  181.         const int num_classes = img_list_.size();  
  182.   
  183.         // init image sampling  
  184.         cur_class_ = 0;  
  185.         // cur_class_img_中存放的是某个类别中随机取到的图像的索引值  
  186.         cur_class_img_.resize(num_classes);  
  187.   
  188.         // init image indices for each class  
  189.         for (int idx_class = 0; idx_class < num_classes; idx_class++)  
  190.         {  
  191.             // 是否需要根据类别随机取某个类别中的一个图像  
  192.             if (sample_per_cluster_)  
  193.             {  
  194.                 // img_list_[idx_class].size()是该idx_class这个类中图像的个数  
  195.                 // 产生从0-该类中图像个数之间的一个随机数  
  196.                 cur_class_img_[idx_class] = rand() % img_list_[idx_class].size();  
  197.                 // 图像类别个数  
  198.                 LOG(INFO) << idx_class << " size: " << img_list_[idx_class].size();  
  199.             }  
  200.             else  
  201.             {  
  202.                 cur_class_img_[idx_class] = 0;  
  203.             }  
  204.         }  
  205.     }  
  206.   
  207.     if (!heatmap_data_param.has_meanfile())// 是否有meanfile  
  208.     {  
  209.         // if no mean, assume input images are RGB (3 channels)  
  210.         this->datum_channels_ = 3;  
  211.         sub_mean_ = false;  
  212.     } else {  
  213.         // Implementation of per-video mean removal  
  214.      // 下面整个一段代码是将每个视频mean文件读取到Mat结构  
  215.        
  216.        
  217.         sub_mean_ = true;  
  218.         // 从参数文件中获取mean文件的路径  
  219.         string mean_path = heatmap_data_param.meanfile();  
  220.   
  221.         LOG(INFO) << "Loading mean file from " << mean_path;  
  222.         BlobProto blob_proto, blob_proto2;  
  223.         Blob<Dtype> data_mean;  
  224.         // 读取到blob,然后blob数据转换到data_mean  
  225.         ReadProtoFromBinaryFile(mean_path.c_str(), &blob_proto);  
  226.         data_mean.FromProto(blob_proto);  
  227.         LOG(INFO) << "mean file loaded";  
  228.   
  229.         // read config  
  230.         this->datum_channels_ = data_mean.channels();  
  231.         // mean值的数目,有多少个视频,就有多少个mean啊  
  232.         num_means_ = data_mean.num();  
  233.         LOG(INFO) << "num_means: " << num_means_;  
  234.   
  235.         // copy the per-video mean images to an array of OpenCV structures  
  236.         const Dtype* mean_buf = data_mean.cpu_data();  
  237.   
  238.         // extract means from beginning of proto file  
  239.         // mean文件中的图像的高度  
  240.         const int mean_height = data_mean.height();  
  241.         // mean文件中图像的宽度  
  242.         const int mean_width = data_mean.width();  
  243.         // 高度数组  
  244.         int mean_heights[num_means_];  
  245.         // 宽度数组  
  246.         int mean_widths[num_means_];  
  247.   
  248.         // offset in memory to mean images  
  249.         //  在mean图像中的偏移量  
  250.         const int meanOffset = 2 * (num_means_);  
  251.         for (int n = 0; n < num_means_; n++)  
  252.         {  
  253.             mean_heights[n] = mean_buf[2 * n];  
  254.             mean_widths[n] = mean_buf[2 * n + 1];  
  255.         }  
  256.   
  257.         // save means as OpenCV-compatible files  
  258.         // 将从protobin文件读取的blob存放到Mat中  
  259.         // 获得mean_image容器,这其中包含了若干个视频的mean值  
  260.         // 下面是分配内存  
  261.         for (int n = 0; n < num_means_; n++)  
  262.         {  
  263.             cv::Mat mean_img_tmp_;  
  264.             mean_img_tmp_.create(mean_heights[n], mean_widths[n], CV_32FC3);  
  265.             mean_img_.push_back(mean_img_tmp_);  
  266.             LOG(INFO) << "per-video mean file array created: " << n << ": " << mean_heights[n] << "x" << mean_widths[n] << " (" << size << ")";  
  267.         }  
  268.   
  269.         LOG(INFO) << "mean: " << mean_height << "x" << mean_width << " (" << size << ")";  
  270.     // 下面是实际的赋值  
  271.         for (int n = 0; n < num_means_; n++)  
  272.         {  
  273.             for (int i = 0; i < mean_heights[n]; i++)  
  274.             {  
  275.                 for (int j = 0; j < mean_widths[n]; j++)  
  276.                 {  
  277.                     for (int c = 0; c < this->datum_channels_; c++)  
  278.                     {  
  279.                         mean_img_[n].at<cv::Vec3f>(i, j)[c] = mean_buf[meanOffset + ((n * this->datum_channels_ + c) * mean_height + i) * mean_width + j]; //[c * mean_height * mean_width + i * mean_width + j];  
  280.                     }  
  281.                 }  
  282.             }  
  283.         }  
  284.   
  285.         LOG(INFO) << "mean file converted to OpenCV structures";  
  286.     }  
  287.   
  288.   
  289.     // init data  
  290.     // 改变数据形状  
  291.     this->transformed_data_.Reshape(batchsize, this->datum_channels_, outsize, outsize);  
  292.     top[0]->Reshape(batchsize, this->datum_channels_, outsize, outsize);  
  293.     for (int i = 0; i < this->PREFETCH_COUNT; ++i)  
  294.         this->prefetch_[i].data_.Reshape(batchsize, this->datum_channels_, outsize, outsize);  
  295.     this->datum_size_ = this->datum_channels_ * outsize * outsize;  
  296.   
  297.     // init label  
  298.     int label_num_channels;  
  299.     if (!sample_per_cluster_)// 如果不按照类别进行均匀采样  
  300.         label_num_channels = img_label_list_[0].second.first.size();// 获取关节坐标的数字的个数(注意是数字的个数,并不是坐标的个数,要除以2才能是坐标的个数哈)  
  301.     else// 如果按照类别均匀采样  
  302.         label_num_channels = img_list_[0][0].second.first.size();// 第0类的第0个图像的关节数字的个数  
  303.     label_num_channels /= 2;// 获得关节个数  
  304.       
  305.     // 将输出设置为对应的大小  
  306.     // top[0]是batchsize个图像数据  
  307.     // top[1]是batchsize个heatmap(一个heatmap有关节个数个channel)  
  308.     // label的batchsize,关节个数作为channel,关节的heatmap的高、关节heatmap的宽度  
  309.     top[1]->Reshape(label_batchsize, label_num_channels, label_height, label_width);  
  310.     for (int i = 0; i < this->PREFETCH_COUNT; ++i)  
  311.         this->prefetch_[i].label_.Reshape(label_batchsize, label_num_channels, label_height, label_width);  
  312.   
  313.     LOG(INFO) << "output data size: " << top[0]->num() << "," << top[0]->channels() << "," << top[0]->height() << "," << top[0]->width();  
  314.     LOG(INFO) << "output label size: " << top[1]->num() << "," << top[1]->channels() << "," << top[1]->height() << "," << top[1]->width();  
  315.     LOG(INFO) << "number of label channels: " << label_num_channels;  
  316.     LOG(INFO) << "datum channels: " << this->datum_channels_;  
  317.   
  318. }  
  319. // 根据初始化之后的信息读取实际的文件数据,以及关节的位置,并将关节位置转换为类标  
  320. template<typename Dtype>  
  321. void DataHeatmapLayer<Dtype>::load_batch(Batch<Dtype>* batch) {  
  322.   
  323.     CPUTimer batch_timer;  
  324.     batch_timer.Start();  
  325.     CHECK(batch->data_.count());  
  326.     HeatmapDataParameter heatmap_data_param = this->layer_param_.heatmap_data_param();  
  327.   
  328.     // Pointers to blobs' float data  
  329.     // 指向数据和类标的指针  
  330.     Dtype* top_data = batch->data_.mutable_cpu_data();  
  331.     Dtype* top_label = batch->label_.mutable_cpu_data();  
  332.   
  333.     cv::Mat img, img_res, img_annotation_vis, img_mean_vis, img_vis, img_res_vis, mean_img_this, seg, segTmp;  
  334.   
  335.     // Shortcuts to params  
  336.     // 是否显示读取的图像啥的,用户调试  
  337.     const bool visualise = this->layer_param_.visualise();  
  338.     // 是否对图像进行缩放  
  339.     const Dtype scale = heatmap_data_param.scale();  
  340.     // 每次读多少个图像  
  341.     const int batchsize = heatmap_data_param.batchsize();  
  342.     // heatmap的高度  
  343.     const int label_height = heatmap_data_param.label_height();  
  344.     // heatmap的宽度  
  345.     const int label_width = heatmap_data_param.label_width();  
  346.     // 需要旋转多少度  
  347.     const float angle_max = heatmap_data_param.angle_max();  
  348.     // 是否不要翻转第一个图  
  349.     const bool dont_flip_first = heatmap_data_param.dont_flip_first();  
  350.     // 是否翻转关节的坐标  
  351.     const bool flip_joint_labels = heatmap_data_param.flip_joint_labels();  
  352.     // 关节的坐标数值需要乘以这个multfact  
  353.     const int multfact = heatmap_data_param.multfact();  
  354.     // 图像是否需要分割  
  355.     const bool segmentation = heatmap_data_param.segmentation();  
  356.     // 切割的图像的块的带下  
  357.     const int size = heatmap_data_param.cropsize();  
  358.     // 切割之后的图像块需要缩放到outsize大小  
  359.     const int outsize = heatmap_data_param.outsize();  
  360.     const int num_aug = 1;  
  361.     // 缩放因子  
  362.     const float resizeFact = (float)outsize / (float)size;  
  363.     // 是不是需要随机切图像块  
  364.     const bool random_crop = heatmap_data_param.random_crop();  
  365.   
  366.     // Shortcuts to global vars  
  367.     const bool sub_mean = this->sub_mean_;  
  368.     const int channels = this->datum_channels_;  
  369.   
  370.     // What coordinates should we flip when mirroring images?  
  371.     // For pose estimation with joints assumes i=0,1 are for head, and i=2,3 left wrist, i=4,5 right wrist etc  
  372.     //     in which case dont_flip_first should be set to true.  
  373.     int flip_start_ind;  
  374.     if (dont_flip_first) flip_start_ind = 2;  
  375.     else flip_start_ind = 0;  
  376.   
  377.     if (visualise)  
  378.     {  
  379.         cv::namedWindow("original image", cv::WINDOW_AUTOSIZE);  
  380.         cv::namedWindow("cropped image", cv::WINDOW_AUTOSIZE);  
  381.         cv::namedWindow("interim resize image", cv::WINDOW_AUTOSIZE);  
  382.         cv::namedWindow("resulting image", cv::WINDOW_AUTOSIZE);  
  383.     }  
  384.   
  385.     // collect "batchsize" images  
  386.     std::vector<float> cur_label, cur_cropinfo;  
  387.     std::string img_name;  
  388.     int cur_class;  
  389.   
  390.     // loop over non-augmented images  
  391.     // 获取batchsize个图像,然后进行预处理  
  392.     for (int idx_img = 0; idx_img < batchsize; idx_img++)  
  393.     {  
  394.         // get image name and class  
  395.         // 获取文件名、label、cropinfo、类标  
  396.         this->GetCurImg(img_name, cur_label, cur_cropinfo, cur_class);  
  397.   
  398.         // get number of channels for image label  
  399.         // 获取关节的数值的个数(并不是关节个数哈,关节个数乘以2就是该数)  
  400.         int label_num_channels = cur_label.size();  
  401.        
  402.      // 将根路径和文件名称拼接并读取数据到img  
  403.         std::string img_path = this->root_img_dir_ + img_name;  
  404.         DLOG(INFO) << "img: " << img_path;  
  405.         img = cv::imread(img_path, CV_LOAD_IMAGE_COLOR);  
  406.   
  407.         // show image  
  408.         // 显示读取的图像  
  409.         if (visualise)  
  410.         {  
  411.             img_annotation_vis = img.clone();  
  412.             this->VisualiseAnnotations(img_annotation_vis, label_num_channels, cur_label, multfact);  
  413.             cv::imshow("original image", img_annotation_vis);  
  414.         }  
  415.   
  416.         // use if seg exists  
  417.         // 是否对图像分割  
  418.         // 分割的模板存放在segs目录  
  419.         // 读取分割模板到seg  
  420.         if (segmentation)  
  421.         {  
  422.             std::string seg_path = this->root_img_dir_ + "segs/" + img_name;  
  423.             std::ifstream ifile(seg_path.c_str());  
  424.   
  425.             // Skip this file if segmentation doesn't exist  
  426.             if (!ifile.good())  
  427.             {  
  428.                 LOG(INFO) << "file " << seg_path << " does not exist!";  
  429.                 idx_img--;  
  430.                 this->AdvanceCurImg();  
  431.                 continue;  
  432.             }  
  433.             ifile.close();  
  434.             seg = cv::imread(seg_path, CV_LOAD_IMAGE_GRAYSCALE);  
  435.         }  
  436.   
  437.         int width = img.cols;  
  438.         int height = img.rows;  
  439.         // size是crop的大小  
  440.         // 如果crop的大小太大x_border会变成负数,下面会进行pad  
  441.         int x_border = width - size;  
  442.         int y_border = height - size;  
  443.        
  444.        
  445.      // 将读取的图像转换为RGB  
  446.         // convert from BGR to RGB  
  447.         cv::cvtColor(img, img, CV_BGR2RGB);  
  448.   
  449.         // to float  
  450.         // 转换数据类型到float  
  451.         img.convertTo(img, CV_32FC3);  
  452.   
  453.         if (segmentation)  
  454.         {  
  455.             segTmp = cv::Mat::zeros(img.rows, img.cols, CV_32FC3);  
  456.             int threshold = 40;// 阈值  
  457.             // 获取分割模板  
  458.             seg = (seg > threshold);  
  459.             // 对图像进行分割  
  460.             segTmp.copyTo(img, seg);  
  461.         }  
  462.   
  463.         if (visualise)  
  464.             img_vis = img.clone();  
  465.   
  466.         // subtract per-video mean if used  
  467.         // 减去每个视频的均值  
  468.         int meanInd = 0;  
  469.         if (sub_mean)  
  470.         {  
  471.             // 由此可以看到每个视频的命名规则,就是目录的名字嘛,而且还是数字  
  472.             // 比如0,1,2,3,4  
  473.             // 假设路径是images/1/xxx.jpg  
  474.             // 那么获取的平均值索引就是1,然后再到mean_img_中得到对应的均值图像  
  475.             std::string delimiter = "/";  
  476.             std::string img_name_subdirImg = img_name.substr(img_name.find(delimiter) + 1, img_name.length());  
  477.             std::string meanIndStr = img_name_subdirImg.substr(0, img_name_subdirImg.find(delimiter));  
  478.             meanInd = atoi(meanIndStr.c_str()) - 1;  
  479.   
  480.             // subtract the cropped mean  
  481.             mean_img_this = this->mean_img_[meanInd].clone();  
  482.   
  483.             DLOG(INFO) << "Image size: " << width << "x" << height;  
  484.             DLOG(INFO) << "Crop info: " << cur_cropinfo[0] << " " <<  cur_cropinfo[1] << " " << cur_cropinfo[2] << " " << cur_cropinfo[3] << " " << cur_cropinfo[4];  
  485.             DLOG(INFO) << "Crop info after: " << cur_cropinfo[0] << " " <<  cur_cropinfo[1] << " " << cur_cropinfo[2] << " " << cur_cropinfo[3] << " " << cur_cropinfo[4];  
  486.             DLOG(INFO) << "Mean image size: " << mean_img_this.cols << "x" << mean_img_this.rows;  
  487.             DLOG(INFO) << "Cropping: " << cur_cropinfo[0] - 1 << " " << cur_cropinfo[2] - 1 << " " << width << " " << height;  
  488.   
  489.             // crop and resize mean image  
  490.             // 对mean文件进行切割并且调整其大小为图像大小  
  491.             // cur_cropinfo中的数据分别为x_left,x_right,y_left,y_right  
  492.             // 而Rect则是x,y,w,h,所以需要转换  
  493.             cv::Rect crop(cur_cropinfo[0] - 1, cur_cropinfo[2] - 1, cur_cropinfo[1] - cur_cropinfo[0], cur_cropinfo[3] - cur_cropinfo[2]);  
  494.             mean_img_this = mean_img_this(crop);// 这样就crop了  
  495.             cv::resize(mean_img_this, mean_img_this, img.size());  
  496.   
  497.             DLOG(INFO) << "Cropped mean image.";  
  498.           
  499.         // 原图像减去crop之后并放大成与原图像一样大小的平均值图像  
  500.         // 这是什么原理?????  
  501.             img -= mean_img_this;  
  502.   
  503.             DLOG(INFO) << "Subtracted mean image.";  
  504.   
  505.             if (visualise)  
  506.             {  
  507.                 img_vis -= mean_img_this;  
  508.                 img_mean_vis = mean_img_this.clone() / 255;  
  509.                 cv::cvtColor(img_mean_vis, img_mean_vis, CV_RGB2BGR);  
  510.                 cv::imshow("mean image", img_mean_vis);  
  511.             }  
  512.         }  
  513.   
  514.         // pad images that aren't wide enough  
  515.         // 如果crop大小大于图像大小则padding,图像得右侧padding  
  516.         if (x_border < 0)  
  517.         {  
  518.             DLOG(INFO) << "padding " << img_path << " -- not wide enough.";  
  519.             // 函数原型如下  
  520.           // void copyMakeBorder( const Mat& src, Mat& dst,  
  521.           // int top, int bottom, int left, int right,  
  522.           // int borderType, const Scalar& value=Scalar() );  
  523.             cv::copyMakeBorder(img, img, 0, 0, 0, -x_border, cv::BORDER_CONSTANT, cv::Scalar(0, 0, 0));  
  524.             width = img.cols;  
  525.             x_border = width - size;  
  526.   
  527.             // add border offset to joints  
  528.             // 因为pad过图像的右侧了所以需要调整关节的x坐标  
  529.             for (int i = 0; i < label_num_channels; i += 2)// 注意这里是i+=2哦!  
  530.                 cur_label[i] = cur_label[i] + x_border;  
  531.   
  532.             DLOG(INFO) << "new width: " << width << "   x_border: " << x_border;  
  533.             if (visualise)// 显示经过padding的图像  
  534.             {  
  535.                 img_vis = img.clone();  
  536.                 cv::copyMakeBorder(img_vis, img_vis, 0, 0, 0, -x_border, cv::BORDER_CONSTANT, cv::Scalar(0, 0, 0));  
  537.             }  
  538.         }  
  539.   
  540.         DLOG(INFO) << "Entering jitter loop.";  
  541.   
  542.         // loop over the jittered versions  
  543.         // 将关节位置转换为heatmap  
  544.         for (int idx_aug = 0; idx_aug < num_aug; idx_aug++)  
  545.         {  
  546.             // augmented image index in the resulting batch  
  547.             const int idx_img_aug = idx_img * num_aug + idx_aug;  
  548.               
  549.             // 关节坐标,首先将从文件读取的关节坐标赋值给它  
  550.             // 接下来因为要对图像进行crop,crop之后的图像还要resize  
  551.             // 所以对应的关节坐标也要进行crop和缩放,经过这个处理的  
  552.             // 关节位置就存放在了 cur_label_aug  
  553.             std::vector<float> cur_label_aug = cur_label;  
  554.           
  555.         // 是否随机crop  
  556.             if (random_crop)  
  557.             {  
  558.                 // random sampling  
  559.                 DLOG(INFO) << "random crop sampling";  
  560.   
  561.                 // horizontal flip  
  562.                 // 随机旋转是否需要水平翻转  
  563.                 if (rand() % 2)  
  564.                 {  
  565.                     // flip,0表示水平  
  566.                     // 水平翻转  
  567.                     cv::flip(img, img, 1);  
  568.   
  569.                     if (visualise)  
  570.                         cv::flip(img_vis, img_vis, 1);  
  571.   
  572.                     // "flip" annotation coordinates  
  573.                     // 将图像的坐标也翻转了  
  574.                     for (int i = 0; i < label_num_channels; i += 2)  
  575.                         // width 是原始图像的宽度,原始图像的宽度除以multfact就是关节的图像宽度,关节图像的宽度减去关节的x坐标就是翻转过来的x坐标  
  576.                         cur_label_aug[i] = (float)width / (float)multfact - cur_label_aug[i];  
  577.   
  578.                     // "flip" annotation joint numbers  
  579.                     // assumes i=0,1 are for head, and i=2,3 left wrist, i=4,5 right wrist etc  
  580.                     // where coordinates are (x,y)  
  581.                     // 将索引位置也翻转了。。。  
  582.                     if (flip_joint_labels)  
  583.                     {  
  584.                         float tmp_x, tmp_y;  
  585.                         for (int i = flip_start_ind; i < label_num_channels; i += 4)  
  586.                         {  
  587.                             CHECK_LT(i + 3, label_num_channels);  
  588.                             tmp_x = cur_label_aug[i];  
  589.                             tmp_y = cur_label_aug[i + 1];  
  590.                             cur_label_aug[i] = cur_label_aug[i + 2];  
  591.                             cur_label_aug[i + 1] = cur_label_aug[i + 3];  
  592.                             cur_label_aug[i + 2] = tmp_x;  
  593.                             cur_label_aug[i + 3] = tmp_y;  
  594.                         }  
  595.                     }  
  596.                 }  
  597.   
  598.                 // left-top coordinates of the crop [0;x_border] x [0;y_border]  
  599.                 // 生成左上的坐标,用于切割图像  
  600.                 int x0 = 0, y0 = 0;  
  601.                 x0 = rand() % (x_border + 1);  
  602.                 y0 = rand() % (y_border + 1);  
  603.   
  604.                 // do crop  
  605.                 cv::Rect crop(x0, y0, size, size);  
  606.   
  607.                 // NOTE: no full copy performed, so the original image buffer is affected by the transformations below  
  608.                 // img_crop与img公用一个内存,所以在img_crop中所作的更改对img也会有  
  609.                 cv::Mat img_crop(img, crop);  
  610.   
  611.                 // "crop" annotations  
  612.                 // 万一关节的位置在crop的大小之外怎么办???疑问  
  613.                 for (int i = 0; i < label_num_channels; i += 2)  
  614.                 {  
  615.                     cur_label_aug[i] -= (float)x0 / (float) multfact;  
  616.                     cur_label_aug[i + 1] -= (float)y0 / (float) multfact;  
  617.                 }  
  618.   
  619.                 // show image  
  620.                 if (visualise)  
  621.                 {  
  622.                     DLOG(INFO) << "cropped image";  
  623.                     cv::Mat img_vis_crop(img_vis, crop);  
  624.                     cv::Mat img_res_vis = img_vis_crop / 255;  
  625.                     cv::cvtColor(img_res_vis, img_res_vis, CV_RGB2BGR);  
  626.                     this->VisualiseAnnotations(img_res_vis, label_num_channels, cur_label_aug, multfact);  
  627.                     cv::imshow("cropped image", img_res_vis);  
  628.                 }  
  629.   
  630.                 // rotations  
  631.                 // 旋转图像到一个均匀分布的角度  
  632.                 float angle = Uniform(-angle_max, angle_max);  
  633.                 cv::Mat M = this->RotateImage(img_crop, angle);  
  634.   
  635.                 // also flip & rotate labels  
  636.                 // 遍历所有关节坐标  
  637.                 for (int i = 0; i < label_num_channels; i += 2)  
  638.                 {  
  639.                     // convert to image space  
  640.                     // 将关节坐标转换到图像中的坐标  
  641.                     float x = cur_label_aug[i] * (float) multfact;  
  642.                     float y = cur_label_aug[i + 1] * (float) multfact;  
  643.   
  644.                     // rotate  
  645.                     // ?为啥  
  646.                     cur_label_aug[i] = M.at<double>(0, 0) * x + M.at<double>(0, 1) * y + M.at<double>(0, 2);  
  647.                     cur_label_aug[i + 1] = M.at<double>(1, 0) * x + M.at<double>(1, 1) * y + M.at<double>(1, 2);  
  648.   
  649.                     // convert back to joint space  
  650.                     // 转换回关节空间  
  651.                     cur_label_aug[i] /= (float) multfact;  
  652.                     cur_label_aug[i + 1] /= (float) multfact;  
  653.                 }  
  654.   
  655.                 img_res = img_crop;  
  656.             } else {// 中心crop(就是图像的中心crop啊)  
  657.                 // determinsitic sampling  
  658.                 DLOG(INFO) << "deterministic crop sampling (centre)";  
  659.   
  660.                 // centre crop  
  661.                 const int y0 = y_border / 2;  
  662.                 const int x0 = x_border / 2;  
  663.   
  664.                 DLOG(INFO) << "cropping image from " << x0 << "x" << y0;  
  665.   
  666.                 // do crop  
  667.                 cv::Rect crop(x0, y0, size, size);  
  668.                 cv::Mat img_crop(img, crop);  
  669.   
  670.                 DLOG(INFO) << "cropping annotations.";  
  671.   
  672.                 // "crop" annotations  
  673.                 // 长见识了,关节的annotation也是需要crop的  
  674.                 for (int i = 0; i < label_num_channels; i += 2)  
  675.                 {  
  676.                     // 除以multfact转换到关节坐标,然后再减去  
  677.                     // 不过我有疑问,万一crop之后的图像没有关节咋办  
  678.                     // 这样真的好吗  
  679.                     cur_label_aug[i] -= (float)x0 / (float) multfact;  
  680.                     cur_label_aug[i + 1] -= (float)y0 / (float) multfact;  
  681.                 }  
  682.   
  683.                 if (visualise)  
  684.                 {  
  685.                     cv::Mat img_vis_crop(img_vis, crop);  
  686.                     cv::Mat img_res_vis = img_vis_crop.clone() / 255;  
  687.                     cv::cvtColor(img_res_vis, img_res_vis, CV_RGB2BGR);  
  688.                     this->VisualiseAnnotations(img_res_vis, label_num_channels, cur_label_aug, multfact);  
  689.                     cv::imshow("cropped image", img_res_vis);  
  690.                 }  
  691.                 img_res = img_crop;  
  692.             }// end of else  
  693.   
  694.             // show image  
  695.             if (visualise)  
  696.             {  
  697.                 cv::Mat img_res_vis = img_res / 255;  
  698.                 cv::cvtColor(img_res_vis, img_res_vis, CV_RGB2BGR);  
  699.                 this->VisualiseAnnotations(img_res_vis, label_num_channels, cur_label_aug, multfact);  
  700.                 cv::imshow("interim resize image", img_res_vis);  
  701.             }  
  702.   
  703.             DLOG(INFO) << "Resizing output image.";  
  704.   
  705.             // resize to output image size  
  706.             // 将crop之后的图像弄到给定的大小  
  707.             cv::Size s(outsize, outsize);  
  708.             cv::resize(img_res, img_res, s);  
  709.   
  710.             // "resize" annotations  
  711.             // resize 标注的关节  
  712.             // 将图像进行缩放了,那么关节的坐标也要缩放  
  713.             for (int i = 0; i < label_num_channels; i++)  
  714.                 cur_label_aug[i] *= resizeFact;  
  715.   
  716.             // show image  
  717.             if (visualise)  
  718.             {  
  719.                 cv::Mat img_res_vis = img_res / 255;  
  720.                 cv::cvtColor(img_res_vis, img_res_vis, CV_RGB2BGR);  
  721.                 this->VisualiseAnnotations(img_res_vis, label_num_channels, cur_label_aug, multfact);  
  722.                 cv::imshow("resulting image", img_res_vis);  
  723.             }  
  724.   
  725.             // show image  
  726.             if (visualise && sub_mean)  
  727.             {  
  728.                 cv::Mat img_res_meansub_vis = img_res / 255;  
  729.                 cv::cvtColor(img_res_meansub_vis, img_res_meansub_vis, CV_RGB2BGR);  
  730.                 cv::imshow("mean-removed image", img_res_meansub_vis);  
  731.             }  
  732.   
  733.             // multiply by scale  
  734.             // 去均值、crop、缩放之后的像素值乘以该scale得到最终的图像的  
  735.             if (scale != 1.0)  
  736.                 img_res *= scale;  
  737.   
  738.             // resulting image dims  
  739.             const int channel_size = outsize * outsize;  
  740.             const int img_size = channel_size * channels;  
  741.   
  742.             // store image data  
  743.             // 将处理好的图像存放到top_data  
  744.             DLOG(INFO) << "storing image";  
  745.             for (int c = 0; c < channels; c++)  
  746.             {  
  747.                 for (int i = 0; i < outsize; i++)  
  748.                 {  
  749.                     for (int j = 0; j < outsize; j++)  
  750.                     {  
  751.                         top_data[idx_img_aug * img_size + c * channel_size + i * outsize + j] = img_res.at<cv::Vec3f>(i, j)[c];  
  752.                     }  
  753.                 }  
  754.             }  
  755.   
  756.             // store label as gaussian  
  757.             // 将关节转换为高斯图像  
  758.             DLOG(INFO) << "storing labels";  
  759.             const int label_channel_size = label_height * label_width;  
  760.             const int label_img_size = label_channel_size * label_num_channels / 2;  
  761.             cv::Mat dataMatrix = cv::Mat::zeros(label_height, label_width, CV_32FC1);  
  762.             float label_resize_fact = (float) label_height / (float) outsize;  
  763.             float sigma = 1.5;  
  764.   
  765.             for (int idx_ch = 0; idx_ch < label_num_channels / 2; idx_ch++)  
  766.             {  
  767.                 // 将经过缩放的关节转换到图像空间的坐标(也就是乘以multfact),再将缩小之后的图像空间坐标转换到缩小之前的图像空间坐标(也就是乘以label_resize_fact)  
  768.                 float x = label_resize_fact * cur_label_aug[2 * idx_ch] * multfact;  
  769.                 float y = label_resize_fact * cur_label_aug[2 * idx_ch + 1] * multfact;  
  770.                 for (int i = 0; i < label_height; i++)  
  771.                 {  
  772.                     for (int j = 0; j < label_width; j++)  
  773.                     {  
  774.                         // 计算索引  
  775.                         int label_idx = idx_img_aug * label_img_size + idx_ch * label_channel_size + i * label_height + j;  
  776.                         float gaussian = ( 1 / ( sigma * sqrt(2 * M_PI) ) ) * exp( -0.5 * ( pow(i - y, 2.0) + pow(j - x, 2.0) ) * pow(1 / sigma, 2.0) );  
  777.                         gaussian = 4 * gaussian;  
  778.                           
  779.                         // 存入到top_label  
  780.                         top_label[label_idx] = gaussian;  
  781.   
  782.                         if (idx_ch == 0)  
  783.                             dataMatrix.at<float>((int)j, (int)i) = gaussian;  
  784.                     }  
  785.                 }  
  786.             }  
  787.   
  788.         } // jittered versions loop  
  789.   
  790.         DLOG(INFO) << "next image";  
  791.   
  792.         // move to the next image  
  793.         // Advance是进行  
  794.         // Cur是表示当前  
  795.         // 那么就是移动到下一个图像  
  796.         this->AdvanceCurImg();  
  797.   
  798.         if (visualise)  
  799.             cv::waitKey(0);  
  800.   
  801.   
  802.     } // original image loop  
  803.   
  804.     batch_timer.Stop();  
  805.     DLOG(INFO) << "Prefetch batch: " << batch_timer.MilliSeconds() << " ms.";  
  806. }  
  807.   
  808.   
  809. // 获取当前图像的路径、类标、crop信息、类别  
  810. template<typename Dtype>  
  811. void DataHeatmapLayer<Dtype>::GetCurImg(string& img_name, std::vector<float>& img_label, std::vector<float>& crop_info, int& img_class)  
  812. {  
  813.   
  814.     if (!sample_per_cluster_)  
  815.     {  
  816.         img_name = img_label_list_[cur_img_].first;  
  817.         img_label = img_label_list_[cur_img_].second.first;  
  818.         crop_info = img_label_list_[cur_img_].second.second.first;  
  819.         img_class = img_label_list_[cur_img_].second.second.second;  
  820.     }  
  821.     else  
  822.     {  
  823.         img_class = cur_class_;  
  824.         // 看见没,这里用到了cur_class_img_,这个在SetUp中生成的随机数作为该类别的图像索引,该随机数的范围在[0,该类别图像的个数-1]之间。  
  825.         img_name = img_list_[img_class][cur_class_img_[img_class]].first;  
  826.         img_label = img_list_[img_class][cur_class_img_[img_class]].second.first;  
  827.         crop_info = img_list_[img_class][cur_class_img_[img_class]].second.second.first;  
  828.     }  
  829. }  
  830.   
  831. // 实际上就是移动索引  
  832. template<typename Dtype>  
  833. void DataHeatmapLayer<Dtype>::AdvanceCurImg()  
  834. {  
  835.     if (!sample_per_cluster_)  
  836.     {  
  837.         if (cur_img_ < img_label_list_.size() - 1)  
  838.             cur_img_++;  
  839.         else  
  840.             cur_img_ = 0;  
  841.     }  
  842.     else  
  843.     {  
  844.         const int num_classes = img_list_.size();  
  845.   
  846.         if (cur_class_img_[cur_class_] < img_list_[cur_class_].size() - 1)  
  847.             cur_class_img_[cur_class_]++;  
  848.         else  
  849.             cur_class_img_[cur_class_] = 0;  
  850.   
  851.         // move to the next class  
  852.         if (cur_class_ < num_classes - 1)  
  853.             cur_class_++;  
  854.         else  
  855.             cur_class_ = 0;  
  856.     }  
  857.   
  858. }  
  859.   
  860. // 可视化关节点  
  861. template<typename Dtype>  
  862. void DataHeatmapLayer<Dtype>::VisualiseAnnotations(cv::Mat img_annotation_vis, int label_num_channels, std::vector<float>& img_class, int multfact)  
  863. {  
  864.     // colors  
  865.     const static cv::Scalar colors[] = {  
  866.         CV_RGB(0, 0, 255),  
  867.         CV_RGB(0, 128, 255),  
  868.         CV_RGB(0, 255, 255),  
  869.         CV_RGB(0, 255, 0),  
  870.         CV_RGB(255, 128, 0),  
  871.         CV_RGB(255, 255, 0),  
  872.         CV_RGB(255, 0, 0),  
  873.         CV_RGB(255, 0, 255)  
  874.     };  
  875.   
  876.     int numCoordinates = int(label_num_channels / 2);  
  877.   
  878.     // points  
  879.     // 将关节点放到centers数组中  
  880.     cv::Point centers[numCoordinates];  
  881.     for (int i = 0; i < label_num_channels; i += 2)  
  882.     {  
  883.         int coordInd = int(i / 2);  
  884.         centers[coordInd] = cv::Point(img_class[i] * multfact, img_class[i + 1] * multfact);  
  885.         // 给关节画圈圈  
  886.         cv::circle(img_annotation_vis, centers[coordInd], 1, colors[coordInd], 3);  
  887.     }  
  888.   
  889.     // connecting lines  
  890.     // 1,3,5是一条膀子  
  891.     // 2,4,6是一条膀子  
  892.     cv::line(img_annotation_vis, centers[1], centers[3], CV_RGB(0, 255, 0), 1, CV_AA);  
  893.     cv::line(img_annotation_vis, centers[2], centers[4], CV_RGB(255, 255, 0), 1, CV_AA);  
  894.     cv::line(img_annotation_vis, centers[3], centers[5], CV_RGB(0, 0, 255), 1, CV_AA);  
  895.     cv::line(img_annotation_vis, centers[4], centers[6], CV_RGB(0, 255, 255), 1, CV_AA);  
  896. }  
  897.   
  898. // [min,max]的均匀分布  
  899. template <typename Dtype>  
  900. float DataHeatmapLayer<Dtype>::Uniform(const float min, const float max) {  
  901.     float random = ((float) rand()) / (float) RAND_MAX;  
  902.     float diff = max - min;  
  903.     float r = random * diff;  
  904.     return min + r;  
  905. }  
  906.   
  907. // 旋转图像  
  908. template <typename Dtype>  
  909. cv::Mat DataHeatmapLayer<Dtype>::RotateImage(cv::Mat src, float rotation_angle)  
  910. {  
  911.     cv::Mat rot_mat(2, 3, CV_32FC1);  
  912.     cv::Point center = cv::Point(src.cols / 2, src.rows / 2);  
  913.     double scale = 1;  
  914.   
  915.     // Get the rotation matrix with the specifications above  
  916.     rot_mat = cv::getRotationMatrix2D(center, rotation_angle, scale);  
  917.   
  918.     // Rotate the warped image  
  919.     cv::warpAffine(src, src, rot_mat, src.size());  
  920.   
  921.     return rot_mat;  
  922. }  
  923.   
  924. INSTANTIATE_CLASS(DataHeatmapLayer);  
  925. REGISTER_LAYER_CLASS(DataHeatmap);  
  926.   
  927. // namespace caffe  
3、最后看看在配置文件中怎么使用该层?
[cpp]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. layer {  
  2.   name: "data"  
  3.   type: "DataHeatmap" // 层的类型是DataHeatmap  
  4.   top: "data"  
  5.   top: "label"  
  6.   visualise: false    // 是否可视化  
  7.   include: { phase: TRAIN }     
  8.   heatmap_data_param {  
  9.     source: "/data/tp/flic/train_shuffle.txt"  
  10.     root_img_dir: "/mnt/ramdisk/tp/flic/"     
  11.     batchsize: 14  
  12.     cropsize: 248  
  13.     outsize: 256  
  14.     sample_per_cluster: false  
  15.     random_crop: true  
  16.     label_width: 64  
  17.     label_height: 64  
  18.     segmentation: false  
  19.     flip_joint_labels: true  
  20.     dont_flip_first: true  
  21.     angle_max: 40     
  22.     multfact: 1  # set to 282 if using preprocessed data from website  
  23.   }  
  24. }  
先浏览一下作者原始的配置文件的代码:

[cpp]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. name: "HeatmapFusionNet"  
  2. layer {  
  3. name: "data"  
  4. type: "DataHeatmap"  
  5. top: "data"  
  6. top: "label"  
  7. visualise: false  
  8. include: { phase: TRAIN }  
  9. heatmap_data_param {  
  10. source: "/data/tp/flic/train_shuffle.txt"  
  11. root_img_dir: "/mnt/ramdisk/tp/flic/"  
  12. batchsize: 14  
  13. cropsize: 248  
  14. outsize: 256  
  15. sample_per_cluster: false  
  16. random_crop: true  
  17. label_width: 64  
  18. label_height: 64  
  19. segmentation: false  
  20. flip_joint_labels: true  
  21. dont_flip_first: true  
  22. angle_max: 40  
  23. multfact: 1 # set to 282 if using preprocessed data from website  
  24. }  
  25. }  
  26. layer {  
  27. name: "data"  
  28. type: "DataHeatmap"  
  29. top: "data"  
  30. top: "label"  
  31. visualise: false  
  32. include: { phase: TEST }  
  33. heatmap_data_param {  
  34. source: "/data/tp/flic/test_shuffle.txt"  
  35. root_img_dir: "/mnt/ramdisk/tp/flic/"  
  36. batchsize: 1  
  37. cropsize: 248  
  38. outsize: 256  
  39. sample_per_cluster: false  
  40. random_crop: false  
  41. label_width: 64  
  42. label_height: 64  
  43. segmentation: false  
  44. dont_flip_first: true  
  45. angle_max: 0  
  46. multfact: 1 # set to 282 if using preprocessed data from website  
  47. }  
  48. }  
  49. #########################################################  
  50. layer {  
  51. name: "conv1"  
  52. type: "Convolution"  
  53. bottom: "data"  
  54. top: "conv1"  
  55. param {  
  56. lr_mult: 1  
  57. decay_mult: 1  
  58. }  
  59. param {  
  60. lr_mult: 2  
  61. decay_mult: 0  
  62. }  
  63. convolution_param {  
  64. num_output: 128  
  65. kernel_size: 5  
  66. stride: 1  
  67. pad: 2  
  68. weight_filler {  
  69. type: "gaussian"  
  70. std: 0.01  
  71. }  
  72. bias_filler {  
  73. type: "constant"  
  74. value: 0  
  75. }  
  76. }  
  77. }  
  78. layer {  
  79. name: "relu1"  
  80. type: "ReLU"  
  81. bottom: "conv1"  
  82. top: "conv1"  
  83. }  
  84. layer {  
  85. name: "pool1"  
  86. type: "Pooling"  
  87. bottom: "conv1"  
  88. top: "pool1"  
  89. pooling_param {  
  90. pool: MAX  
  91. kernel_size: 2  
  92. stride: 2  
  93. }  
  94. }  
  95. #########################################################  
  96. layer {  
  97. name: "conv2"  
  98. type: "Convolution"  
  99. bottom: "pool1"  
  100. top: "conv2"  
  101. param {  
  102. lr_mult: 1  
  103. decay_mult: 1  
  104. }  
  105. param {  
  106. lr_mult: 2  
  107. decay_mult: 0  
  108. }  
  109. convolution_param {  
  110. num_output: 128  
  111. kernel_size: 5  
  112. pad: 2  
  113. weight_filler {  
  114. type: "gaussian"  
  115. std: 0.01  
  116. }  
  117. bias_filler {  
  118. type: "constant"  
  119. }  
  120. }  
  121. }  
  122. layer {  
  123. name: "relu2"  
  124. type: "ReLU"  
  125. bottom: "conv2"  
  126. top: "conv2"  
  127. }  
  128. layer {  
  129. name: "pool2"  
  130. type: "Pooling"  
  131. bottom: "conv2"  
  132. top: "pool2"  
  133. pooling_param {  
  134. pool: MAX  
  135. kernel_size: 2  
  136. stride: 2  
  137. }  
  138. }  
  139. #########################################################  
  140. layer {  
  141. name: "conv3"  
  142. type: "Convolution"  
  143. bottom: "pool2"  
  144. top: "conv3"  
  145. param {  
  146. lr_mult: 1  
  147. decay_mult: 1  
  148. }  
  149. param {  
  150. lr_mult: 2  
  151. decay_mult: 0  
  152. }  
  153. convolution_param {  
  154. num_output: 128  
  155. kernel_size: 5  
  156. pad: 2  
  157. weight_filler {  
  158. type: "gaussian"  
  159. std: 0.01  
  160. }  
  161. bias_filler {  
  162. type: "constant"  
  163. value: 0  
  164. }  
  165. }  
  166. }  
  167. layer {  
  168. name: "relu3"  
  169. type: "ReLU"  
  170. bottom: "conv3"  
  171. top: "conv3"  
  172. }  
  173. #########################################################  
  174. layer {  
  175. name: "conv4"  
  176. type: "Convolution"  
  177. bottom: "conv3"  
  178. top: "conv4"  
  179. param {  
  180. lr_mult: 1  
  181. decay_mult: 1  
  182. }  
  183. param {  
  184. lr_mult: 2  
  185. decay_mult: 0  
  186. }  
  187. convolution_param {  
  188. num_output: 256  
  189. kernel_size: 9  
  190. pad: 4  
  191. weight_filler {  
  192. type: "gaussian"  
  193. std: 0.01  
  194. }  
  195. bias_filler {  
  196. type: "constant"  
  197. }  
  198. }  
  199. }  
  200. layer {  
  201. name: "relu4"  
  202. type: "ReLU"  
  203. bottom: "conv4"  
  204. top: "conv4"  
  205. }  
  206. #########################################################  
  207. layer {  
  208. name: "conv5"  
  209. type: "Convolution"  
  210. bottom: "conv4"  
  211. top: "conv5"  
  212. param {  
  213. lr_mult: 1  
  214. decay_mult: 1  
  215. }  
  216. param {  
  217. lr_mult: 2  
  218. decay_mult: 0  
  219. }  
  220. convolution_param {  
  221. num_output: 512  
  222. kernel_size: 9  
  223. pad: 4  
  224. weight_filler {  
  225. type: "gaussian"  
  226. std: 0.01  
  227. }  
  228. bias_filler {  
  229. type: "constant"  
  230. }  
  231. }  
  232. }  
  233. layer {  
  234. name: "relu5"  
  235. type: "ReLU"  
  236. bottom: "conv5"  
  237. top: "conv5"  
  238. }  
  239. #########################################################  
  240. layer {  
  241. name: "conv6"  
  242. type: "Convolution"  
  243. bottom: "conv5"  
  244. top: "conv6"  
  245. param {  
  246. lr_mult: 1  
  247. decay_mult: 1  
  248. }  
  249. param {  
  250. lr_mult: 2  
  251. decay_mult: 0  
  252. }  
  253. convolution_param {  
  254. num_output: 256  
  255. # pad: 2  
  256. kernel_size: 1  
  257. weight_filler {  
  258. type: "gaussian"  
  259. std: 0.01  
  260. }  
  261. bias_filler {  
  262. type: "constant"  
  263. }  
  264. }  
  265. }  
  266. layer {  
  267. name: "relu6"  
  268. type: "ReLU"  
  269. bottom: "conv6"  
  270. top: "conv6"  
  271. }  
  272. #########################################################  
  273. layer {  
  274. name: "conv7"  
  275. type: "Convolution"  
  276. bottom: "conv6"  
  277. top: "conv7"  
  278. param {  
  279. lr_mult: 1  
  280. decay_mult: 1  
  281. }  
  282. param {  
  283. lr_mult: 2  
  284. decay_mult: 0  
  285. }  
  286. convolution_param {  
  287. num_output: 256  
  288. kernel_size: 1  
  289. weight_filler {  
  290. type: "gaussian"  
  291. std: 0.01  
  292. }  
  293. bias_filler {  
  294. type: "constant"  
  295. }  
  296. }  
  297. }  
  298. layer {  
  299. name: "relu7"  
  300. type: "ReLU"  
  301. bottom: "conv7"  
  302. top: "conv7"  
  303. }  
  304. #########################################################  
  305. layer {  
  306. name: "conv8"  
  307. type: "Convolution"  
  308. bottom: "conv7"  
  309. top: "conv8"  
  310. param {  
  311. lr_mult: 1  
  312. decay_mult: 1  
  313. }  
  314. param {  
  315. lr_mult: 2  
  316. decay_mult: 0  
  317. }  
  318. convolution_param {  
  319. num_output: 7  
  320. kernel_size: 1  
  321. weight_filler {  
  322. type: "gaussian"  
  323. std: 0.01  
  324. }  
  325. bias_filler {  
  326. type: "constant"  
  327. }  
  328. }  
  329. }  
  330. layer {  
  331. name: "relu8"  
  332. type: "ReLU"  
  333. bottom: "conv8"  
  334. top: "conv8"  
  335. }  
  336. #########################################################  
  337. layer {  
  338. name: "loss_heatmap"  
  339. type: "EuclideanLossHeatmap"  
  340. bottom: "conv8"  
  341. bottom: "label"  
  342. bottom: "data"  
  343. top: "loss_heatmap"  
  344. visualise: false  
  345. loss_weight: 1  
  346. }  
  347. #########################################################  
  348. layer {  
  349. name: "concat_fusion"  
  350. type: "Concat"  
  351. bottom: "conv3"  
  352. bottom: "conv7"  
  353. top: "concat_fusion"  
  354. concat_param {  
  355. concat_dim: 1  
  356. }  
  357. }  
  358. #########################################################  
  359. layer {  
  360. name: "conv1_fusion"  
  361. type: "Convolution"  
  362. bottom: "concat_fusion"  
  363. top: "conv1_fusion"  
  364. param {  
  365. lr_mult: 1  
  366. decay_mult: 1  
  367. }  
  368. param {  
  369. lr_mult: 2  
  370. decay_mult: 0  
  371. }  
  372. convolution_param {  
  373. num_output: 64  
  374. kernel_size: 7  
  375. stride: 1  
  376. pad: 3  
  377. weight_filler {  
  378. type: "gaussian"  
  379. std: 0.01  
  380. }  
  381. bias_filler {  
  382. type: "constant"  
  383. }  
  384. }  
  385. }  
  386. layer {  
  387. name: "relu1_fusion"  
  388. type: "ReLU"  
  389. bottom: "conv1_fusion"  
  390. top: "conv1_fusion"  
  391. }  
  392. #########################################################  
  393. layer {  
  394. name: "conv2_fusion"  
  395. type: "Convolution"  
  396. bottom: "conv1_fusion"  
  397. top: "conv2_fusion"  
  398. param {  
  399. lr_mult: 1  
  400. decay_mult: 1  
  401. }  
  402. param {  
  403. lr_mult: 2  
  404. decay_mult: 0  
  405. }  
  406. convolution_param {  
  407. num_output: 64  
  408. kernel_size: 13  
  409. stride: 1  
  410. pad: 6  
  411. weight_filler {  
  412. type: "gaussian"  
  413. std: 0.01  
  414. }  
  415. bias_filler {  
  416. type: "constant"  
  417. }  
  418. }  
  419. }  
  420. layer {  
  421. name: "relu2_fusion"  
  422. type: "ReLU"  
  423. bottom: "conv2_fusion"  
  424. top: "conv2_fusion"  
  425. }  
  426. #########################################################  
  427. layer {  
  428. name: "conv3_fusion"  
  429. type: "Convolution"  
  430. bottom: "conv2_fusion"  
  431. top: "conv3_fusion"  
  432. param {  
  433. lr_mult: 1  
  434. decay_mult: 1  
  435. }  
  436. param {  
  437. lr_mult: 2  
  438. decay_mult: 0  
  439. }  
  440. convolution_param {  
  441. num_output: 128  
  442. kernel_size: 13  
  443. stride: 1  
  444. pad: 6  
  445. weight_filler {  
  446. type: "gaussian"  
  447. std: 0.01  
  448. }  
  449. bias_filler {  
  450. type: "constant"  
  451. }  
  452. }  
  453. }  
  454. layer {  
  455. name: "relu3_fusion"  
  456. type: "ReLU"  
  457. bottom: "conv3_fusion"  
  458. top: "conv3_fusion"  
  459. }  
  460. #########################################################  
  461. layer {  
  462. name: "conv4_fusion"  
  463. type: "Convolution"  
  464. bottom: "conv3_fusion"  
  465. top: "conv4_fusion"  
  466. param {  
  467. lr_mult: 1  
  468. decay_mult: 1  
  469. }  
  470. param {  
  471. lr_mult: 2  
  472. decay_mult: 0  
  473. }  
  474. convolution_param {  
  475. num_output: 256  
  476. kernel_size: 1  
  477. stride: 1  
  478. pad: 0  
  479. weight_filler {  
  480. type: "gaussian"  
  481. std: 0.01  
  482. }  
  483. bias_filler {  
  484. type: "constant"  
  485. }  
  486. }  
  487. }  
  488. layer {  
  489. name: "relu4_fusion"  
  490. type: "ReLU"  
  491. bottom: "conv4_fusion"  
  492. top: "conv4_fusion"  
  493. }  
  494. #########################################################  
  495. layer {  
  496. name: "conv5_fusion"  
  497. type: "Convolution"  
  498. bottom: "conv4_fusion"  
  499. top: "conv5_fusion"  
  500. param {  
  501. lr_mult: 1  
  502. decay_mult: 1  
  503. }  
  504. param {  
  505. lr_mult: 2  
  506. decay_mult: 0  
  507. }  
  508. convolution_param {  
  509. num_output: 7  
  510. kernel_size: 1  
  511. stride: 1  
  512. pad: 0  
  513. weight_filler {  
  514. type: "gaussian"  
  515. std: 0.01  
  516. }  
  517. bias_filler {  
  518. type: "constant"  
  519. }  
  520. }  
  521. }  
  522. #########################################################  
  523. layer {  
  524. name: "loss_fusion"  
  525. type: "EuclideanLossHeatmap"  
  526. bottom: "conv5_fusion"  
  527. bottom: "label"  
  528. bottom: "data"  
  529. top: "loss_fusion"  
  530. visualise: false  
  531. loss_weight: 3  
  532. }  
下一步按照作者需要的层数逐步添加完成。

论文中还有其他的东西,我这一层就没法调用,只是编译完成了,最后说一下实现完之后要注册该层,还有caffe版本是最新的,如果中途出现一个关于opencv错误,修改makefile文件,解决如下图:


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值