【超级干货】2天速成PyTorch深度学习入门教程,缓解研究生焦虑

3、cnn基础

卷积神经网络

在这里插入图片描述

输入层 —输入图片矩阵

  • 输入层一般是 RGB 图像或单通道的灰度图像,图片像素值在[0,255],可以用矩阵表示图片
  • 在这里插入图片描述

卷积层 —特征提取

  • 人通过特征进行图像识别,根据左图直的笔画判断X,右图曲的笔画判断圆
    在这里插入图片描述
  • 卷积操作
    在这里插入图片描述
    在这里插入图片描述

激活层 —加强特征

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

池化层 —压缩数据

在这里插入图片描述
在这里插入图片描述

全连接层 —进行分类

输出层 —输出分类概率

在这里插入图片描述

4、基于LeNet实现cifar10数据集分类

1、数据集

在这里插入图片描述

  • DataSet
import torchvision
from torchvision import transforms
def get_train_dataset(data_root):
    transform = transforms.Compose(
        [transforms.ToTensor(),#0~1
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])#-1,1
    #(0-0.5)/0.5=-1 (1-0.5)/0.5=1
    train_dataset = torchvision.datasets.CIFAR10(root=data_root, train=True,
                                            download=True, transform=None)
    return train_dataset
train_dataset = get_train_dataset("dataset")
image,label = train_dataset[1]
print(train_dataset.classes[label])
print(type(image))

transforms 设定处理图片的规则

# Composes several transforms together (authority)
# 把一系列图片操作【组合】起来
transform = transforms.Compose
# Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor
# 将PIL Image格式或者numpy.ndarray格式的数据格式化为可被pytorch快速处理的tensor类型
transforms.ToTensor() 
# Normalize a tensor image with mean and standard deviation
# output[channel] = (input[channel] - mean[channel]) / std[channel]
# 使用均值和方差对数据归一化,保证程序运行时收敛加快,训练次数少点
# 因为tensor是0-1,经过normalize可以变成(-1,1) 所以可以使用0.5,当然并不一定都是0.5
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

在这里插入图片描述

# root :Root directory of dataset where directory(数据集的根目录)
# train  If True, creates dataset from training set
# download If true, downloads the dataset from the internet
# transform A transform that takes in an PIL image and returns a transformed version
torchvision.datasets.CIFAR10(root=data_root, train=True,
                                            download=True, transform=transform)                              

验证检查

train_dataset = get_train_dataset("dataset")
image,label = train_dataset[1]
print(train_dataset.classes[label])
image.show()

  • DataLoader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=32,
                                              shuffle=True, num_workers=2)

说明

dataset:数据集
batch_size:how many samples per batch to load
shuffle:set to ``True`` to have the data reshuffled
num_workers:
how many subprocesses to use for data 加载数据集采用单进程还是多进程
0 means that the data will be loaded in the main process.数据在主进程加载(windows建议0

测试

val_dataset = get_val_dataset("dataset")
val_loader = get_val_loader(val_dataset)
for data in val_loader:
    images,labels = data
    print(images.shape)
    print(labels)

2、模型

  • 搭建网络

import torch
from torch import nn

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=6,kernel_size=5,stride=1,padding=0)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16,kernel_size=5,stride=1,padding=0)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(in_features=16 * 5 * 5, out_features=120)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(in_features=120, out_features=84)
        self.relu4 = nn.ReLU()
        self.fc3 = nn.Linear(in_features=84, out_features=10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool2(x)

        x = self.flatten(x)
        print(x.shape)
        x = self.fc1(x)
        x = self.relu3(x)
        x = self.fc2
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

金科铁码

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值