Mnist作业
1 设计思路
首先,我们要明确我们的任务是什么。任务是对mnist手写数据集进行分类,那么我们用pytorch来做的话,它是内置了Fashion-Mnist和Mnist数据集,我们只要从框架里面下载即可。然后我们创建dataloader,设置batch_size,设置交叉熵损失函数因为这是一个多分类问题。注意一般二分类不用这个损失函数,一般用BCELoss损失函数。然后设置优化器为SGD或者Adam,设置学习率、权重衰减、动量值。设置训练函数,进行训练,将正确率和loss显示在终端上,最后可视化每一个epoch的准确率和损失。
2 实现过程
2.1 导入相关包
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from torch.utils import *
import torchvision
2.2 处理数据集
transfroma = transforms.Compose(
[
transforms.ToTensor(),
]
)
# 内置mnist
train_ds=torchvision.datasets.MNIST(
root='MNIST', # 根路径
train=True, # 是否为训练集
transform=transfroma, # 图像预处理 转为tensor 或者裁剪 归一化
download=True # 是否下载
)
test_ds=torchvision.datasets.MNIST(
root='MNIST',
train=False,
transform=transfroma,
download=True
)
2.3 创建dataloader
BATCH_SIZE = 64 # 一次送入多少数据
# 创建dataloader
train_dl =