加载数据集
import torch
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
'''
Dataset是一个抽象函数,不能直接实例化,所以我们要创建一个自己类,继承Dataset
继承Dataset后我们必须实现三个函数:
__init__()是初始化函数,之后我们可以提供数据集路径进行数据的加载
__getitem__()帮助我们通过索引找到某个样本
__len__()帮助我们返回数据集大小
'''
class DiabetesDataset(Dataset):
def __init__(self,filepath):
xy = np.loadtxt(filepath,delimiter=',',dtype=np.float32)
#shape本身是一个二元组(x,y)对应数据集的行数和列数,这里[0]我们取行数,即样本数
self.len = xy.shape[0]
self.x_data = torch.from_numpy(xy[:, :-1])
self.y_data = torch.from_numpy(xy[:, [-1]])
def __getitem__(self, index):
return self.x_data[index],self.y_data[index]
def __len__(self):
return self.len
#定义好DiabetesDataset后我们就可以实例化他了
dataset = DiabetesDataset('./data/Diabetes_class.csv.gz')