'''
用PyTorch完成手写数字识别
1.准备数据
2.构建模型
3.模型的训练
4.模型的保存
5.模型的评估
'''
import torch
import os
import numpy as np
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import Compose,ToTensor,Normalize
# from torchvision import transforms
BATCH_SIZE = 128
# 1.准备数据集
def get_dataloader(train=True):
transform_fn =Compose([
ToTensor(), # mean和std的形状与通道数相同
Normalize(mean=(0.1307,),std=(0.3081,))
])
dataset = MNIST(root='./data',train=train,transform=transform_fn)
data_loader = DataLoader(dataset,batch_size=BATCH_SIZE,shuffle=True)
return data_loader
#2.构建模型
class MnistNet(nn.Module):
def __init__(self):
super(MnistNet,self).__init__()
self.fc1 = nn.Linear(28*28*1,28)
self.fc2 = nn.Linear(28,10)
def forward(self,x):
x = x.view(-1,28*28*1)
x = self.fc1(x)
x = F.relu(x)
out = self.fc2(x)
return F.log_softmax(out,dim=-1)
# 模型实例化
model = MnistNet()
optimizer = optim.Adam(model.parameters(), lr=0.001)
if os.path.exists("./model/model.pkl"):
model.load_state_dict(torch.load('./model/model.pkl'))
optimizer.load_state_dict(torch.load('./model/optimizer.pkl'))
def train(epoch):
'''实现训练过程'''
data_loader = get_dataloader()
for idx,(input,traget) in enumerate(data_loader):
optimizer.zero_grad()
output = model(input)
loss = F.nll_loss(output,traget)
loss.backward()
optimizer.step()
if idx%10 == 0:
print(epoch,idx,loss.item())
if idx%100 ==0:
torch.save(model.state_dict(),'./model/model.pkl')
torch.save(optimizer.state_dict(), './model/optimizer.pkl')
def test():
loss_list = []
acc_list = []
test_dataloader = get_dataloader(train=False)
for idx,(input,traget) in enumerate(test_dataloader):
with torch.no_grad():
output = model(input)
cur_loss = F.nll_loss(output,traget)
loss_list.append(cur_loss)
pre = output.max(dim=-1)[-1]#获取最大值的位置
cur_acc = pre.eq(traget).float().mean()
acc_list.append(cur_acc)
print('平均准确率,平均损失:',np.mean(acc_list),np.mean(loss_list))
if __name__ == '__main__':
# for i in range(3):
# train(i)
test()
Pytorch实现手写数字识别
最新推荐文章于 2024-07-12 22:38:25 发布
本文介绍使用PyTorch构建的手写数字识别系统。主要内容包括数据预处理、构建神经网络模型、训练模型、保存及加载模型参数以及最终模型的评估。通过本教程可以了解如何利用PyTorch进行基本的计算机视觉任务。
部署运行你感兴趣的模型镜像
您可能感兴趣的与本文相关的镜像
PyTorch 2.5
PyTorch
Cuda
PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理
7666

被折叠的 条评论
为什么被折叠?



