MnistData的读取

这个博客是留给自己备用的

#include <iostream>
#include <string>
#include <vector>
#include <fstream>
#include <eigen3/Eigen/Core>

class MnistData
{
public:
    MnistData(std::string train_img_filename, 
             std::string train_label_filename,
             std::string test_img_filename, 
             std::string test_label_filename);

    void getMnistTrainData(Eigen::MatrixXd &_data);

    void getMnistTrainLabel(Eigen::MatrixXd &_data);

    void getMnistTestData(Eigen::MatrixXd &_data);

    void getMnistTestLabel(Eigen::MatrixXd &_data);

private:
    void read_Mnist_Images(const std::string &filename, Eigen::MatrixXd &_data);
    void read_Mnist_Label(const std::string &filename, Eigen::MatrixXd &_data);
    int ReverseInt (int i);

private:
    std::vector<std::string> imgDataPath;// {train image, train label, test image, test laebl}
};

MnistData::MnistData(std::string train_img_filename,
             std::string train_label_filename,
             std::string test_img_filename,
             std::string test_label_filename)
{
    imgDataPath.push_back(train_img_filename);
    imgDataPath.push_back(train_label_filename);
    imgDataPath.push_back(test_img_filename);
    imgDataPath.push_back(test_label_filename);
}

void MnistData::getMnistTrainData(Eigen::MatrixXd &_data)
{
    read_Mnist_Images(imgDataPath[0], _data);
}

void MnistData::getMnistTrainLabel(Eigen::MatrixXd &_data)
{
    read_Mnist_Label(imgDataPath[1], _data);
}

void MnistData::getMnistTestData(Eigen::MatrixXd &_data)
{
    read_Mnist_Images(imgDataPath[2], _data);
}

void MnistData::getMnistTestLabel(Eigen::MatrixXd &_data)
{
    read_Mnist_Label(imgDataPath[3], _data);
}

void MnistData::read_Mnist_Images(const std::string &filename, Eigen::MatrixXd &_data)
{
    std::ifstream file;
    file.open(filename.c_str(), std::ios::in | std::ios::binary);
    if (file.is_open())
    {
        int magic_number = 0;
        int number_of_images = 0;
        int n_rows = 0;
        int n_cols = 0;
        unsigned char label;
        file.read((char*)&magic_number, sizeof(magic_number));
        file.read((char*)&number_of_images, sizeof(number_of_images));
        file.read((char*)&n_rows, sizeof(n_rows));
        file.read((char*)&n_cols, sizeof(n_cols));
        magic_number = ReverseInt(magic_number);
        number_of_images = ReverseInt(number_of_images);
        n_rows = ReverseInt(n_rows);
        n_cols = ReverseInt(n_cols);

        std::cout << "magic number = " << magic_number << std::endl;
        std::cout << "number of images = " << number_of_images << std::endl;
        std::cout << "rows = " << n_rows << std::endl;
        std::cout << "cols = " << n_cols << std::endl;

        _data.resize(n_rows * n_cols, number_of_images);

        for (int i = 0; i < number_of_images; i++)
        {
            Eigen::MatrixXd img(n_rows * n_cols, 1);
            for (int r = 0; r < n_rows; r++)
            {
                for (int c = 0; c < n_cols; c++)
                {
                    unsigned char image = 0;
                    file.read((char*)&image, sizeof(image));
                    img(n_rows * r + c, 0) = (double)image / 255.0;
                }
            }
            _data.col(i) = img;
        }
    }
    if (file.is_open())
        file.close();
}

void MnistData::read_Mnist_Label(const std::string &filename, Eigen::MatrixXd &_data)
{
    std::ifstream file(filename.c_str(), std::ios::in | std::ios::binary);
    if (file.is_open())
    {
        int magic_number = 0;
        int number_of_images = 0;
        file.read((char*)&magic_number, sizeof(magic_number));
        file.read((char*)&number_of_images, sizeof(number_of_images));
        magic_number = ReverseInt(magic_number);
        number_of_images = ReverseInt(number_of_images);
        std::cout << "magic number = " << magic_number << std::endl;
        std::cout << "number of images = " << number_of_images << std::endl;
        _data.resize(10, number_of_images);
        _data.setZero();
        for (int i = 0; i < number_of_images; i++)
        {
            unsigned char label = 0;
            file.read((char*)&label, sizeof(label));
            _data((int)label, i) = 1;
        }
    }
    if (file.is_open())
        file.close();
}

int MnistData::ReverseInt (int i)
{
    unsigned char ch1, ch2, ch3, ch4;
    ch1=i&255;
    ch2=(i>>8)&255;
    ch3=(i>>16)&255;
    ch4=(i>>24)&255;
    return((int)ch1<<24)+((int)ch2<<16)+((int)ch3<<8)+ch4;
}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值