- 🍨 本文为🔗365天深度学习训练营中的学习记录博客
- 🍖 原作者:K同学啊|接辅导、项目定制
目录
一、课题背景和开发环境
📌第J6周:ResNeXt-50实战解析📌
- 语言:Python3、Pytorch
- 📌本周任务:📌
– 1. 阅读ResNeXt论文,了解作者的构建思路
– 2. 对比我们之前介绍的ResNet50V2、DenseNet算法
– 3.使用ResNeXt-50算法完成猴痘病识别
二、模型结构
ResNeXt是由何凯明团队在2017年CVPR会议上提出来的新型图像分类网络。ResNeXt是ResNet的升级版,在ResNet的基础上,引入了cardinality的概念,类似于ResNet,ResNeXt也有ResNeXt-50,ResNeXt-101的版本。ResNeXt论文原文如下:
Aggregated Residual Transformations for Deep Neural Networks.pdf
在ResNeXt的论文中,作者提出了当时普遍存在的一个问题,如果要提高模型的准确率,往往采取加深网络或者加宽网络的方法。虽然这种方法是有效的,但是随之而来的,是网络设计的难度和计算开销的增加。为了一点精度的提升往往需要付出更大的代价。因此,需要一个更好的策略,在不额外增加计算代价的情况下,提升网络的精度。由此,何等人提出了cardinality的概念。
下图是ResNet(左)与ResNeXt(右)block的差异。在ResNet中,输入的具有256个通道的特征经过1×1卷积压缩4倍到64个通道,之后3×3的卷积核用于处理特征,经1×1卷积扩大通道数与原特征残差连接后输出。ResNeXt也是相同的处理策略,但在ResNeXt中,输入的具有256个通道的特征被分为32个组,每组被压缩64倍到4个通道后进行处理。32个组相加后与原特征残差连接后输出。这里cardinatity指的是一个block中所具有的相同分支的数目。
三、分组卷积
ResNeXt中采用的分组卷机简单来说就是将特征图分为不同的组,再对每组特征图分别进行卷积,这个操作可以有效的降低计算量。
在分组卷积中,每个卷积核只处理部分通道,比如下图中,红色卷积核只处理红色的通道,绿色卷积核只处理绿色通道,黄色卷积核只处理黄色通道。此时每个卷积核有2个通道,每个卷积核生成一张特征图。
四、Pytorch复现ResNext-50模型
1.分组卷积模块
pytorch
class GroupedConvBlock(nn.Module):
def __init__(self, in_channel, kernel_size=3, stride=1, groups=32):
super(GroupedConvBlock, self).__init__()
self.g_channel = in_channel//groups
self.groups = groups
self.conv = nn.Conv2d(self.g_channel, self.g_channel, kernel_size=3, stride=stride, padding=1, bias=False)
self.norm = nn.BatchNorm2d(in_channel)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
g_list = []
# 分组进行卷积
for c in range(self.groups):
g = x[:,c*self.g_channel:(c+1)*self.g_channel,:,:]
g = self.conv(g)
g_list.append(g)
x = torch.cat(g_list, dim=1)
x = self.norm(x)
x = self.relu(x)
return x
2.定义残差单元
pytorch
''' Residual Block '''
class Block(nn.Module):
def __init__(self, in_channel, filters, kernel_size=3, stride=1, groups=32, conv_shortcut=True):
super(Block, self).__init__()
self.shortcut = conv_shortcut
if self.shortcut:
self.short = nn.Conv2d(in_channel, 2*filters, kernel_size=1, stride=stride, padding=0, bias=False)
elif stride>1:
self.short = nn.MaxPool2d(kernel_size=1, stride=stride, padding=0)
else:
self.short = nn.Identity()
self.conv1 = nn.Sequential(
nn.Conv2d(in_channel, filters, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(filters),
nn.ReLU(True)
)
self.conv2 = GroupedConvBlock(in_channel=filters, kernel_size=kernel_size, stride=stride, groups=groups)
self.conv3 = nn.Sequential(
nn.Conv2d(filters, 2*filters, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(2*filters)
)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
if self.shortcut:
x2 = self.short(x)
else:
x2 = self.short(x)
x1 = self.conv1(x)
x1 = self.conv2(x1)
x1 = self.conv3(x1)
x = x1 + x2
x = self.relu(x)
return x
3.堆叠残差单元
每个stack的第一个block的输入和输出的shape是不一致的,所以残差连接都需要使用1*1卷积升维后才能进行Add操作。
而其他block的输入和输出的shape是一致的,所以可以直接执行Add操作。
pytorch
class Stack(nn.Module):
def __init__(self, in_channel, filters, blocks, stride=2, groups=32):
super(Stack, self).__init__()
self.conv = nn.Sequential()
self.conv.add_module(str(0), Block(in_channel, filters, stride=stride, groups=groups, conv_shortcut=True))
for i in range(1, blocks):
self.conv.add_module(str(i), Block(2*filters, filters, stride=1, groups=groups, conv_shortcut=False))
def forward(self, x):
x = self.conv(x)
return x
4.搭建ResNext-50网络
pytorch
''' ResNeXt50 '''
class ResNeXt50(nn.Module):
def __init__(self,
include_top=True, # 是否包含位于网络顶部的全链接层
preact=False, # 是否使用预激活
use_bias=True, # 是否对卷积层使用偏置
input_shape=[32, 3, 224, 224],
classes=1000,
pooling=None): # 用于分类图像的可选类数
super(ResNeXt50, self).__init__()
self.conv1 = nn.Sequential()
self.conv1.add_module('conv', nn.Conv2d(3, 64, 7, stride=2, padding=3, bias=use_bias, padding_mode='zeros'))
if not preact:
self.conv1.add_module('bn', nn.BatchNorm2d(64))
self.conv1.add_module('relu', nn.ReLU())
self.conv1.add_module('max_pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
self.conv2 = Stack(64, 128, 3, stride=1)
self.conv3 = Stack(256, 256, 4, stride=2)
self.conv4 = Stack(512, 512, 6, stride=2)
self.conv5 = Stack(1024, 1024, 3, stride=2)
self.post = nn.Sequential()
if preact:
self.post.add_module('bn', nn.BatchNorm2d(2048))
self.post.add_module('relu', nn.ReLU())
if include_top:
self.post.add_module('avg_pool', nn.AdaptiveAvgPool2d((1, 1)))
self.post.add_module('flatten', nn.Flatten())
self.post.add_module('fc', nn.Linear(2048, classes))
else:
if pooling=='avg':
self.post.add_module('avg_pool', nn.AdaptiveAvgPool2d((1, 1)))
elif pooling=='max':
self.post.add_module('max_pool', nn.AdaptiveMaxPool2d((1, 1)))
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.conv5(x)
x = self.post(x)
return x
5.查看模型摘要
pytorch
''' 调用并将模型转移到GPU中(我们模型运行均在GPU中进行) '''
model = ResNeXt50(n_class=num_classes).to(device)
#model = ResNeXt50(n_class=num_classes).to(device)
''' 显示网络结构 '''
torchsummary.summary(model, (32, 3, 224, 224))
#torchinfo.summary(model)
print(model)
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 112, 112] 9,472
BatchNorm2d-2 [-1, 64, 112, 112] 128
ReLU-3 [-1, 64, 112, 112] 0
MaxPool2d-4 [-1, 64, 56, 56] 0
Conv2d-5 [-1, 256, 56, 56] 16,384
Conv2d-6 [-1, 128, 56, 56] 8,192
BatchNorm2d-7 [-1, 128, 56, 56] 256
ReLU-8 [-1, 128, 56, 56] 0
Conv2d-9 [-1, 4, 56, 56] 144
Conv2d-10 [-1, 4, 56, 56] 144
Conv2d-11 [-1, 4, 56, 56] 144
Conv2d-12 [-1, 4, 56, 56] 144
Conv2d-13 [-1, 4, 56, 56] 144
Conv2d-14 [-1, 4, 56, 56] 144
Conv2d-15 [-1, 4, 56, 56] 144
Conv2d-16 [-1, 4, 56, 56] 144
Conv2d-17 [-1, 4, 56, 56] 144
Conv2d-18 [-1, 4, 56, 56] 144
Conv2d-19 [-1, 4, 56, 56] 144
Conv2d-20 [-1, 4, 56, 56] 144
Conv2d-21 [-1, 4, 56, 56] 144
Conv2d-22 [-1, 4, 56, 56] 144
Conv2d-23 [-1, 4, 56, 56] 144
Conv2d-24 [-1, 4, 56, 56] 144
Conv2d-25 [-1, 4, 56, 56] 144
Conv2d-26 [-1, 4, 56, 56] 144
Conv2d-27 [-1, 4, 56, 56] 144
Conv2d-28 [-1, 4, 56, 56] 144
Conv2d-29 [-1, 4, 56, 56] 144
Conv2d-30 [-1, 4, 56, 56] 144
Conv2d-31 [-1, 4, 56, 56] 144
Conv2d-32 [-1, 4, 56, 56] 144
Conv2d-33 [-1, 4, 56, 56] 144
Conv2d-34 [-1, 4, 56, 56] 144
Conv2d-35 [-1, 4, 56, 56] 144
Conv2d-36 [-1, 4, 56, 56] 144
Conv2d-37 [-1, 4, 56, 56] 144
Conv2d-38 [-1, 4, 56, 56] 144
Conv2d-39 [-1, 4, 56, 56] 144
Conv2d-40 [-1, 4, 56, 56] 144
BatchNorm2d-41 [-1, 128, 56, 56] 256
ReLU-42 [-1, 128, 56, 56] 0
GroupedConvBlock-43 [-1, 128, 56, 56] 0
Conv2d-44 [-1, 256, 56, 56] 32,768
BatchNorm2d-45 [-1, 256, 56, 56] 512
ReLU-46 [-1, 256, 56, 56] 0
Block-47 [-1, 256, 56, 56] 0
Identity-48 [-1, 256, 56, 56] 0
Conv2d-49 [-1, 128, 56, 56] 32,768
BatchNorm2d-50 [-1, 128, 56, 56] 256
ReLU-51 [-1, 128, 56, 56] 0
Conv2d-52 [-1, 4, 56, 56] 144
Conv2d-53 [-1, 4, 56, 56] 144
Conv2d-54 [-1, 4, 56, 56] 144
Conv2d-55 [-1, 4, 56, 56] 144
Conv2d-56 [-1, 4, 56, 56] 144
Conv2d-57 [-1, 4, 56, 56] 144
Conv2d-58 [-1, 4, 56, 56] 144
Conv2d-59 [-1, 4, 56, 56] 144
Conv2d-60 [-1, 4, 56, 56] 144
Conv2d-61 [-1, 4, 56, 56] 144
Conv2d-62 [-1, 4, 56, 56] 144
Conv2d-63 [-1, 4, 56, 56] 144
Conv2d-64 [-1, 4, 56, 56] 144
Conv2d-65 [-1, 4, 56, 56] 144
Conv2d-66 [-1, 4, 56, 56] 144
Conv2d-67 [-1, 4, 56, 56] 144
Conv2d-68 [-1, 4, 56, 56] 144
Conv2d-69 [-1, 4, 56, 56] 144
Conv2d-70 [-1, 4, 56, 56] 144
Conv2d-71 [-1, 4, 56, 56] 144
Conv2d-72 [-1, 4, 56, 56] 144
Conv2d-73 [-1, 4, 56, 56] 144
Conv2d-74 [-1, 4, 56, 56] 144
Conv2d-75 [-1, 4, 56, 56] 144
Conv2d-76 [-1, 4, 56, 56] 144
Conv2d-77 [-1, 4, 56, 56] 144
Conv2d-78 [-1, 4, 56, 56] 144
Conv2d-79 [-1, 4, 56, 56] 144
Conv2d-80 [-1, 4, 56, 56] 144
Conv2d-81 [-1, 4, 56, 56] 144
Conv2d-82 [-1, 4, 56, 56] 144
Conv2d-83 [-1, 4, 56, 56] 144
BatchNorm2d-84 [-1, 128, 56, 56] 256
ReLU-85 [-1, 128, 56, 56] 0
GroupedConvBlock-86 [-1, 128, 56, 56] 0
Conv2d-87 [-1, 256, 56, 56] 32,768
BatchNorm2d-88 [-1, 256, 56, 56] 512
ReLU-89 [-1, 256, 56, 56] 0
Block-90 [-1, 256, 56, 56] 0
Identity-91 [-1, 256, 56, 56] 0
Conv2d-92 [-1, 128, 56, 56] 32,768
BatchNorm2d-93 [-1, 128, 56, 56] 256
ReLU-94 [-1, 128, 56, 56] 0
Conv2d-95 [-1, 4, 56, 56] 144
Conv2d-96 [-1, 4, 56, 56] 144
Conv2d-97 [-1, 4, 56, 56] 144
Conv2d-98 [-1, 4, 56, 56] 144
Conv2d-99 [-1, 4, 56, 56] 144
Conv2d-100 [-1, 4, 56, 56] 144
Conv2d-101 [-1, 4, 56, 56] 144
Conv2d-102 [-1, 4, 56, 56] 144
Conv2d-103 [-1, 4, 56, 56] 144
Conv2d-104 [-1, 4, 56, 56] 144
Conv2d-105 [-1, 4, 56, 56] 144
Conv2d-106 [-1, 4, 56, 56] 144
Conv2d-107 [-1, 4, 56, 56] 144
Conv2d-108 [-1, 4, 56, 56