resnet50v2
前言
前面对resnet有了一定了解,主要关注其残差结构,其缓解了梯度消失的问题,是卷积神经网络的一大创举,这周主要学习resnet50v2,其是对resnet的改进。
resnet50v2简介
ResNet50V2 是对 ResNet50 的改进版本,它在保留了 ResNet 架构的基本特征的同时,进行了一些改动和优化,以提升模型性能和训练效率。以下是 ResNet50V2 相对于 ResNet50 的一些主要差别:
预激活:ResNet50V2 使用了预激活(pre-activation)结构,在每个残差块的卷积层之前先应用了批归一化和 ReLU 激活函数,而 ResNet50 没有采用这种结构。
预激活结构使得梯度在网络中传播更加稳定,有助于加速训练过程,并提高了模型的性能。
残差块改进:ResNet50V2 中的残差块进行了改进,使用了更多的 Batch Normalization 层来进一步增强模型的稳定性和收敛速度。另外,ResNet50V2 中的残差连接的实现方式与 ResNet50 稍有不同,具体体现在残差连接中的恒等映射和卷积操作之间的位置。
网络参数数量:ResNet50V2 中可能会有更多的参数,这是因为它使用了更多的 Batch Normalization 层和预激活结构,从而增加了网络的深度和复杂度。
性能:由于预激活结构和改进的残差块设计,ResNet50V2 往往比 ResNet50 具有更好的性能,包括更快的收敛速度和更高的准确率。
在一些复杂的图像分类任务中,ResNet50V2 可能会比 ResNet50 更适用。
代码实现
class Block2(nn.Module):
def __init__(self, in_channel, filters, kernel_size=3, stride=1, conv_shortcut=False):
super(Block2, self).__init__()
self.preact = nn.Sequential(
nn.BatchNorm2d(in_channel),
nn.ReLU(True)
)
self.shortcut = conv_shortcut
if self.shortcut:
self.short = nn.Conv2d(in_channel, 4 * filters, 1, stride=stride, padding=0, bias=False)
elif stride > 1:
self.short = nn.MaxPool2d(kernel_size=1, stride=stride, padding=0)
else:
self.short = nn.Identity()
self.conv1 = nn.Sequential(
nn.Conv2d(in_channel, filters, 1, stride=1, bias=False),
nn.BatchNorm2d(filters),
nn.ReLU(True)
)
self.conv2 = nn.Sequential(
nn.Conv2d(filters, filters, kernel_size, stride=stride, padding=1, bias=False),
nn.BatchNorm2d(filters),
nn.ReLU(True)
)
self.conv3 = nn.Conv2d(filters, 4 * filters, 1, stride=1, bias=False)
def forward(self, x):
x1 = self.preact(x)
if self.shortcut:
x2 = self.short(x1)
else:
x2 = self.short(x)
x1 = self.conv1(x1)
x1 = self.conv2(x1)
x1 = self.conv3(x1)
x = x1 + x2
return x
class Stack2(nn.Module):
def __init__(self, in_channel, filters, blocks, stride=2):
super(Stack2, self).__init__()
self.conv = nn.Sequential()
self.conv.add_module(str(0), Block2(in_channel, filters, conv_shortcut=True))
for i in range(1, blocks - 1):
self.conv.add_module(str(i), Block2(4 * filters, filters))
self.conv.add_module(str(blocks - 1), Block2(4 * filters, filters, stride=stride))
def forward(self, x):
x = self.conv(x)
return x
''' 构建ResNet50V2 '''
class ResNet50V2(nn.Module):
def __init__(self,
include_top=True, # 是否包含位于网络顶部的全链接层
preact=True, # 是否使用预激活
use_bias=True, # 是否对卷积层使用偏置
input_shape=[224, 224, 3],
classes=4,
pooling=None): # 用于分类图像的可选类数
super(ResNet50V2, self).__init__()
self.conv1 = nn.Sequential()
self.conv1.add_module('conv', nn.Conv2d(3, 64, 7, stride=2, padding=3, bias=use_bias, padding_mode='zeros'))
if not preact:
self.conv1.add_module('bn', nn.BatchNorm2d(64))
self.conv1.add_module('relu', nn.ReLU())
self.conv1.add_module('max_pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
self.conv2 = Stack2(64, 64, 3)
self.conv3 = Stack2(256, 128, 4)
self.conv4 = Stack2(512, 256, 6)
self.conv5 = Stack2(1024, 512, 3, stride=1)
self.post = nn.Sequential()
if preact:
self.post.add_module('bn', nn.BatchNorm2d(2048))
self.post.add_module('relu', nn.ReLU())
if include_top:
self.post.add_module('avg_pool', nn.AdaptiveAvgPool2d((1, 1)))
self.post.add_module('flatten', nn.Flatten())
self.post.add_module('fc', nn.Linear(2048, classes))
else:
if pooling == 'avg':
self.post.add_module('avg_pool', nn.AdaptiveAvgPool2d((1, 1)))
elif pooling == 'max':
self.post.add_module('max_pool', nn.AdaptiveMaxPool2d((1, 1)))
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.conv5(x)
x = self.post(x)
return x
模型验证
使用鸟类数据集
for epoch in range(epochs):
# 更新学习率(使用自定义学习率时使用)
# adjust_learning_rate(optimizer, epoch, learn_rate)
model.train()
epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, optimizer)
# scheduler.step() # 更新学习率(调用官方动态学习率接口时使用)
model.eval()
epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)
# 保存最佳模型到 best_model
if epoch_test_acc > best_acc:
best_acc = epoch_test_acc
best_model = copy.deepcopy(model)
train_acc.append(epoch_train_acc)
train_loss.append(epoch_train_loss)
test_acc.append(epoch_test_acc)
test_loss.append(epoch_test_loss)
# 获取当前的学习率
lr = optimizer.state_dict()['param_groups'][0]['lr']
template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')
print(template.format(epoch + 1, epoch_train_acc * 100, epoch_train_loss,
epoch_test_acc * 100, epoch_test_loss, lr))
结果
总结
resnet50v2是对resnet的改进,但是个人感觉resnet更为简洁,不过熟悉改造模型的过程也有助于加深理解和给自己以后改进模型提供一些启示。