深入理解pytorch-MNIST手写体识别,特征图可视化,自己动手搭建神经网络模型(超详细)

本文详述了使用PyTorch进行MNIST手写体识别的全过程,包括数据预处理、网络模型构建、训练、测试及结果可视化。作者不仅介绍了LeNet模型,还分享了自己设计的网络结构,实现超过99.5%的测试准确率。此外,通过特征图可视化揭示了神经网络内部工作原理。

0、有什么?

  • 特征图可视化
  • 训练结果可视化
  • 自己动手构造神经网络

根据需要点击目录观看。

1、前言

这篇博客的主要目的,是记录自己的深度学习的学习过程。在做完这个mnist手写体识别之后,能明显的感觉到对深度学习的理解更加深刻、对流程更加了解。关于MNIST手写体识别的博客,网上数不胜数,但是:

深入理解MNIST识别程序,并且进行拓展比如特征图可视化、学着自己搭建神经网络的却很少,

所以打算好好总结一下这次学习过程,也希望能帮助到其他小伙伴。篇幅较长,可以根据需要点击目录观看。

之前刚开始做的时候写过一篇pytorch-mnist的博客,原文链接是:
https://blog.youkuaiyun.com/qq_24489997/article/details/110432634?spm=1001.2014.3001.5502

现在这一篇算是一个补充完整版吧。附上GitHub代码链接:
https://github.com/cssdcc1997/pytorch-mnist

如果小伙伴们觉得有帮助的话别忘了点个star哦!

2、MNIST数据集

似乎所有程序员在学习一个新的程序语言时,都想要打印输出一个“hello world”,它代表了你入门了这门语言。那么,MNIST手写数字识别便是入门机器学习和深度学习的“hello world”。跑通MNIST程序便能大致了解机器学习的流程,包括数据的读取、转换(totensor)、归一化、神经网络模型设计、超参数设计、训练、前向传播、后向传播等等。在入门机器学习之前先自己跑通一遍MNIST识别程序具有非凡的意义。

MNIST(Mixed National Institute of Standards and Technologydatabase)是一个手写数字的大型数据库,拥有60,000个示例的训练集和10,000个示例的测试集。更详细的介绍可以查看 Yann LeCun的MNIST数据集官网

3、深度学习最基本流程

下面是深度学习基本流程的伪代码:

data = get_data()						# 1、获取数据
transformed_data = transform(data)		# 2、将数据进行预处理
for image in transformed_data:			# 遍历数据集
	result = model(image)				# 3、将数据送入模型,模型输出结果
	loss = compute_loss(result, target)	# 4、计算loss损失
	loss.backward()						# 5、进行反向传播
	optimizer.step()					# 6、优化器更新网络
save(model)								# 7、训练完毕,保存模型

深度学习的基本流程大致可以分为7步:
获取数据; 数据预处理; 数据送入模型,得到输出; 根据输出计算loss; 根据loss进行反向传播; 优化器更新网络参数; 训练完毕保存模型。

具体细节这篇博客就不细讲啦,我觉得初学者想学习深度学习也可以从这7步入手,比如:

  1. 如何对数据进行预处理?裁剪、缩放以及数据增强的方法。
  2. 如何设计网络、选择网络?
  3. 常用的loss函数有哪些?
  4. 反向传播的具体实现,是否能手写出反向传播的过程?
  5. 优化器如何选择?都有哪些优缺点?
  6. 优化器如何更新参数?
  7. 什么时候保存模型最好?

4、代码

程序根据pytorch官方提供的MNIST示例代码进行修改,链接:
https://github.com/pytorch/examples/blob/master/mnist/main.py

在经过修改后,添加了训练结果可视化和特征图可视化等功能,并且尝试自己设计神经网络模型,GitHub链接上面已经给出啦。

下面讲解train.py中的代码:

4.1 导入模块

from __future__ import print_function
import argparse
import os
#import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
from pathlib import Path
import time

# import network
from model.network.LeNet import LeNet
from model.network.MyNetV1 import MyNetV1
from model.network.MyNetV2 import MyNetV2
from model.network.DefaultNet import DefaultNet
from model.network.MyFullConvNet import MyFullConvNet
from model.network.MyVggNet import MyVggNet
from model.network.NewNet import NewNet

导入训练网络需要的模块,其中值得注意的是:

  • argparse模块,该模块允许你在运行.py文件时可以附带参数,如:python train.py --model lenet
  • torch基本模块,即pytorch基本的库
  • matplotlib模块,用于绘制loss曲线和acc曲线图,也用于显示模型中各层特征图即特征图可视化

4.2 函数参数

通过argparse模块,可以在运行文件时添加运行所需要的参数。这些参数可以用于设置网络模型的超参数,如学习率batch-sizeepochs、训练模型等等。下面贴出代码:

# Training settings
    parser = argparse.ArgumentParser(description="Pytorch MNIST Example")
    parser.add_argument("--batch-size", type=int, default=64, metavar="N",
                        help="input batch size for training (default : 64)")
    parser.add_argument("--test-batch-size", type=int, default=1000, metavar="N",
                        help="input batch size for testing (default : 1000)")
    parser.add_argument("--epochs", type=int, default=64, metavar="N",
                        help="number of epochs to train (default : 64)")
    parser.add_argument("--learning-rate", type=float, default=0.1, metavar="LR",
                        help="the learning rate (default : 0.1)")
    parser.add_argument("--gamma", type=float, default=0.5, metavar="M",
                        help="Learning rate step gamma (default : 0.5)")
    parser.add_argument("--use-cuda", action="store_true", default=True,
                        help="Using CUDA training")
    parser.add_argument("--dry-run", action="store_true", default=False,
                        help="quickly check a single pass")
    parser.add_argument("--seed", type=int, default=1, metavar="S",
                        help="random seed (default : 1)")
    parser.add_argument("--log-interval", type=int, default=10, metavar="N",
                        help="how many batches to wait before logging training status")
    parser.add_argument("--save-model", action = "store_true", default=True
评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值