mnist 数据集是一个非常出名的数据集,基本上很多网络都将其作为一个测试的标准,其来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST)。 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员,一共有 60000 张图片。 测试集(test set) 也是同样比例的手写数字数据,一共有 10000 张图片。
每张图片大小是 28 x 28 的灰度图:
完整代码在:GitHub 一共4个文件
MNIST.py 是主函数
net.py 里面定义了3种网络,训练的时候选择其中一种
readpic.py 用于读取图片,看看能否识别出来
3.jpg 就是自己用画图手写的一个数字和28*28差不多大
net.py 代码:
import torch
import torch.nn as nn
class simpleNet(nn.Module):
"""
简单的三层全连接神经网络
"""
def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
super(simpleNet, self).__init__()
self.layer1 = nn.Linear(in_dim, n_hidden_1)
self.laye