目录
6、加载图像数据和对应的标签,并为训练和测试数据集创建了数据加载器
一、简单介绍残差神经网络
残差神经网络(Residual Neural Network,通常缩写为ResNet)是一种深度学习神经网络架构,用于解决深度神经网络训练中的梯度消失和梯度爆炸等问题。
ResNet引入了残差块(Residual Block),这是网络的基本构建单元。每个残差块包含两个主要部分:恒等映射(Identity Mapping)和残差映射(Residual Mapping)。
在一个传统的神经网络中,每个层都会通过权重和激活函数来变换输入。在ResNet中,每个残差块尝试学习的是将输入映射到期望的输出,而不是直接学习最终的映射。这是通过在残差块中引入跳跃连接(skip connection)来实现的。
跳跃连接允许网络直接将输入添加到层的输出,从而使残差块能够学习残差的变化。这意味着即使网络增加了更多的层,也不会导致梯度消失或梯度爆炸问题,因为梯度可以通过跳跃连接轻松地传播。这使得能够构建非常深的神经网络,如数百层或更多,而不会出现训练问题。
ResNet在图像分类、目标检测、语义分割等计算机视觉任务中取得了巨大成功,成为了深度学习领域的重要突破之一。这个网络架构的核心思想是通过残差连接来让网络更深,从而提高了模型性能。
二、代码实现
1、导入库
from torch.utils.data import Dataset, DataLoader
#用于创建自定义数据集和数据加载器。
import torch
from torch import nn
#PyTorch核心库,用于深度学习模型的构建和训练。
import numpy as np
#用于处理数据的数值计算库。
from torchvision import transforms
#PyTorch中用于数据增强和预处理的工具。
from PIL import Image
#Python Imaging Library,用于图像处理。
import torchvision.models as models
#包含了许多预训练的深度学习模型,包括ResNet。
2、创建ResNet-101模型,并修改全连接层:
resnet_model = models.resnet101(weights=models.ResNet101_Weights.DEFAULT)
#创建一个ResNet-101模型,并使用默认的预训练权重。
in_features = resnet_model.fc.in_features
#获取ResNet-101最后一个全连接层的输入特征数。
resnet_model.fc = nn.Linear(in_features,20)
#将ResNet-101的全连接层替换为一个新的线性层,输出维度为20。
3、确定需要更新的参数
params_to_update = []
for param in resnet_model.parameters():
if param.requires_grad == True:
params_to_update.append(param)
#使用一个for循环,遍历ResNet-101模型的所有参数,并将 requires_grad 属性为True的参数添加到 params_to_update 列表中。
4、数据转换定义
data_transforms = {
#data_transfo

本文介绍了残差神经网络(ResNet)的工作原理,如何通过残差块和跳跃连接解决深度学习中的梯度问题。随后详细展示了如何在PyTorch中实现ResNet-101模型,包括数据预处理、模型构建、参数选择、损失函数和优化器设定,以及训练和测试过程。
最低0.47元/天 解锁文章
1万+





