Pytorch实现mnist手写数字识别

第P1周:Pytorch实现mnist手写数字识别

导入库

#构建和训练神经网络
import torch
#PyTorch的神经网络模块,提供构建网络所需的类和函数
import torch.nn as nn
import matplotlib.pyplot as plt
#用于处理图像数据集的库
import torchvision
import numpy as np
#提供激活函数的库
import torch.nn.functional as F
import warnings

使用dataset下载MNIST数据集,并划分好训练集与测试集

train_data = torchvision.datasets.MNIST("data",
                                        train=True,
                                        transform = torchvision.transforms.ToTensor(),
                                        download=True)
test_data = torchvision.datasets.MNIST("data",
                                       train=False,
                                       transform=torchvision.transforms.ToTensor(),
                                       download=True)
batch_size = 32

使用数据局加载器,处理导入的数据

#是一个 PyTorch 数据加载器(DataLoader),用于加载训练数据集。通常情况下,数据加载器会将数据集分成小批量(batches)进行处理
train_d = torch.utils.data.DataLoader(train_data,
                                      batch_size = batch_size,
                                      shuffle = True)
test_d = torch.utils.data.DataLoader(test_data,
                                      batch_size = batch_size)
#iter(train_dl) 将数据加载器转换为一个迭代器(iterator),使得我们可以使用 Python 的 next() 函数来逐个访问数据加载器中的元素。
#next() 函数用于获取迭代器中的下一个元素。在这里,它被用来获取 train_dl 中的下一个批量数据。
#它将从 next() 函数返回的元素中提取出两个变量:imgs 和 labels
imgs, labels = next(iter(train_d))
"""imgs 变量将包含一个批量的图像数据,而 labels 变量将包含相应的标签数据。"""

展示一下刚刚导入的数据

#指定图片大小,图像大小为20宽、5高的绘图(单位为英寸inch)
plt.figure(figsize=(20, 5))
for i, imgs in enumerate(imgs[:20]):
    #squeeze()函数的功能是从矩阵shape中,去掉维度为1的,实现维度缩减
    nping = np.squeeze(imgs.numpy())
    # 将整个figure分成2行10列,绘制第i+1个子图
    plt.subplot(2, 10
评论 1
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值