第P1周:Pytorch实现mnist手写数字识别
导入库
import torch
import torch. nn as nn
import matplotlib. pyplot as plt
import torchvision
import numpy as np
import torch. nn. functional as F
import warnings
使用dataset下载MNIST数据集,并划分好训练集与测试集
train_data = torchvision. datasets. MNIST( "data" ,
train= True ,
transform = torchvision. transforms. ToTensor( ) ,
download= True )
test_data = torchvision. datasets. MNIST( "data" ,
train= False ,
transform= torchvision. transforms. ToTensor( ) ,
download= True )
batch_size = 32
使用数据局加载器,处理导入的数据
train_d = torch. utils. data. DataLoader( train_data,
batch_size = batch_size,
shuffle = True )
test_d = torch. utils. data. DataLoader( test_data,
batch_size = batch_size)
imgs, labels = next ( iter ( train_d) )
"""imgs 变量将包含一个批量的图像数据,而 labels 变量将包含相应的标签数据。"""
展示一下刚刚导入的数据
plt. figure( figsize= ( 20 , 5 ) )
for i, imgs in enumerate ( imgs[ : 20 ] ) :
nping = np. squeeze( imgs. numpy( ) )
plt. subplot( 2 , 10