基于pytorch和Fashion Mnist数据集建立简单的CNN网络来实现图片分类

本文介绍了如何使用Pytorch构建一个简单的CNN网络对Fashion Mnist数据集进行图像分类。通过加载数据、构建网络、训练模型和评估模型,展示了Pytorch在深度学习中的应用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

写这篇文章,我主要是想要介绍一种流行的深度学习框架---Pytorch,并且完成一个简单的CNN网络例子来加深对它的认识,我们还使用到了Fashion Mnist数据集,完成这个DL领域的“Hello World”。

相比于TF,Pytorch有很多优点。这些可以自行Google来了解。总之,Pytorch更加符合python的特性,也更加好理解。

数据集

在这个项目中,我将使用Fashion Mnist数据集,可以在kaggle下载csv文件。

网络模型

为了训练图片,我们使用CNN网络来提取图片的各项特征。例如:

在这个网络中,一共有五个卷积层。我们将输入值通过filters来更新数值,这些kernel values通过前向传播来 

计算损失函数,然后通过反向传播来再次更新损失函数。在这个项目中,我们使用的网络结构如下图所示:

 

我们使用了两个卷积层,kenel的大小为5*5,后面是一个全连接层,最后输出一个激活值给最后的输出层。下面,我们来逐个分析代码。

Imports

import torch 
import torch.nn as nn
import torchvision.datasets as dsets
from skimage import transform
import torchvision.transforms as transforms
from torch.autograd import Variable
import pandas as pd;
import numpy as np;
from torch.utils.data import Dataset, DataLoader
from vis_utils import *  #可视化工具,可以省略
import random;
import math;

声明必要的全局变量:

num_epochs = 5;
batch_size = 100;
learning_rate = 0.001;

加载数据

在pytorch中,为了加载数据,我们需要编写一个类来继承Dataset类型,然后定义数据读取函数和数据获取函数,来看看我们是怎么

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值