手写数字识别

本文通过Python的PyTorch库实现手写数字识别,利用MNIST数据集训练一个具有三层隐藏层的多层感知机模型。探讨了不同优化算法(如SGD和Adam)以及超参数对模型训练的影响,观察了损失函数随训练迭代的变化。实验结果显示,优化算法的选择和模型结构对模型性能至关重要。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

手写数字识别

说在前面的话:

本文纯属记录学习,非原创,内容均来自网络的搜集,纯小白中的小白,大神请绕道,持续更新中…

Introduction:

该实验可以分为以下几个步骤:

  1. 导入相关的库
  2. 获取实验数据:MNIST
  3. 定义超参数(ps.超参数是人为给定的,而模型参数是要根据数据驱动来调整的.)
  4. 搭建模型
  5. 定义优化器
  6. 训练模型
  7. 测试模型及输出结果

遇到的相关概念以及名词的解释

MNIST数据集介绍:

灰度照片,尺寸为28*28
总共70000张照片(60000用作train,10000用作test)

优化算法

只知道SGD和Adam

神经网络分类及相关问题

在这里插入图片描述

fnn(前馈神经网络)的某些分支做些介绍:

  1. 单层感知器(perceptron):只有input和output
  2. 全连接神经网络(相邻两层之间的所有神经元之间都有连接)
  3. 卷积神经网络{CNN}(基本结构:输入层,卷积层,池化层,全连接层,输出层):相邻两层之间部分神经元之间有连接
  4. BP神经网络
  5. 循环神经网络

神经网络里面的一些概念:

  • 激活函数,一般分为以下几类:
  1. 线性纠正函数(ReLU){常见于卷积层}

  2. Sigmoid函数(常见于全连接层){把实数压缩到0~1之间}

  3. tanh(x)函数(常见于全连接层){把实数压缩到-1~1之间}

  4. 径向基函数

导入相关的库

import torch
import torchvision
from torch import nn
form torch.utils.data import DateLoader
  1. torchvision 用于download MINST 数据集
  2. DATeLoader 用于加载MINST

定义超参数


epoch = 10

batch_size_train = 64

batch_size_test = 1000

learning_rate = 0.01

random_seed = 10

torch.manual.seed(random_seed)

seed的作用 : 使每次实验的结果近似接近

下载及加载相关的数据集

train_dataset = torchvision.datasets.MNIST(root = './data/',
                                          train = True,
                                          transform = torchvision.transforms.ToTensor(),
                                          download = True)

test_dataset = torchvision.datasets.MNIST(root = './data/',
                                        train = False,
                                        transform = torchvision.transforms.ToTensor(),
                                        download = True)

导入数据同时对数据进行处理(转化成张量)


train_loader = DataLoader(dataset = train_dataset,
                          batch_size = batch_size_train,
                          shuffle = True)

test_loader = DataLoader(dataset = test_dataset,
                         batch_size = batch_size_test,
                         shuffle = True)

加载数据

搭建模型

# 构建模型(搭积木)(搭建一个具有三层隐藏层的多层感知机)

class my_MPL(nn.Module):
    def __init__(self):
        super(my_MPL, self).__init__()

        # 开始搭建

        self.fc0 = nn.Linear(28*28 , 32*14*14)
        self.fc1 = nn.Linear(32*14*14,1024)
        self.fc2 = nn.Linear(1024,512)
        self.fc3 = nn.Linear(512,10

        self.relu0 = nn.ReLU()
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.relu3 = nn.ReLU()

    def forward(self,x):
        x = x.view(-1, 28*28)
        x = self.fc0(x)
        x = self.relu0(x)
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        x = self.fc3(x)
        x = self.relu3(x)

        return x

MLP = my_MPL()

做一些解释:
x = x.view(-1, 28*28)
-1表示不确定的数(相当于占一个位)

super(my_MPL, self).__ init __()
表示用父类初始化的方法来初始化父类那里继承来的属性

MLP = my_MPL()
()括号别忘

nn.Linear(2828 , 3214*14)
这里的全连接层输入输出为二维张量(卷积层为四维)

定义损失函数和优化器

# 定义损失函数和优化器

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(MLP.parameters(), lr=learning_rate, momentum = 0.9)

解释:
criterion = nn.CrossEntropyLoss()
交叉熵损失函数

optimizer = optim.SGD(MLP.parameters(), lr=learning_rate, momentum = 0.9)
这里用的优化函数是SGD
momentum 动量

训练模型及结果输出

# 训练模型及结果输出

device = torch.device("cpu")


MLP = MLP.to(device)

for epoch in range(epochs):
    running_loss = 0.0
    for i, data in enumerate(train_loader, 1):
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()

        outputs = MLP(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        if i % 100 == 0:
            print(f'{epoch}, {i}:{running_loss / 100}')
            running_loss = 0.0

解释:
device = torch.device(“cpu”)
设备为cpu

loss = criterion(outputs, labels)
比较output和标准之间的误差

数据记录

1.第一组
第一组
在这里插入图片描述

data:
0, 100:2.105787729024887
0, 200:1.3007640141248702
0, 300:1.0818366277217866
0, 400:0.9986378115415573
0, 500:0.9380956190824509
0, 600:0.9220479035377502
0, 700:0.8981404012441635
0, 800:0.8662663304805756
0, 900:0.7892363598942757
1, 100:0.6400791090726853
1, 200:0.6169854035973549
1, 300:0.5674494902789593
1, 400:0.5978360724449158
1, 500:0.5558451463282108
1, 600:0.5791972890496254
1, 700:0.5771136912703514
1, 800:0.546100759357214
1, 900:0.5670404487848282
2, 100:0.536649158000946
2, 200:0.5216396081447602
2, 300:0.515502537637949
2, 400:0.5366841191053391
2, 500:0.5262447279691697
2, 600:0.531260186880827
2, 700:0.5301781237125397
2, 800:0.5141828963160515
2, 900:0.5275057528913021

  1. 第二组(相比与第一组没有动量)

data:
0, 100:2.293597095012665
0, 200:2.2681656312942504
0, 300:2.233996319770813
0, 400:2.1706444311141966
0, 500:2.0568545377254486
0, 600:1.9085291373729705
0, 700:1.7539910292625427
0, 800:1.5965783751010896
0, 900:1.48703693151474
1, 100:1.3507489502429961
1, 200:1.295447991490364
1, 300:1.197473978996277
1, 400:1.1746511435508729
1, 500:1.112881744503975
1, 600:1.1031613755226135
1, 700:1.1045008379220962
1, 800:1.0443149244785308
1, 900:1.0650411289930344
2, 100:1.0381091105937958
2, 200:1.0188828611373901
2, 300:1.0006711131334305
2, 400:1.0161162036657334
2, 500:1.0020228821039199
2, 600:0.9793965613842011
2, 700:0.9800641161203384
2, 800:0.9899540412425994
2, 900:0.9670715373754502

  1. 第三组(让我们多加几层)
    在这里插入图片描述
    data:giao 加多了,内存不够了(xs)无数据了(报错137)

  2. 第四组(这次温和一点)
    在这里插入图片描述

data:
0, 100:2.3005282425880433
0, 200:2.277735002040863
0, 300:1.9812933230400085
0, 400:1.3266960901021958
0, 500:1.0656914275884628
0, 600:0.9701974713802337
0, 700:0.9579001593589783
0, 800:0.8922015309333802
0, 900:0.8835197502374649
1, 100:0.8535557878017426
1, 200:0.8409706902503967
1, 300:0.8235989719629287
1, 400:0.8365458434820175
1, 500:0.8051434952020645
1, 600:0.8160263752937317
1, 700:0.8093277823925018
1, 800:0.8010191100835801
1, 900:0.8239868038892746
2, 100:0.7825601372122765
2, 200:0.7770071613788605
2, 300:0.7663508641719818
2, 400:0.7956782352924346
2, 500:0.7669226723909378
2, 600:0.7003127360343933
2, 700:0.5361837792396545
2, 800:0.541669200360775
2, 900:0.520277413725853

  1. 第五组

在这里插入图片描述

changes:
321414 – 641414

data:

0, 100:2.2994043684005736
0, 200:2.2818156337738036
0, 300:2.0806565570831297
0, 400:1.6807345521450043
0, 500:1.5835756409168242
0, 600:1.5085526263713838
0, 700:1.4409454703330993
0, 800:1.3455421423912048
0, 900:1.271108170747757
1, 100:1.2577989375591279
1, 200:1.2352607268095017
1, 300:1.2508223408460617
1, 400:1.2190198016166687
1, 500:1.221548267006874
1, 600:1.2547485852241516
1, 700:1.2196426820755004
1, 800:1.2275055664777756
1, 900:1.23593525826931
2, 100:1.2183733952045441
2, 200:1.1852529442310333
2, 300:1.183802189230919
2, 400:1.1587913089990616
2, 500:1.1841296607255936
2, 600:1.1642793798446656
2, 700:1.1936270797252655
2, 800:1.1885065108537674
2, 900:1.1721717804670333

  1. 第六组

在这里插入图片描述

data:
0, 100:2.153420286178589
0, 200:1.3807597470283508
0, 300:0.8610023042559624
0, 400:0.3984715388715267
0, 500:0.3003653420507908
0, 600:0.26110438015311954
0, 700:0.23661019951105117
0, 800:0.1910202322900295
0, 900:0.19464803928509355
1, 100:0.14816518139094115
1, 200:0.14610067680478095
1, 300:0.14486520748585463
1, 400:0.1376293076016009
1, 500:0.11388071848079562
1, 600:0.1287291205301881
1, 700:0.10425606964156031
1, 800:0.11403923080302775
1, 900:0.09714097636751831
2, 100:0.0853013839572668
2, 200:0.07920281968079507
2, 300:0.0825724380556494
2, 400:0.07912950337398797
2, 500:0.07034362030215562
2, 600:0.06381631655618548
2, 700:0.06913053676486015
2, 800:0.07567000338807701
2, 900:0.07534596136771142

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值