目录
知识点回顾:
- Dataset类的__getitem__和__len__方法(本质是python的特殊方法)
- Dataloader类
- minist手写数据集的了解
在遇到大规模数据集时,显存常常无法一次性存储所有数据,所以需要使用分批训练的方法。为此,PyTorch提供了DataLoader类,该类可以自动将数据集切分为多个批次batch,并支持多线程加载数据。此外,还存在Dataset类,该类可以定义数据集的读取方式和预处理方式。
-
DataLoader类:决定数据如何加载
-
Dataset类:告诉程序去哪里找数据,如何读取单个样本,以及如何预处理。
为了引入这些概念,我们现在接触一个新的而且非常经典的数据集:MNIST手写数字数据集。该数据集包含60000张训练图片和10000张测试图片,每张图片大小为28*28像素,共包含10个类别。因为每个数据的维度比较小,所以既可以视为结构化数据,用机器学习、MLP训练,也可以视为图像数据,用卷积神经网络训练。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset # DataLoader是PyTorch中用于加载数据的工具
from torchvision import datasets, transforms # torchvision是一个计算机视觉的库,datasets 和 transforms是其中的模块
import matplotlib.pyplot as plt
# 设置随机种子,使结果可复现
torch.manual_seed(42)

1、数据预处理,该写法非常类似于管道pipe

最低0.47元/天 解锁文章
4121

被折叠的 条评论
为什么被折叠?



