DataSet类(1)

DataSet类(1)

一、DataSet类

**原文链接:**https://zhuanlan.zhihu.com/p/500839903

DataSet类,可以帮助我们提取我们所需要的数据,用子类继承的DataSet类,给每一个数据进行编号(idx),在后面的神经网络中,初始化DataSet子类的实例后,就可以通过编号去实例对象中读取相应的数据,会自动调用__getitem__方法。同时子类对象也会获得相应真实的Label

**DataSet作用:**提供一种方式去获取数据及其对应真实的Label

DataSet类的子类,需要重写的函数:

获取每一个数据、以及其对应的Label
统计数据集中的数据数量    
1.DataSet类官方解读

在pycharm终端中输入:
在这里插入图片描述

得到以下输出:

Help on class Dataset in module torch.utils.data.dataset:

class Dataset(typing.Generic)
 |  An abstract class representing a :class:`Dataset`.
 |  
 |  All datasets that represent a map from keys to data samples should subclass
 |  it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
 |  data sample for a given key. Subclasses could also optionally overwrite
 |  :meth:`__len__`, which is expected to return the size of the dataset by many
 |  :class:`~torch.utils.data.Sampler` implementations and the default options
 |  of :class:`~torch.utils.data.DataLoader`.
 |  
 |  .. note::
 |    :class:`~torch.utils.data.DataLoader` by default constructs a index
 |    sampler that yields integral indices.  To make it work with a map-style
 |    dataset with non-integral indices/keys, a custom sampler must be provided.
 |  
 |  Method resolution order:
 |      Dataset
 |      typing.Generic
 |      builtins.object
 |  
 |  Methods defined here:
 |  
 |  __add__(self, other:'Dataset[T_co]') -> 'ConcatDataset[T_co]'
 |  
 |  __getattr__(self, attribute_name)
 |  
 |  __getitem__(self, index) -> +T_co
 |  
 |  ----------------------------------------------------------------------
 |  Class methods defined here:
 |  
 |  register_datapipe_as_function(function_name, cls_to_register, enable_df_api_tracing=False) from typing.GenericMeta
 |  
 |  register_function(function_name, function) from typing.GenericMeta
 |  
 |  ----------------------------------------------------------------------
 |  Data descriptors defined here:
 |  
 |  __dict__
 |      dictionary for instance variables (if defined)
 |  
 |  __weakref__
 |      list of weak references to the object (if defined)
 |  
 |  ----------------------------------------------------------------------
 |  Data and other attributes defined here:
 |  
 |  __abstractmethods__ = frozenset()
 |  
 |  __annotations__ = {'functions': typing.Dict[str, typing.Callable]}
 |  
 |  __args__ = None
 |  
 |  __extra__ = None
 |  
 |  __next_in_mro__ = <class 'object'>
 |      The most base type
 |  
 |  __orig_bases__ = (typing.Generic[+T_co],)
 |  
 |  __origin__ = None
 |  
 |  __parameters__ = (+T_co,)
 |  
 |  __tree_hash__ = -9223371872509358054
 |  
 |  functions = {'concat': functools.partial(<function Dataset.register_da...
 |  
 |  ----------------------------------------------------------------------
 |  Static methods inherited from typing.Generic:
 |  
 |  __new__(cls, *args, **kwds)
 |      Create and return a new object.  See help(type) for accurate signature.

该类是一个抽象类,所有的数据集想要在数据与标签之间建立映射,都需要继承这个类,所有的子类都需要重写__getitem__方法,该方法根据索引值获取每一个数据并且获取其对应的Label,子类也可以重写__len__方法,返回数据集的size大小

二、蚂蚁蜜蜂分类数据集:

https://download.pytorch.org/tutorial/hymenoptera_data.zip

代码.py

from torch.utils.data import Dataset
from PIL import Image
import os

class MyData(Dataset):
    def __init__(self,root_dir,label_dir):
        self.root_dir = root_dir
        self.label_dir = label_dir
        #  将路径进行合并
        self.path = os.path.join(self.root_dir,self.label_dir)
        # os.listdir(self.path) 将路径下的文件 装入列表进行返回
        self.img_path = os.listdir(self.path)


    def __getitem__(self, index):  # idx是index的简称,就是一个编号,以便以后数据集获取后,我们使用索引编号访问每个数据
        img_name = self.img_path[index]  # 按照索引取得列表下的文件名字
        # 路径合并,得到对应索引的文件路径
        img_item_path = os.path.join(self.root_dir,self.label_dir,img_name)
        # 获得路径下的图片
        img = Image.open(img_item_path)
        label = self.label_dir
        return img,label

    def __len__(self):
        # 得到数据集的长度
        return len(self.img_path)


if __name__ == '__main__':
    root_dir = r"/home/zxz/DEEPLEARNING/DEMO/dataset_1/hymenoptera_data/train"
    ants_label_dir = "ants"
    bees_label_dir = "bees"
    # 得到ants数据集
    ants_dataset = MyData(root_dir,ants_label_dir)
    bees_dataset = MyData(root_dir,bees_label_dir)

    # 获得蚂蚁数据集中的数据  同 ants_dataset.__getitem__(0)
    img,label = ants_dataset[0]
    # 显示图片
    img.show()
    print(label)

    img,label = bees_dataset[0]
    img.show()
    print(label)

    # 将两个数据集进行合并
    train_dataset  = ants_dataset + bees_dataset
    # 得到总数据集的数据量 同 train_dataset.__len__()
    print(len(train_dataset))

    img,label = train_dataset[0]
    img.show()

**将数据集进行拼接:**当开源的数据集存在不足时,我们可以自己制作相关数据集与已有的数据集进行拼接,再去训练

<think>好的,用户想了解Python中的Dataset,特别是其使用方法和示例。首先,我需要确认用户指的是哪个框架的Dataset,因为不同的库可能有不同的实现。根据提供的引用信息,用户提到了PyTorch中的Dataset和DataLoader,引用[2]详细讨论了这两个,所以可以确定用户关注的是PyTorch的Dataset。 接下来,我需要整理Dataset的基本用法。根据PyTorch的文档,Dataset是一个抽象,用户需要继承它并实现__len__和__getitem__方法。这两个方法是必须的,分别用于获取数据集的大小和通过索引获取数据样本。 然后,我应该给出一个具体的示例,帮助用户理解如何自定义Dataset。例如,创建一个用于处理图像数据集的,其中需要包含初始化方法、__len__和__getitem__。在示例中,可能需要使用到像os模块来遍历文件,PIL库加载图像,以及处理标签的方法。同时,需要展示如何与DataLoader结合使用,进行批量加载和数据打乱等操作。 此外,用户可能对DataLoader的参数感兴趣,比如batch_size、shuffle等,这些参数的作用需要简要说明。还要提到Dataset的其他型,如TensorDataset,方便用户在不同场景下选择。 需要确保回答结构清晰,分步骤说明,并包含代码示例。同时,根据用户的要求,生成相关问题,可能涉及自定义数据集、DataLoader的参数、内置数据集的使用、处理大数据集的方法以及数据增强技巧。这些问题应该覆盖用户可能进一步探索的方向,帮助他们深入学习。 最后,检查所有LaTeX格式是否正确,行内公式用$...$,独立公式用$$...$$,并确保引用标识正确添加在相关段落末尾。例如,在提到Dataset和DataLoader时引用[^2],在提到数据处理时引用[^3]。 总结回答结构:首先介绍Dataset的作用,然后详细说明如何自定义Dataset,给出代码示例,接着介绍DataLoader的使用,最后列举常见问题及相关扩展,如TensorDataset和数据增强。确保每个部分都有对应的引用标注,并生成相关问题供用户参考。</think>### Python Dataset使用方法和示例 在PyTorch中,`Dataset`用于封装数据,提供统一的数据访问接口,而`DataLoader`则负责批量加载数据以优化训练流程[^2]。以下是具体的使用方法和示例: --- #### 1. **自定义Dataset** 需要继承`Dataset`并实现两个核心方法: - `__len__()`:返回数据集大小 - `__getitem__(index)`:根据索引返回单个样本 **示例:图像数据集** ```python from torch.utils.data import Dataset from PIL import Image import os class CustomImageDataset(Dataset): def __init__(self, img_dir, transform=None): self.img_dir = img_dir self.img_names = os.listdir(img_dir) self.transform = transform # 数据预处理函数 def __len__(self): return len(self.img_names) def __getitem__(self, idx): img_path = os.path.join(self.img_dir, self.img_names[idx]) image = Image.open(img_path).convert('RGB') label = 0 if "cat" in self.img_names[idx] else 1 # 假设文件名包含标签信息 if self.transform: image = self.transform(image) return image, label ``` --- #### 2. **与DataLoader结合使用** ```python from torch.utils.data import DataLoader # 初始化Dataset实例 dataset = CustomImageDataset(img_dir="./data/train", transform=ToTensor()) # 创建DataLoader dataloader = DataLoader( dataset, batch_size=32, shuffle=True, num_workers=4 # 多进程加载 ) # 训练时迭代数据 for batch_images, batch_labels in dataloader: # 输入模型训练 outputs = model(batch_images) ``` --- #### 3. **内置Dataset的简化用法** PyTorch提供预定义数据集(如`MNIST`、`CIFAR10`): ```python from torchvision.datasets import MNIST from torchvision.transforms import ToTensor mnist_dataset = MNIST(root="./data", train=True, download=True, transform=ToTensor()) ``` --- #### 4. **TensorDataset快速封装** 若数据已存储在张量中: ```python from torch.utils.data import TensorDataset import torch features = torch.randn(1000, 3) # 1000个样本,3维特征 labels = torch.randint(0, 2, (1000,)) # 二分标签 tensor_dataset = TensorDataset(features, labels) ``` --- #### 5. **数据增强技巧** 通过`transform`参数实现数据预处理: ```python from torchvision import transforms transform = transforms.Compose([ transforms.Resize(256), transforms.RandomHorizontalFlip(), transforms.ToTensor() ]) ``` --- ### 常见问题解答 1. **如何自定义标签解析?** 在`__init__`中解析标签文件(如CSV),存储为列表或字典,在`__getitem__`中按索引返回。 2. **DataLoader的`num_workers`如何选择?** 通常设置为CPU核心数,但需避免内存溢出。 3. **如何处理超大无法内存加载的数据?** 在`__getitem__`中按需读取文件,避免初始化时全加载。 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值