一文搞懂Pytorch数据读取机制!

PyTorch中的数据读取通过Dataset和Dataloader完成,Dataset定义数据集,包含数据和标签,支持预处理操作;Dataloader负责分批、打乱顺序读取数据,支持多进程加载和调整batch_size等参数。理解这两个组件对于高效训练深度学习模型至关重要。

在这里插入图片描述

熟悉深度学习的小伙伴一定都知道:深度学习模型训练主要由数据、模型、损失函数、优化器以及迭代训练五个模块组成。如下图所示,Pytorch数据读取机制则是数据模块中的主要分支。

Pytorch数据读取是通过Dataset+Dataloader的方式完成。其中,

  • DataSet:定义数据集。将原始数据样本及对应标签映射到Dataset,便于后续通过index读取数据。同时,还可以在Dataset中进行数据格式变换、数据增强等预处理操作。

  • DataLoader:迭代读取数据集。将数据样本进行分批次Batch、打乱顺序Shuffle等处理,便于训练时迭代读取数据。

Dataset

Dataset用于解决数据从哪里读取以及如何读取的问题。 Pytorch给定的Dataset是一个抽象类,所有自定义的数据集都要继承Dataset,并重写__init__()、__getitem__()和__len__() 类方法,以供DataLoader类直接调用。

  • __init__:数据集初始化。

  • __getitem__:定义指定索引如何获取样本数据,最终返回index对应的样本对{样本数据x:标签y}。

  • __len__:数据集的样本数。

下面是以cifar10数据集为例实现Dataset自定义数据集的代码样例。

from torch.utils.data import Dataset  
from PIL import Image  
import os  
  
class Mydata(Dataset):  
    """  
    步骤一:继承 torch.utils.data.Dataset 类  
    """  
    def __init__(self,data_dir,label_dir):  
        """  
        步骤二:实现 __init__ 函数,初始化数据集,将样本和标签映射到列表中  
        """  
        self.data_dir = data_dir  
        self.label_dir = label_dir  
        # 用join把路径拼接一起可以避免一些因“/”引发的错误  
        self.path = os.path.join(self.data_dir,self.label_dir)  
        # 将该路径下的所有文件变成一个列表  
        self.img_path = os.listdir(self.path)  
  
    def __getitem__(self,idx):  
        """  
        步骤三:实现 __getitem__ 函数,定义指定 index 时如何获取数据,并返回单条数据(样本数据、对应的标签)  
        """  
        # 根据index(idx),从列表中取出图片  
        # img_path列表里每个元素就是对应图片文件名  
        img_name = self.img_path[idx]  
        # 获得对应图片路径  
        img_item_path = os.path.join(self.data_dir,self.label_dir,img_name)  
        # 使用PIL库下Image工具,打开对应路径图片  
        img = Image.open(img_item_path)  
        label = self.label_dir  
        # 返回图片和对应标签  
        return img,label  
  
    def __len__(self):  
        """  
        步骤四:实现 __len__ 函数,返回数据集的样本总数  
        """  
        return len(self.img_path)  
  
# data_dir,label_dir可自定义数据集目录  
train_custom_dataset = MyData(data_dir,label_dir)  
test_custom_dataset = MyData(data_dir,label_dir)  
  



DataLoader

在实际项目中,当数据量很大,考虑到内存有限、I/O速度等问题,训练中不可能一次性将所有数据加载到内存或者只用一个进行加载数据,此时就需要的是多进程、迭代加载,Dataloader便应运而生。

DataLoader是一个可迭代的数据装载器,组合了数据集和采样器,并在给定数据集上提供可迭代对象。可以完成对数据集中多个对象的集成。

Pytorch的数据读取机制中DataLoader模块包括Sampler和Dataset两个子模块,其中Sampler模块生成索引index;Dataset模块是根据索引读取数据。DataLoader读取数据流程如下图所示。

  • DataLoader:进入DataLoader模块。

  • DataloaderIter:进入__iter__函数判断是否采用多进程,并进入相应的读取机制。

  • Sampler:通过采样,挑选每个Batchsize该读取的数据,并返回这些数据的index。

  • index:一个batchsize数据的索引。

  • DatasetFetcher:获取index对应的数据。

  • Dataset:调用dataset[idx]获取相应数据,并拼接成list。

  • getitem:Dataset的核心,用索引获取数据。

  • img,label:读取到的数据。

  • collate_fn:将读取的数据从list转为batch形式。

  • Batch Data:batch形式数据,第一个元素是图像,第二个元素是标签。

Pytorch中DataLoader类定义如下:

class torch.utils.data.DataLoader(  
     """  
     构建可迭代的数据装载器,训练时,每一个for循环,每一次迭代,  
     从DataLoader中获取一个batch_size大小的数据  
     """  
     dataset,  
     batch_size=1,  
     shuffle=False,  
     sampler=None,  
     batch_sampler=None,  
     num_workers=0,  
     collate_fn=None,  
     pin_memory=False,  
     drop_last=False,  
)  



  • dataset:需要加载的数据集,Dataset对象。

  • batch_size:每批次读取样本数。例如batch_size=16表示每批次读取16个样本。

  • shuffle:每个epoch是否乱序。shuffle=True表示在取数据时打乱样本顺序,以减少过拟合发生的可能。

  • sampler:索引index。

  • batch_sampler:将返回一个索引的sampler进行包装,按照设定的batch_size返回一组索引。

  • num_workers:同步/异步读取数据。num_workers=0表示数据加载是同步的,在主进程中完成。num_workers的值设为大于0时,即开启多进程方式异步加载数据,可提升数据读取速度。

  • pin_memory:是否将数据拷贝到拷贝到临时缓冲区。

  • collate_fn:将多个样本组合在一起变成一个mini-batch,不指定该函数的话会调用Pytorch内部默认的函数。

  • drop_last:丢弃不完整的批次样本,drop_last=True表示当数据集样本数不能被batch_size整除时,则丢弃最后一个不完整的batch样本。

补充说明

Epoch:所有训练样本都已输入到模型中,称为一个epoch

Iteration:一批样本(batch_size)输入到模型中,称为一个Iteration。

Batchsize:一批样本的大小,称为Batchsize。用于决定一个epoch有多少个Iteration。

代码实现示例如下。

import torch  
import torch.utils.data as Data  
  
BATCH_SIZE = 5  
  
x = torch.linspace(1, 10, 10)  
y = torch.linspace(10, 1, 10)  
  
# 将数据集转换为torch可识别的类型  
torch_dataset = Data.TensorDataset(x, y)  
  
loader = Data.DataLoader(  
    dataset=torch_dataset,  
    batch_size=BATCH_SIZE,  
    shuffle=True,  
    num_workers=0  
)  
  
for epoch in range(3):  
    for step, (batch_x, batch_y) in enumerate(loader):  
        print('epoch', epoch,  
              '| step:', step,  
              '| batch_x', batch_x.numpy(),  
              '| batch_y:', batch_y.numpy())  



通过上述方法即可初始化一个数据读取器loader,用于加载训练数据集torch_dataset。

### Seq2Seq模型的基本概念 Seq2Seq(Sequence to Sequence)模型是一种深度学习架构,主要用于处理输入和输出均为序列的任务。其核心思想是通过端到端的方式直接学习输入-输出序列之间的复杂映射关系,而不需要依赖于繁琐的特征工程和规则设计。该模型最初广泛应用于机器翻译领域,但其适用范围远不止于此[^1]。 ### Seq2Seq模型的工作原理 Seq2Seq模型通常由两个主要部分组成:**编码器(Encoder)** 和 **解码器(Decoder)**。在一些改进版本中,还会加入**注意力机制(Attention Mechanism)** 来进一步提升模型性能。 #### 编码器(Encoder) 编码器的作用是将输入序列(例如一个句子)转换为一个包含语义信息的**上下文向量(Context Vector)**,通常也称为**隐藏状态(Hidden State)**。这个过程是通过递归神经网络(RNN)、长短时记忆网络(LSTM)或门控循环单元(GRU)等时序模型完成的。编码器逐步读取输入序列中的每个元素,并更新其内部状态,最终输出一个包含整个序列信息的固定长度向量[^3]。 #### 解码器(Decoder) 解码器根据编码器生成的上下文向量,逐步生成目标序列。它同样使用RNN、LSTM或GRU等结构,但其任务是从上下文向量出发,逐个生成目标序列中的元素。解码器的第一个输入通常是特殊的起始符号(如 `<sos>`),然后根据前一个时刻的输出和隐藏状态生成下一个元素,直到遇到结束符号(如 `<eos>`)为止[^3]。 #### 注意力机制(Attention Mechanism) 在标准的Seq2Seq模型中,解码器仅依赖于编码器的最后一个隐藏状态。然而,这种做法在处理长序列时容易丢失信息。注意力机制的引入允许解码器在生成每个目标词时,关注输入序列中与之最相关的部分,而不是仅仅依赖于一个固定的上下文向量。这种方式显著提升了模型对长序列的处理能力和生成质量[^1]。 ### Seq2Seq模型的训练方式 在训练过程中,Seq2Seq模型通常采用**教师强制(Teacher Forcing)**策略,即在解码阶段,使用真实的前一个目标词作为当前时刻的输入,而不是使用模型自己预测的结果。这有助于加快训练过程并提高收敛速度。但在推理阶段,则需要使用自回归的方式生成目标序列。 为了提升生成质量,推理阶段常采用**束搜索(Beam Search)**策略,即在每一步保留多个可能的候选序列,最终选择概率最高的完整序列作为输出[^2]。 ### 应用场景 Seq2Seq模型因其强大的序列建模能力,被广泛应用于多个自然语言处理(NLP)任务中,包括但不限于: - **机器翻译**:将一种语言的句子翻译成另一种语言。 - **文本摘要**:从长文本中生成简短的摘要。 - **对话系统**:构建基于文本的对话机器人。 - **语音识别**:将语音信号转换为文本。 - **图像描述生成**:根据图像内容生成自然语言描述。 ### 示例代码(基于PyTorch) 以下是一个简化的Seq2Seq模型实现,使用LSTM作为编码器和解码器: ```python import torch import torch.nn as nn class Seq2Seq(nn.Module): def __init__(self, input_size, hidden_size, output_size): super(Seq2Seq, self).__init__() self.encoder = nn.LSTM(input_size, hidden_size, batch_first=True) self.decoder = nn.LSTM(hidden_size, output_size, batch_first=True) self.fc = nn.Linear(hidden_size, output_size) def forward(self, x): encoder_out, (h, c) = self.encoder(x) decoder_out, _ = self.decoder(encoder_out) output = self.fc(decoder_out) return output ``` ###
评论 1
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值