1.背景介绍
1.1 数据集来源
本文使用的是MNIST数据集,里面包含了0到9十种数字的28*28像素规格的灰度图片,总共包含6万张训练数据集和1万张的测试数据集 由Yann LeCun、Corinna Cortes和Christopher Burges他们在1998年发布,用于手写数字的图像识别任务.
2.引入包导入数据集
2.1 引入外部包
首先先导入我们本次需要使用的所有外部的的
import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
import torchvision
from util import plot_curve,plot_image,one_hot
util类的代码如下 自行编写并引入或者直接去除util类的导入直接把引入的方法写在当前类里
import torch
from matplotlib import pyplot as plt
def plot_curve(data):
"""
下降曲线的绘制
:param data: 损失值列表
:return: None
"""
fig = plt.figure()
plt.plot(range(len(data)), data, color='blue') # 绘制损失值曲线
plt.legend(['value'], loc='upper right') # 图例
plt.xlabel('step') # x轴标签
plt.ylabel('value') # y轴标签
plt.show() # 显示图形
def plot_image(img, label, name):
"""
可视化识别结果
:param img: 图像张量
:param label: 标签张量
:param name: 图像标题
:return: None
"""
fig = plt.figure()
for i in range(6): # 显示前 6 张图像
plt.subplot(2, 3, i + 1) # 创建 2 行 3 列的子图
plt.tight_layout() # 调整子图布局
plt.imshow(img[i][0] * 0.3081 + 0.1307, cmap='gray', interpolation='none') # 显示图像
plt.title("{}: {}".format(name, label[i].item())) # 设置标题
plt.xticks([]) # 不显示 x 轴刻度
plt.yticks([]) # 不显示 y 轴刻度
plt.show() # 显示图形
def one_hot(labels, depth=10):
"""
将标签转换为 one-hot 编码
:param labels: 标签张量
:param depth: 类别数(即 one-hot 编码的长度)
:return: one-hot 编码张量
"""
# 确保标签是 Long 类型
labels = labels.long()
# 创建一个全零的张量
out = torch.zeros(labels.size(0), depth)
# 将标签索引变为一列的张量
idx = labels.view(-1, 1)
# 使用 scatter 操作将值设为 1
out.scatter_(dim=1, index=idx, value=1)
return out
#plot_curve 用于绘制损失曲线,监控训练过程。
#plot_image 用于展示图像和预测结果,帮助检查模型的输出。
#one_hot 用于将标签转换为 one-hot 编码,是分类任务中常见的预处理步骤。
2.2 加载MNIST数据集
# 第一步加载数据集
def load_dataset():
#一次性处理数据集图片的数量(根据电脑性能自行调整)
batch_size = 512
#加载mnist训练数据集
train_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('mnist_data', train=True, download=True,