【深度学习】用pytorch实现数字识别

1.背景介绍


      1.1 数据集来源

      本文使用的是MNIST数据集,里面包含了0到9十种数字的28*28像素规格的灰度图片,总共包含6万张训练数据集和1万张的测试数据集 由Yann LeCun、Corinna Cortes和Christopher Burges他们在1998年发布,用于手写数字的图像识别任务.

2.引入包导入数据集

     2.1 引入外部包 

     首先先导入我们本次需要使用的所有外部的的


import torch
from torch import nn
from torch.nn import functional as F
from torch import optim

import torchvision
from   util import plot_curve,plot_image,one_hot

    util类的代码如下 自行编写并引入或者直接去除util类的导入直接把引入的方法写在当前类里

import torch
from matplotlib import pyplot as plt


def plot_curve(data):
    """
    下降曲线的绘制
    :param data: 损失值列表
    :return: None
    """
    fig = plt.figure()
    plt.plot(range(len(data)), data, color='blue')  # 绘制损失值曲线
    plt.legend(['value'], loc='upper right')  # 图例
    plt.xlabel('step')  # x轴标签
    plt.ylabel('value')  # y轴标签
    plt.show()  # 显示图形



def plot_image(img, label, name):
    """
    可视化识别结果
    :param img: 图像张量
    :param label: 标签张量
    :param name: 图像标题
    :return: None
    """
    fig = plt.figure()
    for i in range(6):  # 显示前 6 张图像
        plt.subplot(2, 3, i + 1)  # 创建 2 行 3 列的子图
        plt.tight_layout()  # 调整子图布局
        plt.imshow(img[i][0] * 0.3081 + 0.1307, cmap='gray', interpolation='none')  # 显示图像
        plt.title("{}: {}".format(name, label[i].item()))  # 设置标题
        plt.xticks([])  # 不显示 x 轴刻度
        plt.yticks([])  # 不显示 y 轴刻度
    plt.show()  # 显示图形



def one_hot(labels, depth=10):
    """
    将标签转换为 one-hot 编码
    :param labels: 标签张量
    :param depth: 类别数(即 one-hot 编码的长度)
    :return: one-hot 编码张量
    """
    # 确保标签是 Long 类型
    labels = labels.long()

    # 创建一个全零的张量
    out = torch.zeros(labels.size(0), depth)

    # 将标签索引变为一列的张量
    idx = labels.view(-1, 1)

    # 使用 scatter 操作将值设为 1
    out.scatter_(dim=1, index=idx, value=1)

    return out


#plot_curve 用于绘制损失曲线,监控训练过程。
#plot_image 用于展示图像和预测结果,帮助检查模型的输出。
#one_hot 用于将标签转换为 one-hot 编码,是分类任务中常见的预处理步骤。
   2.2 加载MNIST数据集
# 第一步加载数据集
def load_dataset():
    #一次性处理数据集图片的数量(根据电脑性能自行调整)
    batch_size = 512

    #加载mnist训练数据集
    train_loader = torch.utils.data.DataLoader(
         
        torchvision.datasets.MNIST('mnist_data', train=True, download=True,
                    
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值