libtorch学习历程(四):数据加载模块

本章将详细介绍如何使用libtorch自带的数据加载模块。
自定义数据集

使用自定义数据集

简介

要自定义数据加载模块,需要继承torch::data::Dataset这个基类实现派生类。
与pytorch中需要实现初始化函数init获取函数getitem以及数据集大小函数len类似的是,在libtorch中同样需要处理好初始化函数get()函数size()函数

例程的代码结构

例程中使用了一个图像分类任务来进行介绍,使用pytorch官网提供的昆虫分类数据集

遍历图像文件

例程中使用了io.h来遍历文件夹。
首先实现遍历文件夹的函数:
接受数据集文件夹路径image_dir图片类型image_type,将遍历到的图片路径和其类别分别存储到list_images和list_labels,最后lable变量用于表示类别计数。

通过该函数,会得到所有图像的绝对地址,通过这些地址就可以获得图像。

#include <io.h>
void load_data_from_folder(std::string path, std::string type, std::vector<std::string> &list_images, std::vector<int> &list_labels, int label);

void load_data_from_folder(std::string path, std::string type, std::vector<std::string> &list_images, std::vector<int> &list_labels, int label)
{
    /*
     * path:文件夹地址
     * type:图片类型
     * list_images:所有图片的名称
     * list_label:各个图片的标签,也就是所属的类
     * label:类别的个数
    */
    long long hFile = 0; //句柄
    struct _finddata_t fileInfo;// 记录读取到文件的信息
    std::string pathName;

    // 调用_findfirst函数,其第一个参数为遍历的文件夹路径,*代表任意文件。注意路径最后,需要添加通配符
    // 如果失败,返回-1,否则,就会返回文件句柄,并且将找到的第一个文件信息放在_finddata_t结构体变量中
    if ((hFile = _findfirst(pathName.assign(path).append("\\*.*").c_str(), &fileInfo)) == -1)
    {
        return;
    }
    // 通过do{}while循环,遍历所有文件
    do
    {
        const char* filename = fileInfo.name;// 获得文件名
        const char* t = type.data();

        if (fileInfo.attrib&_A_SUBDIR) //是子文件夹
        {
            //遍历子文件夹中的文件(夹)
            if (strcmp(filename, ".") == 0 || strcmp(filename, "..") == 0) //子文件夹目录是.或者..
                continue;
            std::string sub_path = path + "\\" + fileInfo.name;// 增加多一级
            label++;
            load_data_from_folder(sub_path, type, list_images, list_labels, label);// 读取子文件夹的文件

        }
        else //判断是不是后缀为type文件
        {
            if (strstr(filename, t))
            {
                std::string image_path = path + "\\" + fileInfo.name;// 构造图像的地址
                list_images.push_back(image_path);
                list_labels.push_back(label);
            }
        }
      //其第一个参数就是_findfirst函数的返回值,第二个参数同样是文件信息结构体
    } while (_findnext(hFile, &fileInfo) == 0);
    return;
}

自定义DataSet

需要继承torch::data::Dataset,定义私有变量image_paths和labels分别存储图片路径和类别,是两个vector变量。
在构造函数中,调用图像遍历函数来获得所有图像的地址与类别;并且需要重写get()与size()

在get()中根据传入的index来获得指定的图像,而且可以在get()函数中对图像进行一些处理,例如调整大小或数据增强等。然后使用torch::from_blob将图像数据与label都转换为张量。
其中图像还需要使用permute(),将张量转换为Channels x Height x Width的结构、

class myDataset:public torch::data::Dataset<myDataset>{
public:
    int num_classes = 0;
    myDataset(std::string image_dir, std::string type){
        // 调用遍历文件的函数
        load_data_from_folder(image_dir, std::string(type), image_paths, labels, num_classes);
    }
    // 重写 get(),根据传入的index来获得指定的数据
    torch::data::Example<> get(size_t index) override{
        std::string image_path = image_paths.at(index);// 根据index得到指定的图像
        cv::Mat image = cv::imread(image_path);// 读取图像
        cv::resize(image, image, cv::Size(224, 224));// 调整大小,使得尺寸统一,用于张量stack
        int label = labels.at(index);//
        // 将图像数据转换为张量image_tensor,尺寸{image.rows, image.cols, 3},元素的数据类型为byte
        // Channels x Height x Width
        torch::Tensor img_tensor = torch::from_blob(image.data, { image.rows, image.cols, 3 }, torch::kByte).permute({ 2, 0, 1 });
        //
        torch::Tensor label_tensor = torch::full({ 1 }, label);
        return {img_tensor.clone(), label_tensor.clone()};// 返回图像及其标签
    }
    // Return the length of data
    torch::optional<size_t> size() const override {
        return image_paths.size();
    };
private:
    std::vector<std::string> image_paths;// 所有图像的地址
    std::vector<int> labels;// 所有图像的类别
};

使用自定义数据集

首先创建一个自定义数据集对象,然后它进行一些transform处理

auto mydataset = myDataset(image_dir,".jpg").map(torch::data::transforms::Stack<>());

然后需要使用torch::data::make_data_loader来传入批数据(Batch),对应于pytorch中的torch.utils.data.DataLoader
在这里插入图片描述

这里面的 SequentialSampler 类负责按照我们提供的数据顺序来生成样本。
需要传入数据集对象与批次尺寸

auto mdataloader = torch::data::make_data_loader<torch::data::samplers::RandomSampler>(
						std::move(mydataset), batch_size);

然后可以通过循环来遍历每个批次中的data(image)与target(标签),也就是自定义数据集中的get() 所返回两个数据:image与label
这里每次取得的数据大小取决于之前 torch::data::make_data_loader() 函数中传入的 batch_size 大小

for(auto &batch: *mdataloader){
   auto data = batch.data;
   auto target = batch.target;
   std::cout<<data.sizes()<<target<<std::endl;
}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值