文章目录
摘要
通过学习CNN模型的训练及验证套路,对模型训练以及模型验证套路有了基本认识,并趁热打铁使用CNN模型实现mnist手写数字识别的实操
Abstract
By learning the training and verification routines of CNN model, I have a basic understanding of the training and verification routines of the model, and use the CNN model to realize the practical operation of mnist handwritten digit recognition
1 完整模型训练套路及模型验证套路
完整模型训练套路(以CIFAR10数据集为例)
1.1 模型及训练代码
model.py
import torch
import torch.nn as nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
# model
class model(nn.Module):
def __init__(self):
super(model, self).__init__()
self.m = Sequential(
Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),
MaxPool2d(2),
Conv2d(32, 32, 5, 1, 2),
MaxPool2d(2),
Conv2d(32, 64, 5, 1, 2),
MaxPool2d(2),
Flatten(),
Linear(64*4*4, 64),
Linear(64, 10)
)
def forward(self,x):
return self.m(x)
# 在该模块中测试model
if __name__ == '__main__':
m = model()
input = torch.ones([64, 3, 32, 32])
output = m(input)
print(output.shape)
补充:
- argmax()使用
import torch
output = torch.tensor([[0.1, 0.5],
[0.2, 0.4]])
# dim = 1 数组横向比较中较大的下标
print(output.argmax(dim=1)) # tensor([1, 1])
# dim = 0 数组纵向比较中较大的下标
print(output.argmax(dim=0)) # tensor([1, 0])
train,py
import torch
import torchvision
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from torch.utils.tensorboard import SummaryWriter
# 引入网络模型
from model import *
# 训练数据集
train_data = torchvision.datasets.CIFAR10("dataset2", train=True, transform=torchvision.transforms.ToTensor())
# 测试数据集
test_data = torchvision.datasets.CIFAR10("dataset2", train=False, transform=torchvision.transforms.ToTensor())
# 数据长度
train_data_size = len(train_data)
test_data_size =len(test_data)
print("训练集数据长度为:{}".format(train_data_size)) # 50000
print("测试集数据长度为:{}".format(test_data_size)) # 10000
# 利用DataLoader来加载数据集
train_Dataloader = DataLoader(train_data, batch_size=64)
test_Dataloader = DataLoader(test_data, batch_size=64)
# 创建网络模型
mm = model()
# 损失函数
loss_fn = CrossEntropyLoss()
# 优化器
# 学习率:learing_rate = 0.01
# 1e-2 = 1×(10)^(-2)=1/100 = 0.01
learning_rate = 1e-2
optimizer = torch.optim