研究了下DataLoader大批量加载数据的原理:DataLoader只负责数据的抽象,一次调用getitem只返回一个样本
import torch
from torch.utils.data import DataLoader,Dataset
import numpy as np
import pandas as pd
class DataSet(Dataset):
def __init__(self):
data = np.loadtxt('/Users/yq/Desktop/test_data/new.csv',delimiter=',',dtype=np.float32)
self.x_data = torch.from_numpy(data[:,0:-1])
self.y_data = torch.from_numpy(data[:,-2:)
self.len = data.shape[0]
def __getitem__(self,index):
return self.x_data[index],self.y_data[index]
def __len__(self):
return self.len
dataset = DataSet()
data_loader = DataLoader(dataset=dataset,batch_size=5,shuffle=True)
for i,data in enumerate(data_loader):
item,label = data
print(item)
print(label)
结果示例如下:
在DataLoader中设置了batch_size,shuffle,查看了下源码
for i,data in enumerate(data_loader):
i是range(self.len / batch_size),表示迭代的批次,而每次迭代时,batch_size行数据因为shuffle会被随机选择,如上图所示。
在这里__getitem__内置函数的作用在于:在当前的shuffle序列[9,22,15,16,19],index=9,则返回一行数据(self.x_data[index],self.y_data[index]),以此类推,迭代了len(batch_size),获取了当前批次的数据,开始准备进行训练。