pytorch 手写数字识别+测试自己的图片+线性网络

基础代码网上找了
但是没法保存模型/测试自己的图片
加入了test.py
可以保存模型和测试自己的图片

main.py

#pytorch的包
import torch 
from torch import nn  # 神经网络相关工作
from torch.nn import functional as F  # 常用函数
from torch import optim  # 优化工具包
import torchvision #计算机视觉
from torchvision import transforms #对于图片进行转换
#画图的包
from matplotlib import pyplot as plt #画图
from matplotlib.colors import Normalize #归一化
#其他
from utils import plot_image, plot_curve, one_hot #自己的函数
import random

#超参数
batch_size = 512  #一次处理多张图片
LR=0.001 #学习率

#step1.加载数据集
#训练集
train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST(
        'mnist_data', #保存路径
        train=True, #训练
        download=True,#允许下载
        # 下载的数据为numpy格式转换为tensor格式,正则化使原本[0,1]的数据在0附近均匀分布
        transform=torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize((0.1307,),(0.3081,))
        ])),
        batch_size=batch_size,#批量大小
        shuffle=True#shuffle把数据随机打散
    ) 
#测试集
test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST(
        'mnist_data', 
        train=False, 
        download=True,
        transform=torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                (0.1307,),(0.3081,))
            ])),
        batch_size=batch_size,
        shuffle=False
    )

#step2.创建网络
class Net(nn.Module):
    #初始化网络
    def __init__(self):
        super(Net,self).__init__() #调用父类方法初始化函数为网络
        # 每一层为xw+b
        self.fc1 = nn.Linear(28*28,512)  #第一层28*28是由图片像元数决定的,512为经验决定的
        self.fc2 = nn.Linear(512,128)  #512->128
        self.fc3 = nn.Linear(128,32)  #128->32
        self.fc4 = nn.Linear(32,10)  #最后一层的输出值为10(因为10分类,输出必须为10)  

    #网络的正向传播(计算过程)
    def forward (self, x):
        # x: [b,1,28,28]  b张图片,单通道,每个图片28*28       
        x 
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值