Pytorch深度学习(五):加载数据集以及mini-batch的使用
一、预备知识
- Dataset是一个抽象函数,不能直接实例化,所以我们要创建一个自己类,继承Dataset
继承Dataset后我们必须实现三个函数:
init()是初始化函数,之后我们可以提供数据集路径进行数据的加载
getitem()帮助我们通过索引找到某个样本
len()帮助我们返回数据集大小
class DiabetesDataset(Dataset):
def __init__(self, filepath):
xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)
self.len = xy.shape[0]
self.xdata = torch.from_numpy(xy[:, :-1])
self.ydata = torch.from_numpy(xy[:, [-1]])
def __getitem__(self, index):
return self.xdata[index], self.ydata[index]
def __len__(self):
return self.len
- 用DataLoader为数据进行分组,batch_size是一个组中有多少个样本,shuffle表示要不要对样本进行随机排列。一般来说,训练集我们随机排列,测试集不需要。num_workers表示我们可以用多少进程并行的运算,由于我的版本原因(cuda不好使),只能选择num_workers=0,一般可以写num_workers=2,进行并行运算算提高速度。
dataset = DiabetesDataset('diabetes.csv.gz')
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=