PyTorch入门项目之卷积神经网络实现Fashion-MNIST图像分类(详细注释!)

PyTorch入门项目,卷积神经网络实现图像分类任务


前言

PyTorch目前是学界最流行的深度学习框架之一。对于已经有了深度学习基础的学生而言,急需一个小项目学习练手。然而,网上讲授许多Torch视频教程虽称快速,但很多还是废话较多,部分人可能没有这个耐心看下去,所以我记录了我初次使用torch的一个全流程的训练步骤:导入包、数据集准备、网络搭建、训练、测试、准确率打印、模型保存和loss曲线的内容。


提示:本文基于PyTorch1.7.1+cu101、Python3.7,仿AlexNet,基于FashionMNIST数据集;(没有版本说明和代码注释的都是耍流氓)

一、PyTorch安装

关于PyTorch的安装,这里不做过多的叙述,方法很多。
首先,根据自己显卡的CUDA Version来选择对应版本:
cmd命令下输入nvidia-smi 查看:
cmd命令下输入nvidia-smi 查看
比如,我的是11.2,可以选择支持11.2及以下版本的Torch版本安装。进入Torch官网:https://pytorch.org/选择:
pytorch
不知道怎么安装cudatoolkit的话,最简单的就是创建好conda环境后(conda环境配置等问题在此不叙述),进入环境直接复制:

pip3 install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio===0.9.0

(可能会比较慢,建议用镜像)
当然,不建议选最新的,教程比较少,可以点previous version of PyTorch选择稍低一两个版本的。

二、实践步骤

1.引入相关库

代码如下(示例):(没有相关module的pip install 下)

import torch
import torchvision  # 数据集用
import torch.nn as nn  # 搭建网络
import torch.utils.data as Data  # 加载数据
import time   # 计时用
import matplotlib.pyplot as plt  # 绘图用

2.超参数设置

代码如下(示例):

# hyper_parameters setting
DOWNLOAD = True  # 'True'表示需要下载数据集,'False' 表示已经下载好.
BATCH_SIZE = 128   #根据自己电脑实际设置 16 32 64等
EPOCH = 10  # 迭代次数
learning_rate = 0.05  #学习率
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  # 自动选择训练设备,cpu / gpu(如果支持)

根据需要设置超参数,方便更改程序。

3.数据集准备

以FashionMNIST为例,该数据集可看做是MNIST的升级版,是一些服饰的分类。包含10个类,训练集的图片有60000张,测试集10000张,均为28*28的灰度图像。


## load dataset

train_data = torchvision.datasets.FashionMNIST(
 root='./FashionMnist',  # 数据存放目录
 train=True,  # 用于训练
 transform=torchvision.transforms.ToTensor(),  # 转换为[0,1]的tensor
 download=DOWNLOAD, #是否下载数据集
)

test_data = torchvision.datasets.FashionMNIST(
 root='./FashionMnist', 
 train=False, # 用于测试  
 transform=torchvision.transforms.ToTensor(),
 download=DOWNLOAD)

print(train_data[0]) # 打印查看数据情况 是一个tensor 图像数据和标签

# 数据里的特征和标签按batch分别打包
# shuffle 表示是否随机打乱数据  
# num_workers 表示是否开启新线程加速数据的加载,但设为其他数字时windows系统常报错,所以最好设为0

train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

test_loader = Data.DataLoader(dataset=test_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

其他数据集的导入

评论 3
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值