这个博客是留给自己备用的
#include <iostream>
#include <string>
#include <vector>
#include <fstream>
#include <eigen3/Eigen/Core>
class MnistData
{
public:
MnistData(std::string train_img_filename,
std::string train_label_filename,
std::string test_img_filename,
std::string test_label_filename);
void getMnistTrainData(Eigen::MatrixXd &_data);
void getMnistTrainLabel(Eigen::MatrixXd &_data);
void getMnistTestData(Eigen::MatrixXd &_data);
void getMnistTestLabel(Eigen::MatrixXd &_data);
private:
void read_Mnist_Images(const std::string &filename, Eigen::MatrixXd &_data);
void read_Mnist_Label(const std::string &filename, Eigen::MatrixXd &_data);
int ReverseInt (int i);
private:
std::vector<std::string> imgDataPath;// {train image, train label, test image, test laebl}
};
MnistData::MnistData(std::string train_img_filename,
std::string train_label_filename,
std::string test_img_filename,
std::string test_label_filename)
{
imgDataPath.push_back(train_img_filename);
imgDataPath.push_back(train_label_filename);
imgDataPath.push_back(test_img_filename);
imgDataPath.push_back(test_label_filename);
}
void MnistData::getMnistTrainData(Eigen::MatrixXd &_data)
{
read_Mnist_Images(imgDataPath[0], _data);
}
void MnistData::getMnistTrainLabel(Eigen::MatrixXd &_data)
{
read_Mnist_Label(imgDataPath[1], _data);
}
void MnistData::getMnistTestData(Eigen::MatrixXd &_data)
{
read_Mnist_Images(imgDataPath[2], _data);
}
void MnistData::getMnistTestLabel(Eigen::MatrixXd &_data)
{
read_Mnist_Label(imgDataPath[3], _data);
}
void MnistData::read_Mnist_Images(const std::string &filename, Eigen::MatrixXd &_data)
{
std::ifstream file;
file.open(filename.c_str(), std::ios::in | std::ios::binary);
if (file.is_open())
{
int magic_number = 0;
int number_of_images = 0;
int n_rows = 0;
int n_cols = 0;
unsigned char label;
file.read((char*)&magic_number, sizeof(magic_number));
file.read((char*)&number_of_images, sizeof(number_of_images));
file.read((char*)&n_rows, sizeof(n_rows));
file.read((char*)&n_cols, sizeof(n_cols));
magic_number = ReverseInt(magic_number);
number_of_images = ReverseInt(number_of_images);
n_rows = ReverseInt(n_rows);
n_cols = ReverseInt(n_cols);
std::cout << "magic number = " << magic_number << std::endl;
std::cout << "number of images = " << number_of_images << std::endl;
std::cout << "rows = " << n_rows << std::endl;
std::cout << "cols = " << n_cols << std::endl;
_data.resize(n_rows * n_cols, number_of_images);
for (int i = 0; i < number_of_images; i++)
{
Eigen::MatrixXd img(n_rows * n_cols, 1);
for (int r = 0; r < n_rows; r++)
{
for (int c = 0; c < n_cols; c++)
{
unsigned char image = 0;
file.read((char*)&image, sizeof(image));
img(n_rows * r + c, 0) = (double)image / 255.0;
}
}
_data.col(i) = img;
}
}
if (file.is_open())
file.close();
}
void MnistData::read_Mnist_Label(const std::string &filename, Eigen::MatrixXd &_data)
{
std::ifstream file(filename.c_str(), std::ios::in | std::ios::binary);
if (file.is_open())
{
int magic_number = 0;
int number_of_images = 0;
file.read((char*)&magic_number, sizeof(magic_number));
file.read((char*)&number_of_images, sizeof(number_of_images));
magic_number = ReverseInt(magic_number);
number_of_images = ReverseInt(number_of_images);
std::cout << "magic number = " << magic_number << std::endl;
std::cout << "number of images = " << number_of_images << std::endl;
_data.resize(10, number_of_images);
_data.setZero();
for (int i = 0; i < number_of_images; i++)
{
unsigned char label = 0;
file.read((char*)&label, sizeof(label));
_data((int)label, i) = 1;
}
}
if (file.is_open())
file.close();
}
int MnistData::ReverseInt (int i)
{
unsigned char ch1, ch2, ch3, ch4;
ch1=i&255;
ch2=(i>>8)&255;
ch3=(i>>16)&255;
ch4=(i>>24)&255;
return((int)ch1<<24)+((int)ch2<<16)+((int)ch3<<8)+ch4;
}