PyTorch 实现ResNet-50算法
- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊
前言:
其实在前面的YOLOv5的学习中已经涉及到残差块的知识点,不过没有深入展开,所以今天基于前面的了解,接下来我将用PyTorch实现ResNet-50模块的搭建,并用于鸟类图像识别。
文章目录:
文章目录
一、ResNet -50网络模型介绍
纵观前面的经典网络如AlexNet和VGG网络模型,人们通过增加网络的深度提高了网络的性能,实现了识别精度的提高,但是这个方法不可避免的带来了参数巨大,训练消耗的成本高的问题,后面人们在实验中还发现了一个奇怪的现象,那就是随着网络深度的提高反而准确率下降,出现了梯度爆炸和梯度消失,在2015年,何恺明提出了在网络中加入残差模块来缓解这一问题,同时在视觉挑战赛中取得了优异的成果也证明了它的有效性。
残差的主要思想是,将这一次的输入与通过网络模型后的输出进行相加,使得这次的效果不比经过网络模型后的效果差。
注:X是输入,F(X)是通过网络的输出,identity即为残差部分。
在添加残差块时需要注意,如果经过网络模型前后通道数发生改变,需要在分支部分添加卷积层实现与通过网络输出保持一致。
本文主要以ResNet-50为例,其网络结构如下:
ResNet-50 的结构主要由以下几个部分组成:
-
初始卷积层:
- 7x7 卷积层(64个滤波器)
- ReLU 激活函数
- 批量归一化
- 最大池化(3x3,步长为2)
-
四个阶段的残差块组:
- Stage 1: 3个瓶颈块,每个块包含三个卷积层(1x1, 3x3, 1x1),通道数从64到256。
- Stage 2: 4个瓶颈块,通道数从256到512,第一个块进行下采样(步长为2)。
- Stage 3: 6个瓶颈块,通道数从512到1024,第一个块进行下采样(步长为2)。
- Stage 4: 3个瓶颈块,通道数从1024到2048,第一个块进行下采样(步长为2)。
-
全局平均池化层:将每个特征图的大小缩小到1x1。
-
全连接层:输出类别数(如1000类的ImageNet数据集)
二、实验部分
2.1 导入相关库
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision
from torchvision import transforms, datasets
import os,PIL,pathlib,warnings
warnings.filterwarnings("ignore") #忽略警告信息
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
device(type='cuda')
2.2 展示数据的类别
import os,PIL,random,pathlib
data_dir = 'data/bird_photos'
data_dir = pathlib.Path(data_dir)
data_paths = list(data_dir.glob('*'))
classeNames = [str(path).split("/")[2] for path in data_paths]
classeNames
['Cockatoo', 'Black Skimmer', 'Black Throated Bushtiti', 'Bananaquit']
2.3 图片预处理
Transform=transforms.Compose([
transforms.Resize([224,224]),
transforms.RandomHorizontalFlip(), #将图片随机水平翻转
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485,0.456,0.406],
std=[0.229,0.224,0.225]
)
])
total_data=datasets.ImageFolder('data/bird_photos',transform=Transform)
total_data
Dataset ImageFolder
Number of datapoints: 565
Root location: data/bird_photos
StandardTransform
Transform: Compose(
Resize(size=[224, 224], interpolation=bilinear, max_size=None, antialias=True)
RandomHorizontalFlip(p=0.5)
ToTensor()
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
)
2.4 展示图片
import random,PIL
import matplotlib.pyplot as plt
from PIL import Image
data_paths2=list(data_dir.glob('*/*'))
plt.figure(figsize=(10,4))
for i in range(10):
plt.subplot(2,5,i+1)
plt.axis('off')
image=random.choice(data_paths2) #随机选择一个图片
plt.title(image.parts[-2],fontsize=10) #通过glob对象取出他的文件夹名称,即分类名
plt.imshow(Image.open(str(image))) #显示图片
plt.show()
total_data.class_to_idx
{'Bananaquit': 0,
'Black Skimmer': 1,
'Black Throated Bushtiti': 2,
'Cockatoo': 3}
2.5划分训练集和测试集
train_size=int(0.8*len(total_data))
test_size=len(total_data)-train_size
train_dataset,test_dataset=torch.utils.data.random_split(total_data,[train_size,test_size])
train_dataset,test_dataset
(<torch.utils.data.dataset.Subset at 0x7f285d713f90>,
<torch.utils.data.dataset.Subset at 0x7f27749e5d50>)
train_size,test_size
(452, 113)
from random import shuffle
import torch.utils
batch_size=16
train_dl=torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=1)
test_dl=torch.utils.data.DataLoader(test_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=1)
for X,y in test_dl:
print("Shape of X [N,C,H,W]",X.shape)
print("Shape of y: ",y.shape,y.dtype)
break
Shape of X [N,C,H,W] torch.Size([16, 3, 224, 224])
Shape of y: torch.Size([16]) torch.int64
2.6 搭建网络模型
import torch
import torch.nn as nn
from torchinfo import summary
class Bottleneck(nn.Module):
def __init__(self, in_channels, bottleneck_channels, out_channels, stride=1, downsample=None)