写这篇文章,我主要是想要介绍一种流行的深度学习框架---Pytorch,并且完成一个简单的CNN网络例子来加深对它的认识,我们还使用到了Fashion Mnist数据集,完成这个DL领域的“Hello World”。
相比于TF,Pytorch有很多优点。这些可以自行Google来了解。总之,Pytorch更加符合python的特性,也更加好理解。
数据集
在这个项目中,我将使用Fashion Mnist数据集,可以在kaggle下载csv文件。
网络模型
为了训练图片,我们使用CNN网络来提取图片的各项特征。例如:
在这个网络中,一共有五个卷积层。我们将输入值通过filters来更新数值,这些kernel values通过前向传播来
计算损失函数,然后通过反向传播来再次更新损失函数。在这个项目中,我们使用的网络结构如下图所示:
我们使用了两个卷积层,kenel的大小为5*5,后面是一个全连接层,最后输出一个激活值给最后的输出层。下面,我们来逐个分析代码。
Imports
import torch
import torch.nn as nn
import torchvision.datasets as dsets
from skimage import transform
import torchvision.transforms as transforms
from torch.autograd import Variable
import pandas as pd;
import numpy as np;
from torch.utils.data import Dataset, DataLoader
from vis_utils import * #可视化工具,可以省略
import random;
import math;
声明必要的全局变量:
num_epochs = 5;
batch_size = 100;
learning_rate = 0.001;
加载数据
在pytorch中,为了加载数据,我们需要编写一个类来继承Dataset类型,然后定义数据读取函数和数据获取函数,来看看我们是怎么