Mnist数字分类任务-学习笔记

引言:

MNIST(Modified National Institute of Standards and Technology)数据集是一个经典的入门级数据集,用于手写数字识别,涵盖了0到9这十个数字。尽管其数据规模较小,但MNIST数据集因其简单性和广泛应用,成为了众多机器学习初学者和研究人员的首选练习项目。

在这篇博客中,我们将基于pytorch实现对MNIST数据集的高效分类。我们将详细介绍数据预处理、模型选择、训练过程以及最终的模型评估和优化方法,希望通过这一过程能够帮助读者更好地理解和掌握图像分类任务的核心概念和实践技巧。

本篇博客是对唐宇迪ai的学习笔记。

环境介绍:

cuda 12.3

python 3.9.2

pytorch 2.3.1 stable

1. 读取数据集

下载数据

from pathlib import Path
import requests

DATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"

PATH.mkdir(parents=True, exist_ok=True)

URL = "http://deeplearning.net/data/mnist/"
FILENAME = "mnist.pkl.gz"

if not (PATH / FILENAME).exists():
        content = requests.get(URL + FILENAME).content
        (PATH / FILENAME).open("wb").write(content)

解压数据

import pickle
import gzip

with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:
        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")

查看数据图

from matplotlib import pyplot
import numpy as np
#numpy是一个处理多为数组的计算库,pytorch是基于numpy之上的机器学习库

pyplot.imshow(x_train[2].reshape((28, 28)), cmap="gray")
#将x_train的第二个样本的元素值重塑为28*28的矩阵图,并用“灰度”来映射该图像

print(x_train.shape)

将变量转化为tensor格式

x_train, y_train, x_valid, y_valid = map(torch.tensor, (x_train, y_train, x_valid, y_valid))
#将这四个变量转化为tensor数组格式
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值