pytorch简单自定义Datasets

前言

  • 本文记录一下如何简单自定义pytorch中Datasets,官方教程
  • 文件层级目录如下:
    • images
      • 1.jpg
      • 2.jpg
      • 9.jpg
    • annotations_file.csv

数据说明

  • image文件夹中有需要训练的图片,annotations_file.csv中有2列,分别为image_idlabel,即图片名和其对应标签。
    在这里插入图片描述
image_idlabel
1风景
2风景
3风景
4星空
5星空
6星空
7人物
8人物
9人物

代码展示

导入必要包

import os
import torch
import pandas as pd
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from torchvision import datasets
import torchvision.transforms as T

自定义Datasets

  • 自定义Datasets之前,首先我们需要准备两个信息:
    • 图片地址
    • 图片标签
class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        # 读取包含图片id和图片标签的csv文件
        self.img_labels = pd.read_csv(annotations_file, encoding="utf-8")
        # 图片存储路径
        self.img_dir = img_dir
        # 对图片进行预处理
        self.transform = transform
        # 对图片标签进行预处理,如One-hot编码
        self.target_transform = target_transform

    def __len__(self):
        # 返回样本总量
        return len(self.img_labels)

    def __getitem__(self, idx):
        # 拼接图片完整读取路径
        img_path = os.path.join(self.img_dir, str(self.img_labels.iloc[idx, 0]) + '.jpg')
        # 使用PIL库读取图片
        image = Image.open(img_path)
        # 读取图片标签
        label = self.img_labels.iloc[idx, 1]
        # 如果传入了图片预处理方法则执行
        if self.transform:
            image = self.transform(image)
        # 如果传入了标签预处理方法则执行
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

定义图片预处理方法

transform = {'train':T.Compose([
    # 将图片缩放至固定尺寸
    T.Resize((224,224)),
    # 应用CIFAR10自动增强策略
    T.AutoAugment(T.AutoAugmentPolicy.CIFAR10),
    # 像素值归一化,并转换为ternsor格式
    T.ToTensor(),])}
  • 接下来我们将Datasets实例化
annotations_file = '/kaggle/input/datasets-test/annotations_file.csv'
img_dir = '/kaggle/input/datasets-test/images'
train_data = CustomImageDataset(annotations_file = annotations_file ,img_dir = img_dir, transform = transform['train'])
  • 我们可以使用len(train_data)检查样本完整性,以及Datasets定义正确性,这里输出9,的确只有9张图片,正确无误。

使用DataLoaders加载数据

  • 因为这里数据较少,所以设置batch_size = 2,打乱数据shuffle = true,不丢弃数据drop_last=False,有关DataLoader的更多操作可以参照官方API
from torch.utils.data import DataLoader
train_dataloader = DataLoader(train_data, batch_size=2, shuffle=True, drop_last=False)
  • 使用iter()函数和next()函数,取1个batch,检查数据
train_features, train_labels = next(iter(train_dataloader))
  • 取第1个batch中的第1个图片,并将其可视化。
  • 由于train_features[0]的维度为(1,3,224,224),所以使用squeeze()函数从数组中删除单维度条目,即把为1的维度去掉。再使用permute()函数将维度变换(224,224,3),便于plt绘图。
# 维度整理
img = train_features[0].squeeze()
# 转置操作
img = img.permute(1,2,0)
# 绘图
plt.imshow(np.asarray(img))
# 关闭刻度线
plt.axis('off')
plt.show()

请添加图片描述

  • 打印图片标签print(train_labels[0]),输出'人物'

在Datasets中将字符串标签数值化

  • 我们发现上面打印出的标签为字符串,如果我们想要将其数值化,只需要在Datasets__getitem__部分改动一点
class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        # 读取包含图片id和图片标签的csv文件
        self.img_labels = pd.read_csv(annotations_file, encoding="utf-8")
        # 图片存储路径
        self.img_dir = img_dir
        # 对图片进行预处理
        self.transform = transform
        # 对图片标签进行预处理,如One-hot编码
        self.target_transform = target_transform

    def __len__(self):
        # 返回样本总量
        return len(self.img_labels)

    def __getitem__(self, idx):
        # 拼接图片完整读取路径
        img_path = os.path.join(self.img_dir, str(self.img_labels.iloc[idx, 0]) + '.jpg')
        # 使用PIL库读取图片
        image = Image.open(img_path)
        # 将字符串进行数值编码
        data_category, data_class = pd.factorize(self.img_labels.iloc[:, 1])
        label = data_category[idx]
        # 如果传入了图片预处理方法则执行
        if self.transform:
            image = self.transform(image)
        # 如果传入了标签预处理方法则执行
        if self.target_transform:
            label = self.target_transform(label)
        return image, label
  • 这样就可以了,可以看到其实array格式的数据也是可以读取的,只要保证idx一致,且对应就可以。

划分训练集与验证集

  • 根据前面的说明,其实训练集与验证集的划分就变的很简单了,只需要4个列表/数组,train_pathtrain_labelvaild_pathvaild_label分别表示训练集图片路径、标签、验证集图片路径、标签。DataSets可以这样写:
class CustomImageDataset(Dataset):
    def __init__(self, image_id, image_label, img_dir, transform=None, target_transform=None):
        # 读取包含图片id
        self.image_id = image_id
        # 读取图片标签
        self.image_label = image_label
        # 图片存储路径
        self.img_dir = img_dir
        # 对图片进行预处理
        self.transform = transform
        # 对图片标签进行预处理,如One-hot编码
        self.target_transform = target_transform

    def __len__(self):
        # 返回样本总量
        return len(self.img_labels)

    def __getitem__(self, idx):
        # 拼接图片完整读取路径
        img_path = os.path.join(self.img_dir, str(self.image_id[idx]) + '.jpg')
        # 使用PIL库读取图片
        image = Image.open(img_path)
        # 读取图片标签
        label = self.image.label[idx]
        # 如果传入了图片预处理方法则执行
        if self.transform:
            image = self.transform(image)
        # 如果传入了标签预处理方法则执行
        if self.target_transform:
            label = self.target_transform(label)
        return image, label
  • 实例化和加载器可以写作:
transform = {'train':T.Compose([T.Resize((224,224)),T.AutoAugment(T.AutoAugmentPolicy.CIFAR10),T.ToTensor(),]),
             'valid':T.Compose([T.Resize((224,224)),T.ToTensor(),])}
             
# 实例化训练数据
train_data = CustomImageDataset(train_path, train_label, img_dir = './',transform = transform['train'])
# 实例化训练数据
valid_data = CustomImageDataset(valid_path, valid_label, img_dir = './',transform = transform['valid'])

# 训练集数据加载器
train_dataloader = DataLoader(train_data, batch_size=2, shuffle=True, drop_last=False)
# 验证集数据加载器
valid_dataloader = DataLoader(valid_data, batch_size=2, shuffle=True, drop_last=False)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

羽星_s

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值