【深度学习案例】手写数字项目实现-2.Python模型训练


  该项目所用到的源码以及所有源码均在GitHub以及Gitee上面开源,下载方式:

GitHub: 
git clone https://github.com/guojin-yan/MNIST_demo.git

Gitee:
git clone https://gitee.com/guojin-yan/MNIST_demo.git

4. Python基于Pytorch框架实现模型训练

4.1 训练环境

CUDA 11.4

Pytorch 1.12.1+cu113

Python 3.9

4.2 定义数据加载器

  新建一个数据加载器文件dataloader.py,用于加载训练数据,后续模型训练时通过数据加载器按批次加载训练集。

  首先导入以下模块,该模块在安装torch模块后就可以使用。

import torchvision
from torch.utils.data import DataLoader

  接下来定义一个class Dataloader()类,并初始化相关变量。主要设置训练集、测试集批次大小即可,后面我们直接调用Pytorch的MNIST数据集 API 接口加载数据集,所以需要定义的变量较少。

'''
    数据集加载类
    功能:
    实现数据集本地读取,并将数据集按照指定要求进行预处理;
    后续模型训练直接调用DataLoader逐步进行训练
    初始化参数:
    batch_size_train:训练集批次
    batch_size_test:测试集批次
'''
class Dataloader():
    def __init__(self,batch_size_train, batch_size_test):
        # 初始化训练集bath size
        self.batch_size_train = batch_size_train
        # 初始化测试集bath size
        self.batch_size_test = batch_size_test
   

  分别定义训练集加载器train_loader()以及test_loader(),用于加载训练集以及测试集。torchvision.datasets.MNIST()是Pytorch提供的MNIST数据集加载接口API函数:'./Datasets/'为指定的数据集本地下载路径;download可以设置是否下载数据集,当指定为True且首次运行时,会将数据集下载到指定的路径下,再次运行时会检测路径下是否有该文件,如果有会直接读取,如果没有将会重新下载;transform指定的为数据处理方式,主要是将数据进行归一化以及数据类型转换处理,还可以对数据进行增强处理;batch_size指定训练时数据集加载的批次大小。

#加载训练集
def train_loader(self):
    # 调用pytorch自带的DataLoader方法加载MNIST训练集;
    # 直接使用pytorch的MNIST数据集API接口加载数据,
    # 第一次使用可以设置为True
    train_load = DataLoader(
        torchvision.datasets.MNIST('./Datasets/', train=True, download=True,
                                transform=torchvision.transforms.Compose([
                                    torchvision.transforms.ToTensor(),
                                    torchvision.transforms.Normalize(
                                        (0.1307,), (0.3081,))
                                ])),
        batch_size = self.batch_size_train, shuffle=True)
    return train_load
#加载测试集
def test_loader(self):
    # 调用pytorch自带的DataLoader方法加载MNIST测试集
    test_load = DataLoader(
        torchvision.datasets.MNIST('./Datasets/', train=False, download=True,
                                transform=torchvision.transforms.Compose([
                                    torchvision.transforms.ToTensor(),
                                    torchvision.transforms.Normalize(
                                        (0.1307,), (0.3081,))
                                ])),
        batch_size = self.batch_size_test, shuffle=True)
    return test_load

4.3 定义网络(net,py)

  接下来定义训练网络,要实现后面我们的网络能够很好的识别我们的手写数字,在此处就要定义一个比较好的网络,下面的网络是我参考的网上一些人所做的模型定义的。

  首先导入一下模块:torch.nn 模块下定义了各种我们常见的网络层以及网络结构,直接调用该模块进行网络构建;torch.nn.functional模块是定义了各种激活函数的模块。

import torch.nn as nn
import torch.nn.functional as F

  我们将网络定义到Net类中,继承``nn.Module`''类模板。在初始化时定义一些在前向传播中所用到的网络层,方便再后面直接使用。

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        '''定义相关的计算层'''
        # 定义一个卷积核为1×10的卷积层
        self.conv1 = nn.Conv2d(1, 
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

椒颜皮皮虾྅

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值