caffe代码data_transform

本文详细介绍了Caffe框架中的数据预处理流程,包括数据层的Datum定义、数据的读取与写入、数据预处理类DataTransformer的具体实现,以及如何进行数据变换如缩放、裁剪和镜像。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

//DataTransformer需要输入的是blob,所以需要看一下里面的参数,因此再把这一部分内容的proto贴出来,这是新版的caffe  
/* 
// Specifies the shape (dimensions) of a Blob. 
message BlobShape { 
  repeated int64 dim = 1 [packed = true]; 
} 
 
message BlobProto { 
  optional BlobShape shape = 7; 
  repeated float data = 5 [packed = true]; 
  repeated float diff = 6 [packed = true]; 
  repeated double double_data = 8 [packed = true]; 
  repeated double double_diff = 9 [packed = true]; 
 
  // 4D dimensions -- deprecated.  Use "shape" instead. 
  optional int32 num = 1 [default = 0]; 
  optional int32 channels = 2 [default = 0]; 
  optional int32 height = 3 [default = 0]; 
  optional int32 width = 4 [default = 0]; 
} 
*/  
/////////////////TransformationParameter的caffe消息定义  
/* 
// Message that stores parameters used to apply transformation 
// to the data layer's data 
message TransformationParameter { 
  // For data pre-processing, we can do simple scaling and subtracting the 
  // data mean, if provided. Note that the mean subtraction is always carried 
  // out before scaling. 
  optional float scale = 1 [default = 1]; 
  // Specify if we want to randomly mirror data. 
  optional bool mirror = 2 [default = false]; 
  // Specify if we would like to randomly crop an image. 
  optional uint32 crop_size = 3 [default = 0]; 
  // mean_file and mean_value cannot be specified at the same time 
  optional string mean_file = 4; 
  // if specified can be repeated once (would substract it from all the channels) 
  // or can be repeated the same number of times as channels 
  // (would subtract them from the corresponding channel) 
  repeated float mean_value = 5; 
  // Force the decoded image to have 3 color channels. 
  optional bool force_color = 6 [default = false]; 
  // Force the decoded image to have 1 color channels. 
  optional bool force_gray = 7 [default = false]; 
} 
*/  
#ifdef USE_OPENCV  
#include <opencv2/core/core.hpp>  
#endif  // USE_OPENCV  
  
#include <string>  
#include <vector>  
  
#include "caffe/data_transformer.hpp"  
#include "caffe/util/io.hpp"  
#include "caffe/util/math_functions.hpp"  
#include "caffe/util/rng.hpp"  
  
namespace caffe {  
// 构造函数  
template<typename Dtype>  
DataTransformer<Dtype>::DataTransformer(const TransformationParameter& param,  
    Phase phase)  
    : param_(param), phase_(phase) {  
  // check if we want to use mean_file  
  // 判断是否有平均值文件  
  if (param_.has_mean_file()) {  
    CHECK_EQ(param_.mean_value_size(), 0) <<  
      "Cannot specify mean_file and mean_value at the same time";  
    // 平均值文件的路径  
    const string& mean_file = param.mean_file();  
    if (Caffe::root_solver()) {  
      LOG(INFO) << "Loading mean file from: " << mean_file;  
    }  
    BlobProto blob_proto;// 调用google/protobuf?? ,用于加速运算的数据接口,有时间再详细了解其应用方法   
//这个函数是实现了从二进制文件中读取数据到blob_proto中,猜测函数来自第3方库的google/protobuf模块   
    ReadProtoFromBinaryFileOrDie(mean_file.c_str(), &blob_proto);  
    data_mean_.FromProto(blob_proto);// 调用Blob类的成员函数FromRroto从BlobProto中加载数据   
  }  
  // check if we want to use mean_value  
  if (param_.mean_value_size() > 0) {  
    CHECK(param_.has_mean_file() == false) <<  
      "Cannot specify mean_file and mean_value at the same time";  
    for (int c = 0; c < param_.mean_value_size(); ++c) {  
      mean_values_.push_back(param_.mean_value(c));//将元素param_.mean_value(c)加入到mean_values_容器的最后一位  
    }  
  }  
}  
  
/*提前先描述一下数据层的Datum, 
Datum数据结构,Caffe并不是把向量和矩阵直接放进数据库的,而是将数据通过caffe.proto里定义的一个datum类来封装。数据库里放的是一个个的datum序列化成的字符串。Datum的定义摘录如下: 
message Datum { 
  optional int32 channels = 1; 
  optional int32 height = 2; 
  optional int32 width = 3; 
  // the actual image data, in bytes 
  optional bytes data = 4; 
  optional int32 label = 5; 
  // Optionally, the datum could also hold float data. 
  repeated float float_data = 6; 
  // If true data contains an encoded image that need to be decoded 
  optional bool encoded = 7 [default = false]; 
} 
一个Datum有三个维度,channels, height,和width,可以看做是少了num维度的Blob。存放数据的地方有两个:byte_data和float_data,分别存放整数型和浮点型数据。图像数据一般是整形,放在byte_data里,特征向量一般是浮点型,放在float_data里。label存放数据的类别标签,是整数型。encoded标识数据是否需要被解码(里面有可能放的是JPEG或者PNG之类经过编码的数据)。Datum这个数据结构将数据和标签封装在一起,兼容整形和浮点型数据。经过Protobuf编译后,可以在Python和C++中都提供高效的访问。同时Protubuf还为它提供了序列化与反序列化的功能。存放进LMDB的就是Datum序列化生成的字符串。 
Caffe中关于LMDB的代码有三类:生成数据集、读取数据集、生成特征向量。接下来就分别针对三者进行分析。 
生成数据集: 
生成数据集的代码在examples,随数据集提供,比如MNIST。 
首先,创建访问LMDB所需的一些变量: 
MDB_env *mdb_env; 
MDB_dbi mdb_dbi; 
MDB_val mdb_key, mdb_data; 
MDB_txn *mdb_txn; 
... 
mdb_env是整个数据库环境的句柄,mdb_dbi是环境中一个数据库的句柄,mdb_key和mdb_data用来存放向数据库中输入数据的“值”。mdb_txn是数据库事物操作的句柄,”txn”是”transaction”的缩写。 
然后,创建数据库环境,创建并打开数据库: 
if (db_backend == "lmdb") {  // lmdb 
  LOG(INFO) << "Opening lmdb " << db_path; 
  CHECK_EQ(mkdir(db_path, 0744), 0) 
      << "mkdir " << db_path << "failed"; 
  CHECK_EQ(mdb_env_create(&mdb_env), MDB_SUCCESS) << "mdb_env_create failed"; 
  CHECK_EQ(mdb_env_set_mapsize(mdb_env, 1099511627776), MDB_SUCCESS)  // 1TB 
      << "mdb_env_set_mapsize failed"; 
  CHECK_EQ(mdb_env_open(mdb_env, db_path, 0, 0664), MDB_SUCCESS) 
      << "mdb_env_open failed"; 
  CHECK_EQ(mdb_txn_begin(mdb_env, NULL, 0, &mdb_txn), MDB_SUCCESS) 
      << "mdb_txn_begin failed"; 
  CHECK_EQ(mdb_open(mdb_txn, NULL, 0, &mdb_dbi), MDB_SUCCESS) 
      << "mdb_open failed. Does the lmdb already exist? "; 
} else { 
  LOG(FATAL) << "Unknown db backend " << db_backend; 
} 
mkdir(db_path, 0744)为数据库创建文件夹,如果文件夹已经存在,程序会报错退出。也就是说,程序不会覆盖已有的数据库。已有的数据库如果不要了,需要手动删除。需要注意的是,LMDB的一个环境中是可以有多个数据库的,数据库之间以名字区分。mdb_open()的第二个参数实际上就是数据库的名称(char *)。当一个环境中只有一个数据库的时候,这个参数可以给NULL。最后,为每一个图像创建Datum对象,向对象内写入数据,然后将其序列化成字符串,将字符串放入数据库中: 
Datum datum; 
datum.set_channels(1); 
datum.set_height(rows); 
datum.set_width(cols); 
for (int item_id = 0; item_id < num_items; ++item_id) { 
  image_file.read(pixels, rows * cols); 
  label_file.read(&label, 1); 
  datum.set_data(pixels, rows*cols); 
  datum.set_label(label); 
  snprintf(key_cstr, kMaxKeyLength, "%08d", item_id); 
  datum.SerializeToString(&value); 
  string keystr(key_cstr); 
 
  // Put in db 
  if (db_backend == "lmdb") {  // lmdb 
    mdb_data.mv_size = value.size(); 
    mdb_data.mv_data = reinterpret_cast<void*>(&value[0]); 
    mdb_key.mv_size = keystr.size(); 
    mdb_key.mv_data = reinterpret_cast<void*>(&keystr[0]); 
    CHECK_EQ(mdb_put(mdb_txn, mdb_dbi, &mdb_key, &mdb_data, 0), MDB_SUCCESS) 
        << "mdb_put failed"; 
  } else { 
    LOG(FATAL) << "Unknown db backend " << db_backend; 
  } 
 
  if (++count % 1000 == 0) { 
    // Commit txn 
    if (db_backend == "lmdb") {  // lmdb 
      CHECK_EQ(mdb_txn_commit(mdb_txn), MDB_SUCCESS) 
          << "mdb_txn_commit failed"; 
      CHECK_EQ(mdb_txn_begin(mdb_env, NULL, 0, &mdb_txn), MDB_SUCCESS) 
          << "mdb_txn_begin failed"; 
    } else { 
      LOG(FATAL) << "Unknown db backend " << db_backend; 
    } 
  } 
} 
放入数据的Key是图像的编号,前面补0至8位。MDB_val类型的mdb_data和mdb_key中存放的是数据来源的指针,以及数据的长度。mdb_put()函数将数据存入数据库。每隔1000个图像commit一次数据库。只有commit之后,数据才真正写入磁盘。 
读取数据集: 
Caffe中读取LMDB数据集的代码是DataLayer,用在网络的最下层,提供数据。DataLayer采用顺序遍历的方式读取数据,不支持打乱数据顺序,只能随机跳过前若干个数据。 
首先,在DataLayer的DataLayerSetUp方法中,打开数据库,并获取迭代器cursor_: 
db_.reset(db::GetDB(this->layer_param_.data_param().backend())); 
db_->Open(this->layer_param_.data_param().source(), db::READ); 
cursor_.reset(db_->NewCursor()); 
然后,在每一次的数据预取时,InternalThreadEntry()方法中,从数据库中读取字符串,反序列化为Datum对象,再从Datum对象中取出数据: 
Datum datum; 
datum.ParseFromString(cursor_->value()); 
其中,cursor_->value()获取序列化后的字符串。datum.ParseFromString()方法对字符串进行反序列化。 
最后,要将cursor_向前推进: 
cursor_->Next(); 
if (!cursor_->valid()) { 
  DLOG(INFO) << "Restarting data prefetching from start." 
      cursor_->SeekToFirst(); 
} 
如果cursor->valid()返回false,说明数据库已经遍历到头,这时需要将cursor_重置回数据库开头。不支持样本随机排序应该是DataLayer的致命弱点。如果数据库的key能够统一,其实可以通过对key随机枚举的方式实现。所以caffe定义了一个随机生成器RNG。 
*/  
template<typename Dtype>  
void DataTransformer<Dtype>::Transform(const Datum& datum,  
                                       Dtype* transformed_data) {  
  // 参考TransformationParameter的定义  
  const string& data = datum.data();  
  const int datum_channels = datum.channels();//数据的channel  
  const int datum_height = datum.height();//数据的行数  
  const int datum_width = datum.width();// 数据的列数  
  
  const int crop_size = param_.crop_size();// crop大小  
  const Dtype scale = param_.scale();// 缩放比例  
  const bool do_mirror = param_.mirror() && Rand(2);// 该参数用于在镜像位置对数据处理  
  const bool has_mean_file = param_.has_mean_file();// 是否有均值文件  
  const bool has_uint8 = data.size() > 0;// 数据是否为uint8还是float类型的  
  const bool has_mean_values = mean_values_.size() > 0;// 是否有每个channel的均值  
  
  // 检查合法性  
  CHECK_GT(datum_channels, 0);  
  CHECK_GE(datum_height, crop_size);  
  CHECK_GE(datum_width, crop_size);  
  
  Dtype* mean = NULL;  
/* 
前面有介绍这一部分CHECK内容,glog提供了多个便利的宏来处理特定关系的判定。具体有: 
1,判定大小关系 
CHECK_EQ, CHECK_NE, CHECK_LE, CHECK_LT, CHECK_GE, CHECK_GT,使用这些宏需要注意类型一致,如果出现类型不一致的,可使用static_cast转换。 
2,判定指针是否为空 
CHECK_NOTNULL(some_ptr),可用于对象初始化的时候。 
3,判定字符串是否相等 
CHECK_STREQ, CHECK_STRNE, CHECK_STRCASEEQ,CHECK_STRCASENE。可进行大小写敏感或不敏感字符串来分别判定。 
4, 判定浮点是否相等或相近 
CHECK_DOUBLE_EQ,CHECK_NEAR。这两个宏都需要指定一个可容忍的偏差上限。 
*/  
  if (has_mean_file) {// 检查mean_file是否与数据的参数一致  
    CHECK_EQ(datum_channels, data_mean_.channels());  
    CHECK_EQ(datum_height, data_mean_.height());  
    CHECK_EQ(datum_width, data_mean_.width());  
    mean = data_mean_.mutable_cpu_data();  
  }  
  if (has_mean_values) {  
    CHECK(mean_values_.size() == 1 || mean_values_.size() == datum_channels) <<  
     "Specify either 1 mean_value or as many as channels: " << datum_channels;  
    if (datum_channels > 1 && mean_values_.size() == 1) {  
      // Replicate the mean_value for simplicity  
      for (int c = 1; c < datum_channels; ++c) {  
        mean_values_.push_back(mean_values_[0]);  
      }  
    }  
  }  
  
  int height = datum_height;  
  int width = datum_width;  
  
  // 根据是否需要crop来生成h_off和w_off  
  int h_off = 0;  
  int w_off = 0;  
  if (crop_size) {// 如果crop_size不为0  
    height = crop_size;  
    width = crop_size;  
    // We only do random crop when we do training.  
    // 在训练的时候随机crop图像块,这里需要自己实现Rand这个函数来确定是如何随机的  
    if (phase_ == TRAIN) {  
      h_off = Rand(datum_height - crop_size + 1);// 产生从0到datum_height - crop_size的随机数  
      w_off = Rand(datum_width - crop_size + 1);  
    } else {// 测试的时候不用随机,取图像的中心  
      h_off = (datum_height - crop_size) / 2;  
      w_off = (datum_width - crop_size) / 2;  
    }  
  }  
  
  // 对数据进行变换,主要是将原来的像素值减去均值,然后乘以scale这么一个操作  
  // 如果需要crop则最终转换的Blob的大小即为crop*crop  
  // 如果不是,则最终的Blob大小即为datum_height*datum_width  
  Dtype datum_element;  
  int top_index, data_index;  
  for (int c = 0; c < datum_channels; ++c) {  
    for (int h = 0; h < height; ++h) {  
      for (int w = 0; w < width; ++w) {  
        data_index = (c * datum_height + h_off + h) * datum_width + w_off + w;// 获取数据的索引,我不是很明白怎么计算的?  
        if (do_mirror) {// 是否需要在镜像位置转换  
          top_index = (c * height + h) * width + (width - 1 - w);//在宽这个坐标上做文章,来实现镜像  
        } else {//  
          top_index = (c * height + h) * width + w;  
        }  
        if (has_uint8) {// 数据如果是uint8则进行转换  
          datum_element =  
            static_cast<Dtype>(static_cast<uint8_t>(data[data_index]));  
        } else {// 否则就是float  
          datum_element = datum.float_data(data_index);  
        }  
        if (has_mean_file) {// 如果有mean_file,则原来的像素值减去均值,然后乘以scale  
          transformed_data[top_index] =  
            (datum_element - mean[data_index]) * scale;  
        } else {  
          if (has_mean_values) {// 否则减去该channel的均值(每个channel有其一个均值),然后乘以scale  
            transformed_data[top_index] =  
              (datum_element - mean_values_[c]) * scale;  
          } else {// 否则如果没有均值那么就直接乘以scale即可  
            transformed_data[top_index] = datum_element * scale;  
          }  
        }  
      }  
    }  
  }  
}  
  
  
template<typename Dtype>  
void DataTransformer<Dtype>::Transform(const Datum& datum,  
                                       Blob<Dtype>* transformed_blob) {  
  // If datum is encoded, decoded and transform the cv::image.  
  if (datum.encoded()) {//  检查是否编码了,如果是则解码  
#ifdef USE_OPENCV  
    // 先检查是不是两个属性都设置, 如果是则说明参数设置有误  
    CHECK(!(param_.force_color() && param_.force_gray()))  
        << "cannot set both force_color and force_gray";  
    cv::Mat cv_img;  
    if (param_.force_color() || param_.force_gray()) {  
        // 如果强制彩色或者强制灰度图像一个成立则使用DecodeDatumToCVMat解码  
    // If force_color then decode in color otherwise decode in gray.  
      cv_img = DecodeDatumToCVMat(datum, param_.force_color());  
    } else {// 否则使用DecodeDatumToCVMatNative解码  
      cv_img = DecodeDatumToCVMatNative(datum);  
    }  
    // Transform the cv::image into blob.  
    // 变换  
    return Transform(cv_img, transformed_blob);  
#else  
    LOG(FATAL) << "Encoded datum requires OpenCV; compile with USE_OPENCV.";  
#endif  // USE_OPENCV  
  } else {// 如果没有编码则,检查force_color和force_gray是否设置,如果设置则不合法,因为该选项只适合于编码后的数据  
    if (param_.force_color() || param_.force_gray()) {  
      LOG(ERROR) << "force_color and force_gray only for encoded datum";  
    }  
  }  
  
  const int crop_size = param_.crop_size();  
  const int datum_channels = datum.channels();  
  const int datum_height = datum.height();  
  const int datum_width = datum.width();  
  
  // Check dimensions.  
  const int channels = transformed_blob->channels();  
  const int height = transformed_blob->height();  
  const int width = transformed_blob->width();  
  const int num = transformed_blob->num();  
  
  CHECK_EQ(channels, datum_channels);  
  CHECK_LE(height, datum_height);  
  CHECK_LE(width, datum_width);  
  CHECK_GE(num, 1);  
  
  if (crop_size) {  
    CHECK_EQ(crop_size, height);  
    CHECK_EQ(crop_size, width);  
  } else {  
    CHECK_EQ(datum_height, height);  
    CHECK_EQ(datum_width, width);  
  }  
  // 继续变换数据  
  Dtype* transformed_data = transformed_blob->mutable_cpu_data();  
  Transform(datum, transformed_data);  
}  
  
template<typename Dtype>  
void DataTransformer<Dtype>::Transform(const vector<Datum> & datum_vector,  
                                       Blob<Dtype>* transformed_blob) {  
  const int datum_num = datum_vector.size();  
  // 变换到的目标blob的形状  
  const int num = transformed_blob->num();  
  const int channels = transformed_blob->channels();  
  const int height = transformed_blob->height();  
  const int width = transformed_blob->width();  
  
  CHECK_GT(datum_num, 0) << "There is no datum to add";  
  CHECK_LE(datum_num, num) <<  
    "The size of datum_vector must be no greater than transformed_blob->num()";  
  // 新建一个uni_blob,里面只有一个batch  
  Blob<Dtype> uni_blob(1, channels, height, width);  
  for (int item_id = 0; item_id < datum_num; ++item_id) {  
    int offset = transformed_blob->offset(item_id);  
    uni_blob.set_cpu_data(transformed_blob->mutable_cpu_data() + offset);  
    Transform(datum_vector[item_id], &uni_blob);  
  }  
}  
  
#ifdef USE_OPENCV  
template<typename Dtype>  
void DataTransformer<Dtype>::Transform(const vector<cv::Mat> & mat_vector,  
                                       Blob<Dtype>* transformed_blob) {  
  // 获取mat的参数  
  const int mat_num = mat_vector.size();  
  const int num = transformed_blob->num();  
  const int channels = transformed_blob->channels();  
  const int height = transformed_blob->height();  
  const int width = transformed_blob->width();  
  
  CHECK_GT(mat_num, 0) << "There is no MAT to add";  
  CHECK_EQ(mat_num, num) <<  
    "The size of mat_vector must be equals to transformed_blob->num()";  
  //  同上  
  Blob<Dtype> uni_blob(1, channels, height, width);  
  for (int item_id = 0; item_id < mat_num; ++item_id) {  
    int offset = transformed_blob->offset(item_id);  
    uni_blob.set_cpu_data(transformed_blob->mutable_cpu_data() + offset);  
    Transform(mat_vector[item_id], &uni_blob);  
  }  
}  
  
// 如果是图像的话,需要减去均值乘以scale,判断是不是需要做镜像处理  
// 逻辑与前面类似  
template<typename Dtype>  
void DataTransformer<Dtype>::Transform(const cv::Mat& cv_img,  
                                       Blob<Dtype>* transformed_blob) {  
  const int crop_size = param_.crop_size();  
  const int img_channels = cv_img.channels();  
  const int img_height = cv_img.rows;  
  const int img_width = cv_img.cols;  
  
  // Check dimensions.  
  const int channels = transformed_blob->channels();  
  const int height = transformed_blob->height();  
  const int width = transformed_blob->width();  
  const int num = transformed_blob->num();  
  
  CHECK_EQ(channels, img_channels);  
  CHECK_LE(height, img_height);  
  CHECK_LE(width, img_width);  
  CHECK_GE(num, 1);  
  
  CHECK(cv_img.depth() == CV_8U) << "Image data type must be unsigned byte";  
  
  const Dtype scale = param_.scale();  
  const bool do_mirror = param_.mirror() && Rand(2);  
  const bool has_mean_file = param_.has_mean_file();  
  const bool has_mean_values = mean_values_.size() > 0;  
  
  CHECK_GT(img_channels, 0);  
  CHECK_GE(img_height, crop_size);  
  CHECK_GE(img_width, crop_size);  
  
  Dtype* mean = NULL;  
  if (has_mean_file) {  
    CHECK_EQ(img_channels, data_mean_.channels());  
    CHECK_EQ(img_height, data_mean_.height());  
    CHECK_EQ(img_width, data_mean_.width());  
    mean = data_mean_.mutable_cpu_data();  
  }  
  if (has_mean_values) {  
    CHECK(mean_values_.size() == 1 || mean_values_.size() == img_channels) <<  
     "Specify either 1 mean_value or as many as channels: " << img_channels;  
    if (img_channels > 1 && mean_values_.size() == 1) {  
      // Replicate the mean_value for simplicity  
      for (int c = 1; c < img_channels; ++c) {  
        mean_values_.push_back(mean_values_[0]);  
      }  
    }  
  }  
  
  int h_off = 0;  
  int w_off = 0;  
  cv::Mat cv_cropped_img = cv_img;  
  if (crop_size) {  
    CHECK_EQ(crop_size, height);  
    CHECK_EQ(crop_size, width);  
    // We only do random crop when we do training.  
    if (phase_ == TRAIN) {  
      h_off = Rand(img_height - crop_size + 1);  
      w_off = Rand(img_width - crop_size + 1);  
    } else {  
      h_off = (img_height - crop_size) / 2;  
      w_off = (img_width - crop_size) / 2;  
    }  
    cv::Rect roi(w_off, h_off, crop_size, crop_size);  
    cv_cropped_img = cv_img(roi);  
  } else {  
    CHECK_EQ(img_height, height);  
    CHECK_EQ(img_width, width);  
  }  
  
  CHECK(cv_cropped_img.data);  
  
  Dtype* transformed_data = transformed_blob->mutable_cpu_data();  
  int top_index;  
  for (int h = 0; h < height; ++h) {  
    const uchar* ptr = cv_cropped_img.ptr<uchar>(h);  
    int img_index = 0;  
    for (int w = 0; w < width; ++w) {  
      for (int c = 0; c < img_channels; ++c) {  
        if (do_mirror) {  
          top_index = (c * height + h) * width + (width - 1 - w);  
        } else {  
          top_index = (c * height + h) * width + w;  
        }  
        // int top_index = (c * height + h) * width + w;  
        Dtype pixel = static_cast<Dtype>(ptr[img_index++]);  
        if (has_mean_file) {  
          int mean_index = (c * img_height + h_off + h) * img_width + w_off + w;  
          transformed_data[top_index] =  
            (pixel - mean[mean_index]) * scale;  
        } else {  
          if (has_mean_values) {  
            transformed_data[top_index] =  
              (pixel - mean_values_[c]) * scale;  
          } else {  
            transformed_data[top_index] = pixel * scale;  
          }  
        }  
      }  
    }  
  }  
}  
#endif  // USE_OPENCV  
  
template<typename Dtype>  
void DataTransformer<Dtype>::Transform(Blob<Dtype>* input_blob,  
                                       Blob<Dtype>* transformed_blob) {  
  const int crop_size = param_.crop_size();  
  const int input_num = input_blob->num();  
  const int input_channels = input_blob->channels();  
  const int input_height = input_blob->height();  
  const int input_width = input_blob->width();  
  
  if (transformed_blob->count() == 0) {  
    // Initialize transformed_blob with the right shape.  
    if (crop_size) {  
      transformed_blob->Reshape(input_num, input_channels,  
                                crop_size, crop_size);  
    } else {  
      transformed_blob->Reshape(input_num, input_channels,  
                                input_height, input_width);  
    }  
  }  
  
  const int num = transformed_blob->num();  
  const int channels = transformed_blob->channels();  
  const int height = transformed_blob->height();  
  const int width = transformed_blob->width();  
  const int size = transformed_blob->count();  
  
  CHECK_LE(input_num, num);  
  CHECK_EQ(input_channels, channels);  
  CHECK_GE(input_height, height);  
  CHECK_GE(input_width, width);  
  
  
  const Dtype scale = param_.scale();  
  const bool do_mirror = param_.mirror() && Rand(2);  
  const bool has_mean_file = param_.has_mean_file();  
  const bool has_mean_values = mean_values_.size() > 0;  
  
  int h_off = 0;  
  int w_off = 0;  
  if (crop_size) {  
    CHECK_EQ(crop_size, height);  
    CHECK_EQ(crop_size, width);  
    // We only do random crop when we do training.  
    if (phase_ == TRAIN) {  
      h_off = Rand(input_height - crop_size + 1);  
      w_off = Rand(input_width - crop_size + 1);  
    } else {  
      h_off = (input_height - crop_size) / 2;  
      w_off = (input_width - crop_size) / 2;  
    }  
  } else {  
    CHECK_EQ(input_height, height);  
    CHECK_EQ(input_width, width);  
  }  
  
  // 如果有均值文件则  
  Dtype* input_data = input_blob->mutable_cpu_data();  
  if (has_mean_file) {  
    CHECK_EQ(input_channels, data_mean_.channels());  
    CHECK_EQ(input_height, data_mean_.height());  
    CHECK_EQ(input_width, data_mean_.width());  
    for (int n = 0; n < input_num; ++n) {  
      int offset = input_blob->offset(n);  
      /* 
         template <typename Dtype> 
       void caffe_sub(const int N, const Dtype* a, const Dtype* b, Dtype* y); 
       math_function中定义的caffe_sub目的是矩阵相减input_data(以offset开始的矩阵) = input_data(以offset开始的矩阵) - data_mean_ 
    */  
      caffe_sub(data_mean_.count(), input_data + offset,  
            data_mean_.cpu_data(), input_data + offset);  
    }  
  }  
  // 如果每个channel有均值则  
  if (has_mean_values) {  
    CHECK(mean_values_.size() == 1 || mean_values_.size() == input_channels) <<  
     "Specify either 1 mean_value or as many as channels: " << input_channels;  
    if (mean_values_.size() == 1) {  
      caffe_add_scalar(input_blob->count(), -(mean_values_[0]), input_data);  
    } else {  
      for (int n = 0; n < input_num; ++n) {  
        for (int c = 0; c < input_channels; ++c) {  
          int offset = input_blob->offset(n, c);  
          // 给nput_data[offset]地址开始的每一个元素加上一个-mean_values_[c]  
          caffe_add_scalar(input_height * input_width, -(mean_values_[c]),  
            input_data + offset);  
        }  
      }  
    }  
  }  
  
  // 如果啥均值都没有则直接复制  
  Dtype* transformed_data = transformed_blob->mutable_cpu_data();  
  
  for (int n = 0; n < input_num; ++n) {  
    int top_index_n = n * channels;  
    int data_index_n = n * channels;  
    for (int c = 0; c < channels; ++c) {  
      int top_index_c = (top_index_n + c) * height;  
      int data_index_c = (data_index_n + c) * input_height + h_off;  
      for (int h = 0; h < height; ++h) {  
        int top_index_h = (top_index_c + h) * width;  
        int data_index_h = (data_index_c + h) * input_width + w_off;  
        if (do_mirror) {  
          int top_index_w = top_index_h + width - 1;  
          for (int w = 0; w < width; ++w) {  
            transformed_data[top_index_w-w] = input_data[data_index_h + w];  
          }  
        } else {  
          for (int w = 0; w < width; ++w) {  
            transformed_data[top_index_h + w] = input_data[data_index_h + w];  
          }  
        }  
      }  
    }  
  }  
  if (scale != Dtype(1)) {  
    DLOG(INFO) << "Scale: " << scale;  
    caffe_scal(size, scale, transformed_data);  
  }  
}  
  
template<typename Dtype>  
vector<int> DataTransformer<Dtype>::InferBlobShape(const Datum& datum) {  
  if (datum.encoded()) {  
#ifdef USE_OPENCV // 如果使用OpenCV则可以用先转换为CVMat,然后在推断blob的形状  
    CHECK(!(param_.force_color() && param_.force_gray()))  
        << "cannot set both force_color and force_gray";  
    cv::Mat cv_img;  
    if (param_.force_color() || param_.force_gray()) {  
    // If force_color then decode in color otherwise decode in gray.  
      cv_img = DecodeDatumToCVMat(datum, param_.force_color());  
    } else {  
      cv_img = DecodeDatumToCVMatNative(datum);  
    }  
    // InferBlobShape using the cv::image.  
    return InferBlobShape(cv_img);  
#else  
    LOG(FATAL) << "Encoded datum requires OpenCV; compile with USE_OPENCV.";  
#endif  // USE_OPENCV  
  }  
  
  // 否则直接粗暴地从datum里面获取形状的数据  
  const int crop_size = param_.crop_size();  
  const int datum_channels = datum.channels();  
  const int datum_height = datum.height();  
  const int datum_width = datum.width();  
  // Check dimensions.  
  CHECK_GT(datum_channels, 0);  
  CHECK_GE(datum_height, crop_size);  
  CHECK_GE(datum_width, crop_size);  
  // Build BlobShape.  
  vector<int> shape(4);  
  shape[0] = 1;  
  shape[1] = datum_channels;  
  shape[2] = (crop_size)? crop_size: datum_height;  
  shape[3] = (crop_size)? crop_size: datum_width;  
  return shape;  
}  
  
template<typename Dtype>  
vector<int> DataTransformer<Dtype>::InferBlobShape(  
    const vector<Datum> & datum_vector) {  
  const int num = datum_vector.size();  
  CHECK_GT(num, 0) << "There is no datum to in the vector";  
  // Use first datum in the vector to InferBlobShape.  
  // 使用第一个来进行推断  
  vector<int> shape = InferBlobShape(datum_vector[0]);  
  // Adjust num to the size of the vector.  
  shape[0] = num;  
  return shape;  
}  
  
#ifdef USE_OPENCV  
// 如果使用OpenCV  
// 使用CVMat中的信息来推断形状  
template<typename Dtype>  
vector<int> DataTransformer<Dtype>::InferBlobShape(const cv::Mat& cv_img) {  
  const int crop_size = param_.crop_size();  
  const int img_channels = cv_img.channels();  
  const int img_height = cv_img.rows;  
  const int img_width = cv_img.cols;  
  // Check dimensions.  
  CHECK_GT(img_channels, 0);  
  CHECK_GE(img_height, crop_size);  
  CHECK_GE(img_width, crop_size);  
  // Build BlobShape.  
  vector<int> shape(4);  
  shape[0] = 1;  
  shape[1] = img_channels;  
  shape[2] = (crop_size)? crop_size: img_height;  
  shape[3] = (crop_size)? crop_size: img_width;  
  return shape;  
}  
  
template<typename Dtype>  
vector<int> DataTransformer<Dtype>::InferBlobShape(  
    const vector<cv::Mat> & mat_vector) {  
  const int num = mat_vector.size();  
  CHECK_GT(num, 0) << "There is no cv_img to in the vector";  
  // Use first cv_img in the vector to InferBlobShape.  
  // 使用第一个来推断  
  vector<int> shape = InferBlobShape(mat_vector[0]);  
  // Adjust num to the size of the vector.  
  shape[0] = num;  
  return shape;  
}  
#endif  // USE_OPENCV  
  
// 初始化随机数种子  
template <typename Dtype>  
void DataTransformer<Dtype>::InitRand() {  
  // 要么需要镜像要么训练阶段和需要crop同时满足的情况下才初始化随机数种子  
  const bool needs_rand = param_.mirror() ||  
      (phase_ == TRAIN && param_.crop_size());  
  if (needs_rand) {  
    const unsigned int rng_seed = caffe_rng_rand();// 获得随机数种子(通过熵池或者时间生成种子)  
    rng_.reset(new Caffe::RNG(rng_seed));//初始化随机数种子并实例化随机数生成器  
  } else {  
    rng_.reset();//否则随机数生成器设置为空  
  }  
}  
  
// 产生从0到n的随机数  
template <typename Dtype>  
int DataTransformer<Dtype>::Rand(int n) {  
  CHECK(rng_);  
  CHECK_GT(n, 0);  
  caffe::rng_t* rng =  
      static_cast<caffe::rng_t*>(rng_->generator());  
  return ((*rng)() % n);  
}  
  
INSTANTIATE_CLASS(DataTransformer);  
/* 
初始化类的宏定义是这样的,前面有讲过,这里再给出来 
#define INSTANTIATE_CLASS(classname) \ 
  char gInstantiationGuard##classname; \ 
  template class classname<float>; \ 
  template class classname<double> 
*/  
  
}  // namespace caffe  


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值