Pytorch:Dataset&Dataloader

本文详细介绍了PyTorch中的训练epoch、batch size、iteration的概念,并通过实例演示了如何使用DataLoader处理自定义Dataset。重点讲解了如何构建DataLoader,以及其在模型训练中的关键作用。

1. epoch

epochbatchsizeIteration
所有训练样本的一次训练一次训练中训练的样本数量迭代:1个iteration等于使用batchsize个样本训练一次
举个例子,训练集有1000个样本,batchsize=10,那么:
训练完整个样本集需要:
100次iteration,1次epoch。

2.dataset

 Dataset是一个包装类,用来将数据包装为Dataset类,然后传入DataLoader中,
 我们再使用DataLoader这个类来更加快捷的对数据进行操作。
 Dataset是一个抽象类,自定义的dataset必须继承。

3. dataloader

1. 将自定义的Dataset根据batch size大小、是否shuffle等封装成一个Batch Size大小的Tensor,用于后面的训练。
2. DataLoader常用操作有:batch_size(每个batch的大小), shuffle(是否进行shuffle操作), 
num_workers(加载数据的时候使用几个子进程)

在这里插入图片描述
在这里插入图片描述


import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
#Dataset is an abstract class. We can define our class inherited from
#DataLoader is a class to help us loading data in PyTorch this class.
#DiabetesDataset is inherited from abstract class Dataset.
class DiabetesDataset(Dataset):
    def __init__(self,filepath):
        datas=np.loadtxt(filepath,delimiter=',',dtype=np.float32)
        self.len=datas.shape[0]
        self.x_data=torch.from_numpy(datas[:,:-1])
        self.y_data=torch.from_numpy(datas[:,[-1]])

    def __getitem__(self,index):#dataset[index]
        return self.x_data[index],self.y_data[index]

    def __len__(self):#returns length of dataset.
        return self.len

dataset=DiabetesDataset('diabetes.csv')# Construct DiabetesDataset object
train_loader=DataLoader(dataset=dataset,batch_size=32,shuffle=True,num_workers=2)
#Initialize loader with batch-size,shuffle, process number.

class Model(torch.nn.Module):
    def __init__(self):
        super(Model,self).__init__()
        self.linear1=torch.nn.Linear(8,6)
        self.linear2=torch.nn.Linear(6,4)
        self.linear3=torch.nn.Linear(4,1)
        self.sigmoid=torch.nn.Sigmoid()

    def forward(self, x):
        x=self.sigmoid(self.linear1(x))
        x=self.sigmoid(self.linear2(x))
        x=self.sigmoid(self.linear3(x))
        return x
model=Model()

criterion=torch.nn.BCELoss(reduction='mean')
optimizer=torch.optim.SGD(model.parameters(),lr=0.01)

if __name__ == '__main__':
    for epoch in range(100):
        for i,data in enumerate(train_loader,0):
            #1.Prepare data
            inputs,labels=data
            #2.Forward
            y_predict=model(inputs)
            loss=criterion(y_predict,labels)
            print(epoch,loss.item())
            #3.Backward
            optimizer.zero_grad()
            loss.backward()
            #4.Update
            optimizer.step()

<think>我们正在回答用户关于PyTorchDataset和DataLoader使用方法的问题。根据之前的对话历史,用户已经了解了SRCNN模型的实现和训练流程,现在专注于数据加载部分。 我们将结合用户提供的引用内容,详细说明如何使用Dataset和DataLoader导入数据。 步骤: 1. 创建一个自定义Dataset类,继承自torch.utils.data.Dataset。 2. 实现Dataset的三个核心方法:__init__、__len__和__getitem__。 3. 使用DataLoader封装Dataset,实现批量加载、打乱数据、多进程读取等功能。 下面我们详细展开: </think>### PyTorchDataset和DataLoader导入数据的使用方法 #### 一、核心概念解析 1. **Dataset类** 是表示数据集的抽象类,需要实现三个核心方法: - `__init__`: 初始化数据路径和预处理方法 - `__len__`: 返回数据集样本数量 - `__getitem__`: 根据索引返回单个样本和标签 ```python class CustomDataset(torch.utils.data.Dataset): def __init__(self, data_dir, transform=None): self.file_list = [os.path.join(data_dir, f) for f in os.listdir(data_dir)] self.transform = transform def __len__(self): return len(self.file_list) def __getitem__(self, idx): img = Image.open(self.file_list[idx]) if self.transform: img = self.transform(img) return img ``` 2. **DataLoader类** 包装Dataset实现高效数据加载: - 自动批处理 - 多进程数据加载 - 内存优化 - 数据打乱 #### 二、完整使用流程 ```python import torch from torch.utils.data import Dataset, DataLoader from torchvision import transforms from PIL import Image import os # 1. 定义自定义数据集 class ImageDataset(Dataset): def __init__(self, root_dir, transform=None): self.root_dir = root_dir self.transform = transform self.image_paths = [os.path.join(root_dir, f) for f in os.listdir(root_dir) if f.endswith(('.png', '.jpg'))] def __len__(self): return len(self.image_paths) def __getitem__(self, idx): img_path = self.image_paths[idx] image = Image.open(img_path).convert('RGB') # 转换为RGB格式 if self.transform: image = self.transform(image) return image # 2. 定义数据变换 transform = transforms.Compose([ transforms.Resize(256), # 调整大小 transforms.CenterCrop(224), # 中心裁剪 transforms.ToTensor(), # 转为张量 transforms.Normalize( # 标准化 mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) # 3. 创建数据集实例 train_dataset = ImageDataset( root_dir='./data/train', transform=transform ) # 4. 创建数据加载器 train_loader = DataLoader( dataset=train_dataset, batch_size=32, # 批次大小 shuffle=True, # 打乱数据 num_workers=4, # 多进程加载 pin_memory=True # 加速GPU传输 ) # 5. 使用数据加载器 for batch in train_loader: # batch是形状为[32, 3, 224, 224]的张量 # 在此处输入模型进行训练 ... ``` #### 三、关键参数详解 | 参数 | 默认值 | 说明 | |------|--------|------| | `batch_size` | 1 | 每批次的样本数量 | | `shuffle` | False | 是否打乱数据顺序 | | `num_workers` | 0 | 加载数据的子进程数 | | `pin_memory` | False | 是否将数据复制到CUDA固定内存 | | `drop_last` | False | 是否丢弃最后不完整的批次 | #### 四、最佳实践建议 1. **内存优化**: ```python # 使用生成器避免一次性加载所有数据 def __getitem__(self, idx): img_path = self.image_paths[idx] return Image.open(img_path) # 延迟加载 ``` 2. **多卡训练支持**: ```python # 分布式采样器 sampler = torch.utils.data.distributed.DistributedSampler(dataset) loader = DataLoader(dataset, sampler=sampler) ``` 3. **数据增强策略**: ```python # 训练和验证使用不同变换 train_transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomRotation(10), transforms.ToTensor() ]) val_transform = transforms.Compose([ transforms.ToTensor() ]) ``` #### 五、常见问题解决 1. **数据加载瓶颈**: - 增加`num_workers`(不超过CPU核心数) - 启用`pin_memory=True`(GPU训练时) - 使用SSD替代机械硬盘 2. **内存溢出处理**: ```python # 启用分页加载 loader = DataLoader(..., collate_fn=lambda x: x[0]) ``` 3. **数据不均衡处理**: ```python # 加权随机采样 weights = compute_sample_weights() sampler = WeightedRandomSampler(weights, num_samples) ```
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值