PyTorch深度学习实践 第八讲---Mini-batch数据集

Demo8 数据集划分处理(加载数据集)
来源 B站 刘二大人

说明:

  1. DataSet 是抽象类,不能实例化对象,主要是用于构造我们的数据集。继承DataSet的类需要重写方法返回len和根据索引获取对应的值。getitem()目的是为支持下标(索引)操作。
  2. DataLoader 用来帮助我们处理数据,具体就是做shuffle(提高数据集的随机性),和根据batch_size拿出Mini-Batch进行训练,DataLoader可实例化对象。
    DataLoader的作用示意图
    代码说明:
  3. 需要mini_batch 就需要import DataSet和DataLoader
  4. 继承DataSet的类需要重写init,getitem,len魔法函数。分别是为了加载数据集,获取数据索引,获取数据总量。
  5. DataLoader对数据集先打乱(shuffle),然后划分成mini_batch。
  6. len函数的返回值 除以 batch_size 的结果就是每一轮epoch中需要迭代的次数。
  7. inputs, labels = data中的inputs的shape是[32,8],labels 的shape是[32,1]。也就是说mini_batch在这个地方体现的
  8. diabetes.csv数据集老师给了下载地址,可自行下载

课程代码

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
class DiabetesDataset(Dataset):
    def __init__(self, filepath):
        xy = np.loadtxt(filepath, delimiter=',', dtype = np.float32)
       
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值