import numpy as np
import torch
from torch.utils.data import Dataset # Dataset是一个抽象类,不能实例化,只能被其他的子类继承
from torch.utils.data import DataLoader
class DiabetesDataset(Dataset): # 这个类继承自Dataset
def __init__(self, filepath): # filepath是文件路径
xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32) # 加载数据集
self.len = xy.shape[0]
self.x_data = torch.from_numpy(xy[:, :-1])
self.y_data = torch.from_numpy(xy[:, [-1]])
def __getitem__(self, index): # 通过索引支持下标操作
return self.x_data[index], self.y_data[index]
def __len__(self): # 返回数据集中的数据条数
return self.len
dataset = DiabetesDataset('diabetes.csv') # 实例化自定义的类
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=2)
# dataset传递数据集,batch_size传递一个batch中的样本数量,shuffle表示数据是否打乱,num_workers表示是否要并行化读取数据
# -------------------------------------------
PyTorch深度学习实践——加载数据集
最新推荐文章于 2024-07-18 12:25:42 发布