PyTorch学习笔记(一)——Dataset

文章介绍了数据集的三种组织形式,特别是训练集和验证集的划分。重点讲解了PyTorch中Dataset类的作用,包括获取数据和标签,以及统计数据集大小。同时提到了Dataloader的作用,即批量处理数据。文中还展示了如何使用PIL和OpenCV读取图像,以及如何构建自定义的Dataset类来处理图像数据。最后提到了数据集的另一种组织形式,即图片和标签分开存储在txt文件中。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1 数据集的组织形式 

train训练集    val验证集

1.1第一种组织形式

文件夹名就是label,里面放的就是相应label的图片

1.2第二种组织形式

 图片和label分开文件夹存放,相应文件名要一致,txt文件里存放相应图片的label

1.3第三种组织形式

label直接为图片的名称

2 pytorch读取数据涉及两个类:Dataset & Dataloader

Dataset:提供一种方式,获取需要的数据和对应的label 值,并完成编号。主要实现两个功能:

  • 获取每一个数据及其对应label
  • 统计数据集中的数据数量(神经网络经常需要对一个数据迭代多次,只有知道当前有多少数据,进行训练时才知道要训练多少次,才能把整个数据集迭代完)

Dataloader:打包(batch_size),为后面的神经网络提供不同的数据形式

2.1 Dataset类

数据的第一种组织形式

1、查看官方文档:

class Dataset(Generic[T_co]):
    r"""An abstract class representing a :class:`Dataset`.

    All datasets that represent a map from keys to data samples should subclass
    it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
    data sample for a given key. Subclasses could also optionally overwrite
    :meth:`__len__`, which is expected to return the size of the dataset by many
    :class:`~torch.utils.data.Sampler` implementations and the default options
    of :class:`~torch.utils.data.DataLoader`.

    .. note::
      :class:`~torch.utils.data.DataLoader` by default constructs a index
      sampler that yields integral indices.  To make it work with a map-style
      dataset with non-integral indices/keys, a custom sampler must be provided.
    """

    def __getitem__(self, index) -> T_co:
        raise NotImplementedError

    def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
        return ConcatDataset([self, other])

    # No `def __len__(self)` default?
    # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
    # in pytorch/torch/utils/data/sampler.py

该类是一个抽象类,所有的数据集想要在数据与标签之间建立映射都需要继承这个类。所有的子类都需要重写__getitem__方法,该方法根据索引值获取每一个数据并且获取其对应的Label。子类也可以重写__len__方法,返回数据集的size大小。

2、导入Dataset类

from torch.utils.data import Dataset

3、创建一个类(需要继承Dataset类),类名为MyData

class MyData(Dataset):
    def __init__(self):
        pass

    def __getitem__(self, index): 
        pass

    def __len__(self):
        pass

在实现MyData类之前,首先需要读取图像数据

4、读取图像数据

  • 方法一:使用PIL
  • 方法二:使用OpenCV

使用PIL:

from PIL import Image
img_path = "dataset/train/ants/0013035.jpg"
img = Image.open(img_path)
img.show()

可以在Python控制台看到读取出来的 img,是一个JpegImageFile类的对象,展开还可以查看这个对象的属性,比如size。

 

使用OpenCV:

import cv2
img_path = "dataset/train/ants/0013035.jpg"
cv_img = cv2.imread(img_path)
cv2.imshow('cv_img',cv_img)
cv2.waitKey(0) ##等待时间,毫秒级,0表示任意键终止
cv2.destroyAllWindows()

 可以在Python控制台看到读取出来的cv_img,是一个ndarray类的对象。

 5、获取图片的文件名

借助os模块,从数据集路径中,获取所有文件的文件名,存储在一个列表中。

import os
dir_path = "dataset/train/ants"
img_path_list = os.listdir(dir_path)

 可以使用索引去访问列表中的每个文件名

 6、继续完善MyData

class MyData(Dataset):
    def __init__(self,root_dir,label_dir):
        self.root_dir = root_dir     #根路径 如“dataset/train”
        self.label_dir = label_dir   #子路径 如“ants”
        self.path = os.path.join(self.root_dir,self.label_dir)  #将路径拼接起来
        self.img_path = os.listdir(self.path)  #该路径下的所有图片名变成一个列表

    def __getitem__(self, index):
        img_name = self.img_path[index] #根据索引从列表中获取了图片名
        img_item_path = os.path.join(self.root_dir,self.label_dir,img_name)  #组装路径,获得了图片的具体路径
        img = Image.open(img_item_path)
        label = self.label_dir
        return img,label

    def __len__(self):
        return len(self.img_path)
  • __init__():接收实例化时传入的参数(根路径子路径),然后将两个路径进行组合,就得到了目标数据集的路径。将这个路径作为参数传入listdir函数,从而让img_path中存储该目录下所有文件名(通过索引就可以轻松获取每个文件名)。
  • __getitem__():index是索引编号,用于对初始化中得到的文件名列表进行索引访问(这样就得到了具体的文件名),然后与根路径和子路径再次组合,得到具体数据(图片)的相对路径。然后用PIL读取具体的图像。返回的img是JpegImageFile对象,label是字符串。
  • __len__():获取数据集的长度,在初始化中已经获取了所有文件名的列表,只需知道这个列表的长度。

7、使用MyData

root_dir = "dataset/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = MyData(root_dir,ants_label_dir)
bees_dataset = MyData(root_dir,bees_label_dir)

8、完整代码

from torch.utils.data import Dataset
from PIL import Image
import os

class MyData(Dataset):
    def __init__(self,root_dir,label_dir):
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir,self.label_dir) 
        self.img_path = os.listdir(self.path)

    def __getitem__(self, index):
        img_name = self.img_path[index]
        img_item_path = os.path.join(self.root_dir,self.label_dir,img_name) 
        img = Image.open(img_item_path)
        label = self.label_dir
        return img,label

    def __len__(self):
        return len(self.img_path)

root_dir = "dataset/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = MyData(root_dir,ants_label_dir)
bees_dataset = MyData(root_dir,bees_label_dir)

8、补充

  • 获取数据集中指定的数据
ants_dataset[0]

img,label = ants_dataset[0]
img.show()

  • 拼接数据集

用处:有时候数据集不足就需要仿造数据集,需要将仿造数据集和真实数据集结合后训练

train_dataset = ants_dataset + bees_dataset

 变成数据集的第二种组织形式:当label比较复杂,存储数据比较多时,是以每张图片对应一个txt文件,txt里存储label信息的方式。

为数据集中的每张图片生成对应的txt文件:

import os

 root_dir = 'dataset-练习/train'
 target_dir = 'ants_image'
 img_path = os.listdir(os.path.join(root_dir, target_dir))
 label = target_dir.split('_')[0]
 out_dir = 'ants_label'
 for i in img_path:
     file_name = i.split('.jpg')[0]
     with open(os.path.join(root_dir, out_dir,"{}.txt".format(file_name)),'w') as f:
         f.write(label)
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值