PyTorch DataLoader、DataSet

本文深入探讨了PyTorch中DataLoader组件的工作原理,详细解释了如何通过设置batch_size和shuffle参数来实现数据的批量加载与随机化处理,这对于理解深度学习训练流程中的数据预处理环节至关重要。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

研究了下DataLoader大批量加载数据的原理:DataLoader只负责数据的抽象,一次调用getitem只返回一个样本

import torch
from torch.utils.data import DataLoader,Dataset
import numpy as np
import pandas as pd

class DataSet(Dataset):
    def __init__(self):
        data = np.loadtxt('/Users/yq/Desktop/test_data/new.csv',delimiter=',',dtype=np.float32)
        self.x_data = torch.from_numpy(data[:,0:-1])
        self.y_data = torch.from_numpy(data[:,-2:)
        self.len = data.shape[0]
    def __getitem__(self,index):
        return self.x_data[index],self.y_data[index]
    def __len__(self):
        return self.len


dataset = DataSet()
data_loader = DataLoader(dataset=dataset,batch_size=5,shuffle=True)


for i,data in enumerate(data_loader):
    item,label = data
    print(item)
    print(label)

结果示例如下:

在DataLoader中设置了batch_size,shuffle,查看了下源码

for i,data in enumerate(data_loader):

i是range(self.len / batch_size),表示迭代的批次,而每次迭代时,batch_size行数据因为shuffle会被随机选择,如上图所示。

在这里__getitem__内置函数的作用在于:在当前的shuffle序列[9,22,15,16,19],index=9,则返回一行数据(self.x_data[index],self.y_data[index]),以此类推,迭代了len(batch_size),获取了当前批次的数据,开始准备进行训练。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值