python DataSet+ Dataloader 深度学习编程细节_数据集 pytorchDataset的构建与使用

本文介绍了深度学习中数据集的处理,包括如何构建自定义的`DataSet`类,重写`__init__`, `__getitem__`, `__len__`方法,并使用`DataLoader`进行批量加载。示例中展示了简单的数据集构造和一个MLP模型的训练流程。同时提到了图像分割和自监督任务的数据集构建方法,以及如何处理可能出现的异常情况,如`IndexError`。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

  • 深度学习中许多网络的设计都需数据集的预处理功能辅助,本文对DataSet + Dataloader 的使用做介绍。

DataSet构建(简单示例)

        构建数据集需要继承torch.utils.data.dataset的Dataset类重写init,getitem(self, mask),len三个方法。然后使用torch.utils.data import DataLoader来加载你创建的数据集Dataset。

import argparse
import os
import random
import shutil
import time
import warnings
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models

import numpy as np
import os, imageio


from torch.utils.data.dataset import Dataset
class MyDataSet(Dataset):
    def __init__(self, data, label):#传入参数是我们的数据集(data)和标签集(label)
        self.data = data
        self.label = label
        self.length = data.shape[0]

    def __getitem__(self, mask):# 获取返回数据的方法,传入参数是一个index,也被叫做mask,就是我们对数据集的选择索引。在调用DataLoader时就会自己生成index,所以我们只需要写好方法即可。
        label = self.label[mask]
        data = self.data[mask]
        return label, data

    def __len__(self):
        # print(self.length)
        return self.length



train_set = MyDataSet(xb,yb)# xb,yb为所有的数据
# train_set = MyDataSet(data=X_train, label=Y_train)
num_epoch = 100     # number of epochs to train on
batch_size = 1024  # training batch size
train_data = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)

class MLP(nn.Module):
    def __init__(self,depth=4,mapping_size=2,hidden_size=256):
        super().__init__()
        layers = []
        layers.append(nn.Linear(mapping_size,hidden_size))
        layers.append(nn.ReLU(inplace=True))
        for _ in range(depth-2):
            layers.append(nn.Linear(hidden_size,hidden_size))
            layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Linear(hidden_size,3))
        self.layers = nn.Sequential(*layers)
    def forward(self,x):
        return torch.sigmoid(self.layers(x))
model = MLP()
for epoch in range(num_epoch ):
    model.train()
    for batchsz, (label, data) in enumerate(train_data):
        # i表示第几个batch, data表示该batch对应的数据,包含data和对应的labels
        print("第 {} 个Batch size of label {} and size of data{}".format(batchsz, label.shape, data.shape))

图像的分割处理数据集的构建

添加链接描述
添加链接描述

构建自监督任务的数据集(用一个数据集构建正负样本)

from torch.utils.data.dataset import Dataset
class MyDataSet(Dataset):
    def __init__(self, data, label):#传入参数是我们的数据集(data)和标签集(label)
        self.data = data
        self.label = label
        self.length = data.shape[0]

    def __getitem__(self, mask):# 获取返回数据的方法,传入参数是一个index,也被叫做mask,就是我们对数据集的选择索引。在调用DataLoader时就会自己生成index,所以我们只需要写好方法即可。
        label = self.label[mask]
        data = self.data[mask]
        return label, data

    def __len__(self):
        # print(self.length)
        return self.length

C&G

后续(+捕获异常)

image = np_load_frame(self.videos[video_name]['frame'][frame_name+i], self._resize_height, self._resize_width)
IndexError: list index out of range

先加个捕获异常:

    def __getitem__(self, index):
        video_name = self.samples[index].split('/')[-2]
        frame_name = int(self.samples[index].split('/')[-1].split('.')[-2])

        batch = []
        for i in range(self._time_step+self._num_pred):
            try:
                image = np_load_frame(self.videos[video_name]['frame'][frame_name+i], self._resize_height, self._resize_width)
            except :
                print('error from --- model utils')
                print(frame_name)
                print(i)
            if self.transform is not None:
                batch.append(self.transform(image))

        return np.concatenate(batch, axis=0)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值