pytorch学习3之加载数据集

本文介绍了如何在PyTorch中加载数据集,包括DataLoader的使用,如设置`shuffle`参数用于每轮训练时的数据洗牌,`batch_size`指定每个批次样本数量,以及`num_workers`设定多进程加载数据的线程数。通过`enumerate(train_loader, 0)`进行遍历,详细解析了其内部的可迭代序列和起始位置。" 92264596,8257132,嵌入式定时器/计数器实验解析,"['嵌入式开发', '定时器实验', '微控制器']

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

import torch
from torch.utils.data import Dataset  #Dataset是抽象类,不能实例化,只能继承
from torch.utils.data import DataLoader
import numpy as np


class Set_Dataset(Dataset):
    def __init__(self, filepath):
        xy = np.loadtxt(filepath, delimiter=',', dtype='float32')
        self.len = xy.shape[0]
        self.x_data = torch.from_numpy(xy[:, :-1])
        self.y_data = torch.from_numpy(xy[:, [-1]])

    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]

    def __len__(self):
        return self.len


class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear1 = torch.nn.Linear(8, 6)
        self.linear2 = torch.nn.Linear(6, 4)
        self.linear3 = torch.nn.Linear(4, 1)
        self.sigmoid = torch.nn.Sigmoid()
    def forward(self, x):
        x = self.sigmoid(self.linear1(x))
        x = self.sigmoid(self.linear2(x))
        x = self.sigmoid(self.linear3(x))
        return x

if __name__ == "__main__":
    dataset = Set_Dataset('C:/python3/envs/pytorch/atest_torch/data/diabetes.csv')
    train_loader = DataLoader(dataset, shuffle=True, batch_size=32, num_workers=2)

    model = Model()
    criterion = torch.nn.BCELoss(reduction='mean')
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

    for epoch in range(20):
        for i, data in enumerate(train_loader, 0):
            # 1. Prepare data
            inputs, labels = data
            # 2. Forward
            y_pred = model(inputs)
            loss = criterion(y_pred, labels)
            print(epoch, i, loss.item())
            # 3. Backward
            optimizer.zero_grad()
            loss.backward()
            # 4. Update
            optimizer.step()

运行结果:19 4 0.6433802247047424
19 5 0.6984792947769165
19 6 0.6620081663131714
19 7 0.6072109341621399
19 8 0.6251083016395569
19 9 0.6068292260169983
19 10 0.6617796421051025
19 11 0.6436219811439514
19 12 0.735340416431427
19 13 0.6616771817207336
19 14 0.6253309845924377
19 15 0.6617479920387268
19 16 0.5704611539840698
19 17 0.5883565545082092
19 18 0.6800984740257263
19 19 0.6618121862411499
19 20 0.6064887642860413
19 21 0.6620679497718811
19 22 0.6439443826675415
19 23 0.6199027299880981

总结:

1、train_loader = DataLoader(dataset, shuffle=True, batch_size=32,num_workers=2)
shuffle -洗牌-set to True to have the data reshuffled at every epoch (default: False).
指是否打乱抽样
batch_size 每个批次的大小
num_workers 多进程的数量

2、enumerate(train_loader, 0)
class enumerate()
def init(self, iterable: Iterable[_T], start: int = 0)
Initialize self. See help(type(self)) for accurate signature.

iterable: Iterable[_T]指一个可迭代序列(train_loader)
start: int = 0指开始位置(0)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值