引言:
MNIST(Modified National Institute of Standards and Technology)数据集是一个经典的入门级数据集,用于手写数字识别,涵盖了0到9这十个数字。尽管其数据规模较小,但MNIST数据集因其简单性和广泛应用,成为了众多机器学习初学者和研究人员的首选练习项目。
在这篇博客中,我们将基于pytorch实现对MNIST数据集的高效分类。我们将详细介绍数据预处理、模型选择、训练过程以及最终的模型评估和优化方法,希望通过这一过程能够帮助读者更好地理解和掌握图像分类任务的核心概念和实践技巧。
本篇博客是对唐宇迪ai的学习笔记。
环境介绍:
cuda 12.3
python 3.9.2
pytorch 2.3.1 stable
1. 读取数据集
下载数据
from pathlib import Path
import requests
DATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"
PATH.mkdir(parents=True, exist_ok=True)
URL = "http://deeplearning.net/data/mnist/"
FILENAME = "mnist.pkl.gz"
if not (PATH / FILENAME).exists():
content = requests.get(URL + FILENAME).content
(PATH / FILENAME).open("wb").write(content)
解压数据
import pickle
import gzip
with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:
((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")
查看数据图
from matplotlib import pyplot
import numpy as np
#numpy是一个处理多为数组的计算库,pytorch是基于numpy之上的机器学习库
pyplot.imshow(x_train[2].reshape((28, 28)), cmap="gray")
#将x_train的第二个样本的元素值重塑为28*28的矩阵图,并用“灰度”来映射该图像
print(x_train.shape)
将变量转化为tensor格式
x_train, y_train, x_valid, y_valid = map(torch.tensor, (x_train, y_train, x_valid, y_valid))
#将这四个变量转化为tensor数组格式