MNIST手写数据集识别练习:
- MNIST数据集是在进行深度学习分类时最常用的一个数据集,是手写的0-9十个数,图像大小为(1, 28, 28)。
- MNIST数据集都是灰度图所以,通道数为1。
- MNIST数据集大概有7w多张图片(6w训练,1w测试)
这段时间观看b站up主 刘二大人 的关于深度学习的讲解,让我获益匪浅。在网络中添加残差模块,使得测试精度再次提升,下面是视频地址。
https://www.bilibili.com/video/BV1Y7411d7Ys?p=11&share_source=copy_web
下面为相关代码:
文中绘制相关曲线使用的是visdom库,优点是根据运行计算出的损失值,实时的进行绘制。
visdom的安装和使用可以浏览下列连接的博客。
https://blog.youkuaiyun.com/qq_42962681/article/details/116271548
为了能更加了解网络的结构和参数的传递,标记了大量的注释。
import torch
import torch.nn as nn
from torch.utils.data import DataLoader # 我们要加载数据集的
from torchvision import transforms # 数据的原始处理
from torchvision import datasets # pytorch十分贴心的为我们直接准备了这个数据集
import torch.nn.functional as F # 激活函数
import torch.optim as optim
import time
import visdom
batch_size = 64
# 我们拿到的图片是pillow,我们要把他转换成模型里能训练的tensor也就是张量的格式
# transform = transforms.Compose([transforms.ToTensor()])
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)) # 你均值和方差都只传入一个参数,就报错了.
# 这个函数的功能是把输入图片数据转化为给定均值和方差的高斯分布,使模型更容易收敛。图片数据是r,g,b格式,对应r,g,b三个通道数据都要转换。
])
# 加载训练集,pytorch十分贴心的为我们直接准备了这个数据集,注意,即使你没有下载这个数据集
# 在函数中输入download=True,他在运行到这里的时候发现你给的路径没有,就自动下载
train_dataset = datasets.MNIST(root='../datasets/fashion-mnist/', train=True, download=True, transform=transform)
train_loader = DataLoader(dataset=train_dataset, shuffle=True, batch_size=batch_size)
# 同样的方式加载一下测试集
test_dataset = datasets.MNIST(root='../datasets/fashion-mnist/', train=False, download=True, transform=transform)
test_loader = DataLoader(dataset=test_dataset, shuffle=False, batch_size=batch_size)
#为了使预测精度更高,我在这里添加了残差网络模块
# 残差网络示意
class ResidualBlock(nn.Module):
def __init__(self, channels):
super(ResidualBlock, self).__init__()
self.channels = channels
#这里输入通道与输出通道相等,且经过kernel=3的卷积核,由于w,h填充1,所以图像尺寸不变
self