Pytorch深度学习(五):加载数据集以及mini-batch的使用

Pytorch深度学习(五):加载数据集以及mini-batch的使用

一、预备知识

  • Dataset是一个抽象函数,不能直接实例化,所以我们要创建一个自己类,继承Dataset
    继承Dataset后我们必须实现三个函数:
    init()是初始化函数,之后我们可以提供数据集路径进行数据的加载
    getitem()帮助我们通过索引找到某个样本
    len()帮助我们返回数据集大小
class DiabetesDataset(Dataset):
    def __init__(self, filepath):
        xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)
        self.len = xy.shape[0]
        self.xdata = torch.from_numpy(xy[:, :-1])
        self.ydata = torch.from_numpy(xy[:, [-1]])

    def __getitem__(self, index):
        return self.xdata[index], self.ydata[index]

    def __len__(self):
        return self.len
  • DataLoader为数据进行分组,batch_size是一个组中有多少个样本,shuffle表示要不要对样本进行随机排列。一般来说,训练集我们随机排列,测试集不需要。num_workers表示我们可以用多少进程并行的运算,由于我的版本原因(cuda不好使),只能选择num_workers=0,一般可以写num_workers=2,进行并行运算算提高速度。
dataset = DiabetesDataset('diabetes.csv.gz')
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值