小白学Pytorch使用(1):Mnist手写数字集分类
任务背景
使用Pytorch框架搭建神经网络进行Mnist手写数字集分类。数据可以自行下载进行本地链接,也可直接通过代码下载。
mnist数据集下载链接
一、导入库
# pathlib路径操作模块
from pathlib import Path
# requests是⽤Python语⾔编写,基于urllib,采⽤Apache2 Licensed开源协议的 HTTP 库
import requests
# 持续化模块:把Python对象直接保存到文件里,而不需要先把它们转化为字符串再保存,也不需要用底层的文件访问操作,直接把它们写入到一个二进制文件里
import pickle
# 压缩和解压缩模块
import gzip
# 数组及数组运算函数库
import numpy as np
# 2D绘图库
from matplotlib import pyplot
# Pytorch框架
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
from torch import optim
二、下载数据集
# 下载mnist手写数字集,Path中更换自己的文件路径
DATA_PATH = Path("D:/咕泡人工智能-配套资料/配套资料/4.第四章 深度学习核⼼框架PyTorch/第二,三章:神经网络实战分类与回归任务/神经网络实战分类与回归任务/data")
# Mnist数据集存储文件夹(可舍去)
PATH = DATA_PATH/'mnist'
# 建立PATH路径,若自己手动建立也可注释掉
PATH.mkdir(parents=True, exist_ok=True)
# Mnist数据集下载网页
# "http://deeplearning.net/data/mnist/"是老师给的下载网页,但我这无法连接
URL = "https://resources.oreilly.com/live-training/inside-unsupervised-learning/-/tree/master/data/mnist_data/"
# 数据集名称mnist.pkl.gz,可打开网页验证
FILENAME = 'mnist.pkl.gz'
# 若本地没有相应文件则进行下载
if not (PATH/FILENAME).exists():
content = requests.get(URL + FILENAME).content # requests.get()请求指定的页面信息,并返回实体主体
(PATH/FILENAME).open('wb').write(content)
三、处理数据集
# 打开压缩包获取训练集、测试集,gzip进行解压
with gzip.open