深度学习之使用BP神经网络识别MNIST数据集

本文围绕基于PyTorch的BP神经网络手写数字识别展开。先介绍torch.nn.LogSoftmax等补充知识点,包括数据处理、数据集、损失函数等相关内容;接着进行代码实现,涵盖BP网络搭建、建立神经网络对象、预测等步骤,最终经15轮训练基本能完成手写数字识别。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

目录

补充知识点

torch.nn.LogSoftmax()

 torchvision.transforms

transforms.Compose

transforms.ToTensor

transforms.Normalize(mean, std)

torchvision.datasets

MNIST(手写数字数据集)

torch.utils.data.DataLoader

torch.nn.NLLLoss() 

torch.nn.CrossEntropyLoss()

torch.nn.NLLLoss() 

enumerate() 函数

next() 函数

pytorch中.detach()

torch.squeeze()

torch.optim.SGD 

独热编码

代码实现

bp网络的搭建

建立我们的神经网络对象

预测

完整代码

结果极其图像显示


补充知识点

torch.nn.LogSoftmax()

torch.nn.LogSoftmax()和我们的softmax差不多只不过就是最后加入了一个log,关于softmax的详情大家可以看这一篇博客深度学习之感知机,激活函数,梯度消失,BP神经网络

我们只需要知道里面的一些常用参数就好,一般就是这个dim

下面是softmax的logsoftmax和他一样。

 torchvision.transforms

这个是一个功能强大的数据处理集合库

这里主要说一下transforms.Compose还有transforms.ToTensor以及transforms.Normalize(mean, std)

transforms.Compose

这个功能函数可以看作一个功能函数容器,它里面可以放多个功能函数(多个的话应该把这些功能函数放在一个列表内),当该功能函数定义好后,其内部的其他功能函数也会随之按我们给定的要求定义好,当我们调用compose的实例的时候,就会按照我们在容器内部摆放的顺序从左至右的依次调用功能函数

transforms.ToTensor

用于对载入的图片数据进行类型转换,将之前PIL图片的数据(应该是np.array()类型的)转换成Tensor数据类型的张量,以便于我们后续的数据使用

(补充:PIL默认输出的图片格式为 RGB)

(下面图片内容来自网络,侵权必删。)

transforms.Normalize(mean, std)

数据归一化处理。

下面是我在学习过程中遇到的问题:

1.归一化就是要把图片3个通道中的数据整理到[-1, 1]或者[0,1]区间,x = (x - mean(x))/std(x)只要输入数据集x就可以直接算出来,为什么Normalize()函数的mean和std(标准差)还需要我们手动输入数值呢?

我的理解是,我们一开始就算好可以极大的减少运算量,如果我们自动的让他算的话,我们这个每一个图片都要算,这样运算量就极大。

2.RGB单个通道的值是[0, 255],所以一个通道的均值应该在127附近才对。我们接下来的代码如下图所示

所填的是0.5,0.5,这是为什么?

 因为我们应用了torchvision.transforms.ToTensor,他会将数据归一化到[0,1](是将数据除以255),transforms.ToTensor( )会把HWC会变成C *H *W(拓展:格式为(h,w,c),像素顺序为RGB),所以我们就应该输入0.5,0.5

(该图片内容来自网络,侵权璧必删)

我们一般来说只需了解一些常用的库函数,我们这里只是提一下我们这次会用到的函数,其余的函数若想了解的话,推荐大家看这位作者写的博客PyTorch之torchvision.tra()nsforms详解[原理+代码实现]-优快云博客

torchvision.datasets

这里面包含了一些我们目前常用的一些数据集

我们这里主要讲一下mnist 

MNIST(手写数字数据集)

MNIST 数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST). 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set) 也是同样比例的手写数字数据。

在我们pytorch里面的torchvision里面是有的,我们可以直接用。

torchvision.datasets.MNIST(root, train=True, transform=None, target_transform=None, download=False)

相关参数:

root:就是我们从网上下载的数据集所放的目录,也就是文件路径

train:train=True意思就是下载的是训练集,如果train=Flase,那就是下载的是测试集

transform: 因为我们的数据直接下载过后一般都是要进行处理,所以这个后面跟的是一个transforms数据处理函数的一个实例化对象

download:为True就是从网络上下载数据集保存在我们的root路径中,为False就是不下载。

其余参数可自行查阅

torch.utils.data.DataLoader

(图片内容来自网络,侵权必删)

(其实我们记住常用的,dataset,batch_size,shuffle这三个常用的参数就好了)

torch.nn.NLLLoss() 

torch.nn.CrossEntropyLoss()

首先我们要了解交叉熵损失函数,torch.nn.CrossEntropyLoss()

什么是熵?

熵是用来描述一个系统的混乱程度,通过交叉熵我们就能够确定预测数据与真是数据之间的相近程度。交叉熵越小,表示数据越接近真实样本。

(预测的概率就是我们的预测值的准确值)

torch.nn.NLLLoss() 

 torch.nn.NLLLoss输入是一个对数概率向量和一个目标标签,它与torch.nn.CrossEntropyLoss的关系可以描述为:

假设有张量x,先softmax(x)得到y,然后再log(y)得到z,然后我们已知标签b,则:

NLLLoss(z,b)=CrossEntropyLoss(x,b)

代码:

nllloss = nn.NLLLoss()
predict = torch.Tensor([[2, 3, 1],
                        [3, 7, 9]])
predict = torch.log(torch.softmax(predict, dim=-1))
label = torch.tensor([1, 2])
nllloss(predict, label)
运行结果:tensor(0.2684)

而我们用 torch.nn.CrossEntropyLoss

cross_loss = nn.CrossEntropyLoss()

predict = torch.Tensor([[2, 3, 1],
                        [3, 7, 9]])
label = torch.tensor([1, 2])
cross_loss(predict, label)
运行结果: tensor(0.2684)

enumerate() 函数

这是python的一个内置函数。

enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中。

Python 2.3. 以上版本可用,2.6 添加 start 参数。

enumerate(sequence, [start=0])
sequence -- 一个序列、迭代器或其他支持迭代对象。
start -- 下标起始位置的值。
例如:
>>> seasons = ['Spring', 'Summer', 'Fall', 'Winter']
>>> list(enumerate(seasons))
[(0, 'Spring'), (1, 'Summer'), (2, 'Fall'), (3, 'Winter')]
>>> list(enumerate(seasons, start=1))       # 下标从 1 开始
[(1, 'Spring'), (2, 'Summer'), (3, 'Fall'), (4, 'Winter')]

next() 函数

python的内置函数

next() 返回迭代器的下一个项目。

next() 函数要和生成迭代器的 iter() 函数一起使用。

返回值:返回下一个项目。

next(iterable[, default])
iterable -- 可迭代对象
default -- 可选,用于设置在没有下一个元素时返回该默认值,如果不设置,又没有下一个元素则会触发 StopIteration 异常。
例如:
#!/usr/bin/python
# -*- coding: UTF-8 -*-
 
# 首先获得Iterator对象:
it = iter([1, 2, 3, 4, 5])
# 循环:
while True:
    try:
        # 获得下一个值:
        x = next(it)
        print(x)
    except StopIteration:
        # 遇到StopIteration就退出循环
        break
结果:
1
2
3
4
5

pytorch中.detach()

在 PyTorch 中,detach() 方法用于返回一个新的 Tensor,这个 Tensor 和原来的 Tensor 共享相同的内存空间,但是不会被计算图所追踪,也就是说它不会参与反向传播,不会影响到原有的计算图,这使得它成为处理中间结果的一种有效方式,通常在以下两种情况下使用:

第一:在计算图中间,需要截断反向传播的梯度计算时。例如,当计算某个 Tensor 的梯度时,我们希望在此处截断反向传播,而不是将梯度一直传递到计算图的顶部,从而减少计算量和内存占用。此时可以使用 detach() 方法将 Tensor 分离出来。

第二:在将 Tensor 从 GPU 上拷贝到 CPU 上时,由于 Tensor 默认是在 GPU 上存储的,所以直接进行拷贝可能会导致内存不一致的问题。此时可以使用 detach() 方法先将 Tensor 分离出来,然后再将分离出来的 Tensor 拷贝到 CPU 上。

torch.squeeze()

详情请看这篇博客深度学习之张量的处理(代码笔记)

torch.optim.SGD 

torch.optim.SGD 是 PyTorch 中用于实现随机梯度下降(Stochastic Gradient Descent,SGD)优化算法的类。SGD 是一种常用的优化算法。

原理部份可以看我的这篇博客:机器学习优化算法(深度学习)-优快云博客

我们主要介绍一下常用的参数:

torch.optim.SGD(params, lr=<required parameter>, momentum=0, dampening=0, weight_decay=0, nesterov=False)

(图片内容来自网络,侵权必删) 

独热编码

就是如果是0-9这十个数,那我们就用[1,0,0,0,0,0,0,0,0,0]表示0,[0,1,0,0,0,0,0,0,0,0]表示1,等等

其余同理

代码实现

关于bp神经网络的原理,我们可以看这篇博客:深度学习之感知机,激活函数,梯度消失,BP神经网络-优快云博客

我们接下来就只是讲解代码实现。

bp网络的搭建

#搭建bp神经网络
class BPNetwork(torch.nn.Module):
    def __init__(self):
        super(BPNetwork,self).__init__()
        #我们的每张图片都是28*28也就是784个像素点
        #第一个隐藏层
        self.linear1=torch.nn.Linear(784,128)
        #激活函数,这里选择Relu
        self.relu1=torch.nn.ReLU()
        #第二个隐藏层
        self.linear2=torch.nn.Linear(128,64)
        #激活函数
        self.relu2=torch.nn.ReLU()
        #第三个隐藏层:
        self.linear3=torch.nn.Linear(64,32)
        # 激活函数
        self.relu3 = torch.nn.ReLU()
        #输出层
        self.linear4=torch.nn.Linear(32,10)
        # 激活函数
        self.softmax=torch.nn.LogSoftmax()
    #前向传播
    def forward(self,x):
        #修改每一个批次的样本集尺寸,修改为64*784,因为我们的图片是28*28
        x=x.reshape(x.shape[0],-1)
        #前向传播
        x=self.linear1(x)#784*128
        x=self.relu1(x)
        x=self.linear2(x)#128*64
        x=self.relu2(x)
        x=self.linear3(x)#64*32
        x=self.relu3(x)
        x=self.linear4(x)#输出层32*10
        x=self.softmax(x)#最后输出的数值我们需要利用到独热编码的思想
        #上面的这些都可以这几使用x=self.model(x)来代替,为什么能用它,我的理解是,我们继承的class moudle 然后对立面写好的模型框架进行定义,而这个方法就是可以直接调用我们定义好的神经网络
        return x

一些关键点都在注释上

搭建这次的BP神经网络我的隐藏层有三层,分别是128,64,32个神经元,因为我们的图片是28*28=784得,我们需要把其展开成一维,所以第一层网络是784*128得,这样输入层中每一行代表一个样本(或者说一张图片得所有像素点),因为我们的每个神经元都有一个参数,第一层网络中每一列都是一个神经元对应的参数,所以最后就是n*784和784*128两个矩阵相乘,最后得到n*128得矩阵,以此类推,最后因为我们的输出要用到独热编码的思想,所以我们的输出层调整为10个神经元,或者说最后得线性网络是32*10。

建立我们的神经网络对象

#建立我们的神经网络对象
model=BPNetwork()
#定义损失函数
critimizer=torch.nn.NLLLoss()
#定义优化器
optimizer=torch.optim.SGD(model.parameters(),lr=0.003,momentum=0.9)
epochs=15#循环得轮数
#每轮抽取次数的和
a=0
loss_=[]
a_=[]
font = font_manager.FontProperties(fname="C:\\Users\\ASUS\\Desktop\\Fonts\\STZHONGS.TTF")
for i in range(epochs):
    # 损失值参数
    sumloss = 0
    for imges,labels in trainload:
        a+=1
        #前向传播
        output=model(imges)
        #反向传播
        loss=critimizer(output,labels)
        loss.backward()
        #参数更新
        optimizer.step()
        #梯度清零
        optimizer.zero_grad()
        #损失累加
        sumloss+=loss.item()
    loss_.append(sumloss)
    a_.append(a)
    print(f"第{i+1}轮的损失:{sumloss},抽取次数和:{a}")
plt.figure()
plt.plot(a_,loss_)
plt.title('损失值随着抽取总次数得变化情况:',fontproperties=font, fontsize=18)
plt.show()

注意看注释。

SGD还有损失函数等等上面的补充内容里面都有说明,这里不在多加阐述。

预测

#开始预测
example=enumerate(testLoad)#从测试集里面随机抽64份并且记录下来里面的内容和下标
batch_index,(imagess,labelss)=next(example)
# bath_index=0
# imagess=0
# labelss=0
# for i,j in example:
#     bath_index=i
#     (imagess, labelss)=j
fig=plt.figure()
for i in range(64):
    pre=model(imagess[i])#预测
    #第一张图片对应的pre得格式:
    # print(pre)
    # tensor([[-2.7053e+01, -1.1105e-03, -1.2767e+01, -1.1126e+01, -1.6005e+01,
    #          -2.0953e+01, -2.3342e+01, -6.8246e+00, -1.2127e+01, -1.6131e+01]],
    #        grad_fn= < LogSoftmaxBackward0 >)
    #接下来我们要用到独热编码的思想,我们取最大的数,也就是最高的概率对应得下标,就相当于这个最高概率对应得独热编码里面的1,其他是0
    pro = list(pre.detach().numpy()[0])
    pre_label=pro.index(max(pro))
    #print(pre_label)

注意看注释

我们实际上也可以不用next()我们可以直接用for循环比如这样。

import d2l.torch as d2l
import math
import torch
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.5,0.5)])
traindata=torchvision.datasets.MNIST(root='D:\learn_pytorch\数据集',train=True,download=True,transform=transform)#训练集60,000张用于训练
testdata=torchvision.datasets.MNIST(root='D:\learn_pytorch\数据集',train=False,download=True,transform=transform)#测试集10,000张用于测试
#利用DataLoader加载数据集
trainload=DataLoader(dataset=traindata,shuffle=True,batch_size=64)
testLoad=DataLoader(dataset=testdata,shuffle=False,batch_size=64)
example=enumerate(testLoad)#从测试集里面随机抽64份并且记录下来里面的内容和下标
a=0
images=0
labelss=0
for i,j in example:
    a+=1
    index = i
    (imagess, labelss) = j
    print(imagess[0])
    print('数据集中抽取的64份数据的纯数据部份的尺寸:',imagess.shape)
    print(imagess[0].shape)
    print(labelss[0])
    print(labelss[0].shape)

    if a==1:
        break

这样得到的结果我们可以看到:

D:\Anaconda3\envs\pytorch\python.exe D:\learn_pytorch\prictice.py 
tensor([[[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.3412,
           0.4510,  0.2471,  0.1843, -0.5294, -0.7176, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,  0.7412,
           0.9922,  0.9922,  0.9922,  0.9922,  0.8902,  0.5529,  0.5529,
           0.5529,  0.5529,  0.5529,  0.5529,  0.5529,  0.5529,  0.3333,
          -0.5922, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.4745,
          -0.1059, -0.4353, -0.1059,  0.2784,  0.7804,  0.9922,  0.7647,
           0.9922,  0.9922,  0.9922,  0.9608,  0.7961,  0.9922,  0.9922,
           0.0980, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -0.8667, -0.4824, -0.8902,
          -0.4745, -0.4745, -0.4745, -0.5373, -0.8353,  0.8510,  0.9922,
          -0.1686, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -0.3490,  0.9843,  0.6392,
          -0.8588, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -0.8275,  0.8275,  1.0000, -0.3490,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000,  0.0118,  0.9922,  0.8667, -0.6549,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -0.5373,  0.9529,  0.9922, -0.5137, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000,  0.0431,  0.9922,  0.4667, -0.9608, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -0.9294,  0.6078,  0.9451, -0.5451, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -0.0118,  0.9922,  0.4275, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -0.4118,  0.9686,  0.8824, -0.5529, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.8510,
           0.7333,  0.9922,  0.3020, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.9765,  0.5922,
           0.9922,  0.7176, -0.7255, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.7020,  0.9922,
           0.9922, -0.3961, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -0.7569,  0.7569,  0.9922,
          -0.0980, -0.9922, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000,  0.0431,  0.9922,  0.9922,
          -0.5922, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -0.5216,  0.8980,  0.9922,  0.9922,
          -0.5922, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -0.0510,  0.9922,  0.9922,  0.7176,
          -0.6863, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -0.0510,  0.9922,  0.6235, -0.8588,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000]]])
数据集中抽取的64份数据的纯数据部份的尺寸: torch.Size([64, 1, 28, 28])
torch.Size([1, 28, 28])
tensor(7)
torch.Size([])

进程已结束,退出代码0

 我们可以看到一张图片得数据格式(这里面是已经归一化处理过的),因为我们的手写字体识别是单通道的灰度图,所以size是[1,28,28],这里很正确,对于彩色图三通道得来说,会有些不一样。

完整代码

import matplotlib.pyplot as plt
from matplotlib import font_manager

print('BP识别MNIST任务说明---------------------')
import torch
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
#导入数据集并且进行数据处理
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.5,0.5)])
traindata=torchvision.datasets.MNIST(root='D:\learn_pytorch\数据集',train=True,download=True,transform=transform)#训练集60,000张用于训练
testdata=torchvision.datasets.MNIST(root='D:\learn_pytorch\数据集',train=False,download=True,transform=transform)#测试集10,000张用于测试
#利用DataLoader加载数据集
trainload=DataLoader(dataset=traindata,shuffle=True,batch_size=64)
testLoad=DataLoader(dataset=testdata,shuffle=False,batch_size=64)
#搭建bp神经网络
class BPNetwork(torch.nn.Module):
    def __init__(self):
        super(BPNetwork,self).__init__()
        #我们的每张图片都是28*28也就是784个像素点
        #第一个隐藏层
        self.linear1=torch.nn.Linear(784,128)
        #激活函数,这里选择Relu
        self.relu1=torch.nn.ReLU()
        #第二个隐藏层
        self.linear2=torch.nn.Linear(128,64)
        #激活函数
        self.relu2=torch.nn.ReLU()
        #第三个隐藏层:
        self.linear3=torch.nn.Linear(64,32)
        # 激活函数
        self.relu3 = torch.nn.ReLU()
        #输出层
        self.linear4=torch.nn.Linear(32,10)
        # 激活函数
        self.softmax=torch.nn.LogSoftmax()
    #前向传播
    def forward(self,x):
        #修改每一个批次的样本集尺寸,修改为64*784,因为我们的图片是28*28
        x=x.reshape(x.shape[0],-1)
        #前向传播
        x=self.linear1(x)#784*128
        x=self.relu1(x)
        x=self.linear2(x)#128*64
        x=self.relu2(x)
        x=self.linear3(x)#64*32
        x=self.relu3(x)
        x=self.linear4(x)#输出层32*10
        x=self.softmax(x)#最后输出的数值我们需要利用到独热编码的思想
        #上面的这些都可以这几使用x=self.model(x)来代替,为什么能用它,我的理解是,我们继承的class moudle 然后对立面写好的模型框架进行定义,而这个方法就是可以直接调用我们定义好的神经网络
        return x
#建立我们的神经网络对象
model=BPNetwork()
#定义损失函数
critimizer=torch.nn.NLLLoss()
#定义优化器
optimizer=torch.optim.SGD(model.parameters(),lr=0.003,momentum=0.9)
epochs=15
#每轮抽取次数的和
a=0
loss_=[]
a_=[]
font = font_manager.FontProperties(fname="C:\\Users\\ASUS\\Desktop\\Fonts\\STZHONGS.TTF")
for i in range(epochs):
    # 损失值参数
    sumloss = 0
    for imges,labels in trainload:
        a+=1
        #前向传播
        output=model(imges)
        #反向传播
        loss=critimizer(output,labels)
        loss.backward()
        #参数更新
        optimizer.step()
        #梯度清零
        optimizer.zero_grad()
        #损失累加
        sumloss+=loss.item()
    loss_.append(sumloss)
    a_.append(a)
    print(f"第{i+1}轮的损失:{sumloss},抽取次数和:{a}")
plt.figure()
plt.plot(a_,loss_)
plt.title('损失值随着抽取总次数得变化情况:',fontproperties=font, fontsize=18)
plt.show()
#开始预测
example=enumerate(testLoad)#从测试集里面随机抽64份并且记录下来里面的内容和下标
batch_index,(imagess,labelss)=next(example)
# bath_index=0
# imagess=0
# labelss=0
# for i,j in example:
#     bath_index=i
#     (imagess, labelss)=j
fig=plt.figure()
for i in range(64):
    pre=model(imagess[i])#预测
    #第一张图片对应的pre得格式:
    # print(pre)
    # tensor([[-2.7053e+01, -1.1105e-03, -1.2767e+01, -1.1126e+01, -1.6005e+01,
    #          -2.0953e+01, -2.3342e+01, -6.8246e+00, -1.2127e+01, -1.6131e+01]],
    #        grad_fn= < LogSoftmaxBackward0 >)
    #接下来我们要用到独热编码的思想,我们取最大的数,也就是最高的概率对应得下标,就相当于这个最高概率对应得独热编码里面的1,其他是0
    pro = list(pre.detach().numpy()[0])
    pre_label=pro.index(max(pro))
    #print(pre_label)
    #图像显示
    img=torch.squeeze(imagess[i]).numpy()

    plt.subplot(8,8,i+1)
    plt.tight_layout()
    plt.imshow(img,cmap='gray',interpolation='none')
    plt.title(f"预测值:{pre_label}",fontproperties=font, fontsize=9)
    plt.xticks([])
    plt.yticks([])
plt.show()

















结果及其图像显示

D:\Anaconda3\envs\pytorch\python.exe D:\learn_pytorch\学习过程\第四周的代码\代码一:BP识别MNIST任务说明.py 
BP识别MNIST任务说明---------------------
D:\learn_pytorch\学习过程\第四周的代码\代码一:BP识别MNIST任务说明.py:49: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.
  x=self.softmax(x)#最后输出的数值我们需要利用到独热编码的思想
第1轮的损失:815.2986399680376,抽取次数和:938
第2轮的损失:281.06414164602757,抽取次数和:1876
第3轮的损失:205.76270231604576,抽取次数和:2814
第4轮的损失:159.9431014917791,抽取次数和:3752
第5轮的损失:131.47989665158093,抽取次数和:4690
第6轮的损失:109.93954652175307,抽取次数和:5628
第7轮的损失:95.26143277343363,抽取次数和:6566
第8轮的损失:85.17149402853101,抽取次数和:7504
第9轮的损失:75.1239058477804,抽取次数和:8442
第10轮的损失:68.23363681556657,抽取次数和:9380
第11轮的损失:60.05981844640337,抽取次数和:10318
第12轮的损失:54.82598690944724,抽取次数和:11256
第13轮的损失:51.70861432119273,抽取次数和:12194
第14轮的损失:46.613128249999136,抽取次数和:13132
第15轮的损失:43.05269447225146,抽取次数和:14070

进程已结束,退出代码0

 可以看到,经过15轮得训练后,我们基本上已经完全能识别出来了。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值