目录
一、背景简介
ResNet152是深度残差网络(Deep Residual Network)的一种,它是一个非常强大的图像分类模型。该网络由微软研究院提出,其核心思想是通过引入残差模块和瓶颈结构,使得模型可以在更深的层次上有效地学习图像特征,从而避免优化函数陷入局部最优解和梯度消失的问题。基本原理是通过引入残差模块和瓶颈结构,使得模型可以在更深的层次上有效地学习图像特征,从而避免优化函数陷入局部最优解和梯度消失的问题。
ResNet的设计思路是将输入特征通过一系列的卷积层、池化层等操作后,再将其与原始输入特征进行求和,这样就可以保留更多的原始信息,避免信息在多层网络中传递时被丢失。这种残差连接的设计使得网络在训练时可以跳过一些不必要的卷积操作,从而减少计算量和模型大小,同时提高模型的性能。
二、ResNet152残差网络应用实践
ResNet152残差网络在应用实践中表现出色,被广泛应用于各种计算机视觉任务中,如图像分类、目标检测、语义分割等,本次将以图像分类方面为案例进行实践操作演示。
1、定义ResNet152模型
resnet_model = models.resnet152(weights=models.ResNet152_Weights.DEFAULT) #创建ResNet152模型实例
for param in resnet_model.parameters():#冻结模型参数,只剩全连接层
param.requires_grad = False
in_features = resnet_model.fc.in_features
resnet_model.fc = nn.Linear(in_features,20)
params_to_update = []
for param in resnet_model.parameters():#遍历ResNet152模型的所有参数
if param.requires_grad == True:
params_to_update.append(param)
具体参数详解如下:
resnet_model = models.resnet152(weights=models.ResNet152_Weights.DEFAULT):这行代码创建了一个ResNet152模型实例。models.resnet152是PyTorch中预定义的ResNet152模型函数。weights=models.ResNet152_Weights.DEFAULT指定了使用默认的预训练权重。for param in resnet_model.parameters(): param.requires_grad = False:这行代码将ResNet152模型中所有参数的requires_grad属性设置为False。这意味着在反向传播(backpropagation)时,这些参数不会更新。这通常用于冻结模型的某些部分,以防止在训练过程中改变其参数。in_features = resnet_model.fc.in_features:这行代码获取ResNet152模型的最后全连接层的输入特征数量。resnet_model.fc = nn.Linear(in_features,20):这行代码将ResNet152模型的最后全连接层替换为一个新的全连接层,该全连接层的输入特征数量与原来的相同(in_features),输出特征数量为20。params_to_update = [] for param in resnet_model.parameters(): if param.requires_grad == True: params_to_update.append(param):这里创建了一个空列表params_to_update,然后遍历ResNet152模型的所有参数。对于每个需要更新的参数(即param.requires_grad == True),将其添加到params_to_update列表中。
通过上述过程创建了一个ResNet152模型,冻结了其所有参数(不更新这些参数),然后修改了最后的全连接层,以使其输出20维的向量,而不是原始的类别向量。同时,准备了一个列表,用于存储需要更新的参数,以便在训练过程中使用。
2、输入图像预处理
data_transforms = { #也可以使用PIL库,smote 人工拟合出来数据
'train':
transforms.Compose([
transforms.Resize([300,300]), #是图像变换大小
transforms.RandomRotation(45),#随机旋转,-45到45度之间随机选
transforms.CenterCrop(256),#从中心开始裁剪[256,256]
transforms.RandomHorizontalFlip(p=0.5),#随机水平翻转 选择一个概率概率
transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转
transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),#参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
transforms.RandomGrayscale(p=0.1),#概率转换成灰度率,3通道就是R=G=B
transf
ResNet152残差网络图像分类实践

本文介绍了ResNet152残差网络,它能避免优化函数陷入局部最优和梯度消失问题。以图像分类为例进行实践,包括定义ResNet152模型、输入图像预处理、定义自定义数据集类、检测计算设备及定义优化器、模型训练等步骤,最后展示完整代码和结果。
最低0.47元/天 解锁文章
326

被折叠的 条评论
为什么被折叠?



