Pytorch实现手写数字识别

'''
用PyTorch完成手写数字识别
1.准备数据
2.构建模型
3.模型的训练
4.模型的保存
5.模型的评估
'''
import torch
import os
import numpy as np
from torch import nn
from torch import optim
import torch.nn.functional  as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import Compose,ToTensor,Normalize
# from torchvision import transforms
BATCH_SIZE = 128
# 1.准备数据集
def get_dataloader(train=True):
    transform_fn =Compose([
        ToTensor(), # mean和std的形状与通道数相同
        Normalize(mean=(0.1307,),std=(0.3081,))
    ])
    dataset = MNIST(root='./data',train=train,transform=transform_fn)
    data_loader = DataLoader(dataset,batch_size=BATCH_SIZE,shuffle=True)
    return data_loader
#2.构建模型
class MnistNet(nn.Module):
    def __init__(self):
        super(MnistNet,self).__init__()
        self.fc1 = nn.Linear(28*28*1,28)
        self.fc2 = nn.Linear(28,10)
    def forward(self,x):
        x = x.view(-1,28*28*1)
        x = self.fc1(x)
        x = F.relu(x)
        out = self.fc2(x)
        return F.log_softmax(out,dim=-1)

    # 模型实例化
model = MnistNet()

optimizer = optim.Adam(model.parameters(), lr=0.001)
if os.path.exists("./model/model.pkl"):
    model.load_state_dict(torch.load('./model/model.pkl'))
    optimizer.load_state_dict(torch.load('./model/optimizer.pkl'))


def train(epoch):
    '''实现训练过程'''
    data_loader = get_dataloader()
    for idx,(input,traget) in enumerate(data_loader):
        optimizer.zero_grad()
        output = model(input)
        loss = F.nll_loss(output,traget)
        loss.backward()
        optimizer.step()
        if idx%10 == 0:
            print(epoch,idx,loss.item())
        if idx%100 ==0:
            torch.save(model.state_dict(),'./model/model.pkl')
            torch.save(optimizer.state_dict(), './model/optimizer.pkl')
def test():
    loss_list = []
    acc_list = []
    test_dataloader = get_dataloader(train=False)
    for idx,(input,traget) in enumerate(test_dataloader):
        with torch.no_grad():
            output = model(input)
            cur_loss = F.nll_loss(output,traget)
            loss_list.append(cur_loss)
            pre = output.max(dim=-1)[-1]#获取最大值的位置
            cur_acc = pre.eq(traget).float().mean()
            acc_list.append(cur_acc)
    print('平均准确率,平均损失:',np.mean(acc_list),np.mean(loss_list))

if __name__ == '__main__':
    # for i in range(3):
    #     train(i)
    test()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值