PyTorch 实现ResNet-50算法

PyTorch 实现ResNet-50算法

前言

其实在前面的YOLOv5的学习中已经涉及到残差块的知识点,不过没有深入展开,所以今天基于前面的了解,接下来我将用PyTorch实现ResNet-50模块的搭建,并用于鸟类图像识别。

文章目录

一、ResNet -50网络模型介绍

​ 纵观前面的经典网络如AlexNet和VGG网络模型,人们通过增加网络的深度提高了网络的性能,实现了识别精度的提高,但是这个方法不可避免的带来了参数巨大,训练消耗的成本高的问题,后面人们在实验中还发现了一个奇怪的现象,那就是随着网络深度的提高反而准确率下降,出现了梯度爆炸和梯度消失,在2015年,何恺明提出了在网络中加入残差模块来缓解这一问题,同时在视觉挑战赛中取得了优异的成果也证明了它的有效性。

​ 残差的主要思想是,将这一次的输入与通过网络模型后的输出进行相加,使得这次的效果不比经过网络模型后的效果差。

在这里插入图片描述

注:X是输入,F(X)是通过网络的输出,identity即为残差部分。

在添加残差块时需要注意,如果经过网络模型前后通道数发生改变,需要在分支部分添加卷积层实现与通过网络输出保持一致。

在这里插入图片描述

本文主要以ResNet-50为例,其网络结构如下:

在这里插入图片描述

ResNet-50 的结构主要由以下几个部分组成:

  1. 初始卷积层

    • 7x7 卷积层(64个滤波器)
    • ReLU 激活函数
    • 批量归一化
    • 最大池化(3x3,步长为2)
  2. 四个阶段的残差块组

    • Stage 1: 3个瓶颈块,每个块包含三个卷积层(1x1, 3x3, 1x1),通道数从64到256。
    • Stage 2: 4个瓶颈块,通道数从256到512,第一个块进行下采样(步长为2)。
    • Stage 3: 6个瓶颈块,通道数从512到1024,第一个块进行下采样(步长为2)。
    • Stage 4: 3个瓶颈块,通道数从1024到2048,第一个块进行下采样(步长为2)。
  3. 全局平均池化层:将每个特征图的大小缩小到1x1。

  4. 全连接层:输出类别数(如1000类的ImageNet数据集)

二、实验部分

2.1 导入相关库

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision
from torchvision import transforms, datasets
import os,PIL,pathlib,warnings

warnings.filterwarnings("ignore")             #忽略警告信息

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
device(type='cuda')

2.2 展示数据的类别

import os,PIL,random,pathlib

data_dir = 'data/bird_photos'
data_dir = pathlib.Path(data_dir)

data_paths  = list(data_dir.glob('*'))
classeNames = [str(path).split("/")[2] for path in data_paths]
classeNames
['Cockatoo', 'Black Skimmer', 'Black Throated Bushtiti', 'Bananaquit']

2.3 图片预处理

Transform=transforms.Compose([
    transforms.Resize([224,224]),
    transforms.RandomHorizontalFlip(), #将图片随机水平翻转
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485,0.456,0.406],
        std=[0.229,0.224,0.225]
    )

])

total_data=datasets.ImageFolder('data/bird_photos',transform=Transform)
total_data
Dataset ImageFolder
    Number of datapoints: 565
    Root location: data/bird_photos
    StandardTransform
Transform: Compose(
               Resize(size=[224, 224], interpolation=bilinear, max_size=None, antialias=True)
               RandomHorizontalFlip(p=0.5)
               ToTensor()
               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
           )

2.4 展示图片

import random,PIL
import matplotlib.pyplot as plt
from PIL import Image


data_paths2=list(data_dir.glob('*/*'))
plt.figure(figsize=(10,4))
for i in range(10):
    plt.subplot(2,5,i+1)
    plt.axis('off')
    image=random.choice(data_paths2)  #随机选择一个图片
    plt.title(image.parts[-2],fontsize=10) #通过glob对象取出他的文件夹名称,即分类名
    plt.imshow(Image.open(str(image)))  #显示图片
plt.show()

在这里插入图片描述

total_data.class_to_idx
{'Bananaquit': 0,
 'Black Skimmer': 1,
 'Black Throated Bushtiti': 2,
 'Cockatoo': 3}

2.5划分训练集和测试集

train_size=int(0.8*len(total_data))
test_size=len(total_data)-train_size
train_dataset,test_dataset=torch.utils.data.random_split(total_data,[train_size,test_size])
train_dataset,test_dataset
(<torch.utils.data.dataset.Subset at 0x7f285d713f90>,
 <torch.utils.data.dataset.Subset at 0x7f27749e5d50>)
train_size,test_size
(452, 113)
from random import shuffle
import torch.utils


batch_size=16

train_dl=torch.utils.data.DataLoader(train_dataset,
                                     batch_size=batch_size,
                                     shuffle=True,
                                     num_workers=1)
test_dl=torch.utils.data.DataLoader(test_dataset,
                                     batch_size=batch_size,
                                     shuffle=True,
                                     num_workers=1)
for X,y in test_dl:
    print("Shape of X [N,C,H,W]",X.shape)
    print("Shape of y: ",y.shape,y.dtype)
    break
Shape of X [N,C,H,W] torch.Size([16, 3, 224, 224])
Shape of y:  torch.Size([16]) torch.int64

2.6 搭建网络模型

import torch
import torch.nn as nn
from torchinfo import summary

class Bottleneck(nn.Module):
    def __init__(self, in_channels, bottleneck_channels, out_channels, stride=1, downsample=None)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值