pytorch-MNIST手写数字识别与特征图可视化
1、MNIST数据集
似乎所有程序员在学习一个新的程序语言时,都想要打印输出一个“hello world”,它代表了你入门了这门语言。那么,MNIST手写数字识别便是入门机器学习和深度学习的“hello world”。跑通MNIST程序便能大致了解机器学习的流程,包括数据的读取、转换(totensor)、归一化、神经网络模型设计、超参数设计、训练、前向传播、后向传播等等。在入门机器学习之前先自己跑通一遍MNIST识别程序具有非凡的意义。
MNIST
(Mixed National Institute of Standards and Technologydatabase)是一个手写数字的大型数据库,拥有60,000个示例的训练集和10,000个示例的测试集。更详细的介绍可以查看 Yann LeCun的MNIST数据集官网。
2、代码
本程序来自pytorch官方提供的MNIST示例代码,链接:
https://github.com/pytorch/examples/blob/master/mnist/main.py
在经过修改并添加训练结果可视化和特征图可视化等功能,github链接在本文最下方。
下面讲解train.py中的代码:
from __future__ import print_function
import argparse
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
from pathlib import Path
import time
# import network
from model.network.LeNet import LeNet
from model.network.MyNetV1 import MyNetV1
from model.network.MyNetV2 import MyNetV2
from model.network.DefaultNet import DefaultNet
from model.network.LeNet5 import LeNet5
from model.network.MyFullConvNet import MyFullConvNet
from model.network.MyVggNet import MyVggNet
导入训练网络需要的模块,其中值得注意的是:
• argparse模块,该模块允许你在运行.py
文件时可以附带参数,如:python train.py --model lenet
• torch基本模块,即pytorch基本的库
• matplotlib模块,用于绘制loss曲线和acc曲线图,也用于显示模型中各层特征图即特征图可视化
2.2 函数参数
通过argparse
模块,可以在运行文件时添加运行所需要的参数。这些参数可以用于设置网络模型的超参数,如学习率
、batch-size
、epochs
、训练模型等等。下面贴出代码:
# Training settings
parser = argparse.ArgumentParser(description="Pytorch MNIST Example")
parser.add_argument("--batch-size", type=int, default=64, metavar="N",
help="input batch size for training (default : 64)")
parser.add_argument("--test-batch-size", type=int, default=1000, metavar="N",
help="input batch size for testing (default : 1000)")
parser.add_argument("--epochs", type=int, default=64, metavar="N",
help="number of epochs to train (default : 64)")
parser.add_argument("--learning-rate", type=float, default=0.1, metavar="LR",
help="number of epochs to train (default : 14)")
parser.add_argument("--gamma", type=float, default=0.5, metavar="M",
help="Learning rate step gamma (default : 0.5)")
parser.add_argument("--no-cuda", action="store_true", default=True,
help="disables CUDA training")
parser.add_argument("--dry-run", action="store_true", default=False,