pytorch-MNIST手写数字识别与特征图可视化

本文介绍了使用PyTorch实现MNIST手写数字识别的教程,包括数据加载、模型训练、结果可视化和特征图可视化。通过训练过程中的损失和准确率展示,以及不同网络结构的准确率比较,揭示了网络结构对模型性能的影响。特征图可视化部分展示了卷积神经网络如何处理图像特征,以理解模型内部工作原理。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1、MNIST数据集

似乎所有程序员在学习一个新的程序语言时,都想要打印输出一个“hello world”,它代表了你入门了这门语言。那么,MNIST手写数字识别便是入门机器学习和深度学习的“hello world”。跑通MNIST程序便能大致了解机器学习的流程,包括数据的读取、转换(totensor)、归一化、神经网络模型设计、超参数设计、训练、前向传播、后向传播等等。在入门机器学习之前先自己跑通一遍MNIST识别程序具有非凡的意义。

MNIST(Mixed National Institute of Standards and Technologydatabase)是一个手写数字的大型数据库,拥有60,000个示例的训练集和10,000个示例的测试集。更详细的介绍可以查看 Yann LeCun的MNIST数据集官网

2、代码

本程序来自pytorch官方提供的MNIST示例代码,链接:
https://github.com/pytorch/examples/blob/master/mnist/main.py
在经过修改并添加训练结果可视化和特征图可视化等功能,github链接在本文最下方。

下面讲解train.py中的代码:

from __future__ import print_function
import argparse
import os
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
from pathlib import Path
import time

# import network
from model.network.LeNet import LeNet
from model.network.MyNetV1 import MyNetV1
from model.network.MyNetV2 import MyNetV2
from model.network.DefaultNet import DefaultNet
from model.network.LeNet5 import LeNet5
from model.network.MyFullConvNet import MyFullConvNet
from model.network.MyVggNet import MyVggNet

导入训练网络需要的模块,其中值得注意的是:
• argparse模块,该模块允许你在运行.py文件时可以附带参数,如:python train.py --model lenet
• torch基本模块,即pytorch基本的库
• matplotlib模块,用于绘制loss曲线和acc曲线图,也用于显示模型中各层特征图即特征图可视化

2.2 函数参数

通过argparse模块,可以在运行文件时添加运行所需要的参数。这些参数可以用于设置网络模型的超参数,如学习率batch-sizeepochs、训练模型等等。下面贴出代码:

# Training settings
    parser = argparse.ArgumentParser(description="Pytorch MNIST Example")
    parser.add_argument("--batch-size", type=int, default=64, metavar="N",
                        help="input batch size for training (default : 64)")
    parser.add_argument("--test-batch-size", type=int, default=1000, metavar="N",
                        help="input batch size for testing (default : 1000)")
    parser.add_argument("--epochs", type=int, default=64, metavar="N",
                        help="number of epochs to train (default : 64)")
    parser.add_argument("--learning-rate", type=float, default=0.1, metavar="LR",
                        help="number of epochs to train (default : 14)")
    parser.add_argument("--gamma", type=float, default=0.5, metavar="M",
                        help="Learning rate step gamma (default : 0.5)")
    parser.add_argument("--no-cuda", action="store_true", default=True,
                        help="disables CUDA training")
    parser.add_argument("--dry-run", action="store_true", default=False,
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值