深入理解ResNet残差网络:从理论到实践
d2l-zh 项目地址: https://gitcode.com/gh_mirrors/d2l/d2l-zh
引言
在深度学习领域,随着网络深度的增加,模型性能往往会遇到瓶颈。传统的深度神经网络在层数增加到一定程度后,会出现"退化"现象:随着网络深度的增加,训练误差不降反升。这种现象并非由过拟合引起,而是因为深层网络难以优化。2015年,何恺明等人提出的ResNet(残差网络)通过引入"残差学习"的概念,成功解决了这一问题,并在ImageNet等多项视觉识别任务中取得了突破性成果。
残差网络的核心思想
函数类与网络表达能力
在理解ResNet之前,我们需要思考一个基本问题:为什么简单地增加网络深度会导致性能下降?
从数学角度看,我们可以将神经网络视为一个函数类$\mathcal{F}$,它包含了该网络架构(配合特定超参数)能够表示的所有函数。当我们设计更深层的网络时,实际上是扩展了这个函数类。理想情况下,更大的函数类应该包含更接近真实目标函数$f^*$的近似。
然而,关键在于函数类的嵌套关系。只有当更大的函数类完全包含较小的函数类时,增加网络深度才能保证模型表达能力严格增强。否则,更深的网络可能反而会偏离最优解。
残差学习原理
ResNet的核心创新在于提出了**残差块(Residual Block)**结构。传统网络层直接学习目标映射$H(x)$,而残差块则学习残差映射$F(x) = H(x) - x$。这样,原始映射可以表示为$H(x) = F(x) + x$。
这种设计的优势在于:
- 当最优映射接近恒等映射时,残差映射更容易学习(只需将权重趋近于零)
- 通过跳跃连接(shortcut connection)实现了信息的直接传播
- 缓解了梯度消失问题,使深层网络更容易训练
残差块实现详解
基本残差块结构
一个标准的残差块包含以下组件:
- 两个3×3卷积层,保持空间维度不变
- 每个卷积层后接批量归一化(BatchNorm)和ReLU激活
- 跳跃连接将输入直接加到第二个卷积层的输出上
- 最后的ReLU激活
当输入输出维度匹配时,跳跃连接可以直接相加;当需要改变维度时,需要通过1×1卷积调整通道数和空间分辨率。
代码实现
以下是残差块的关键实现代码(以PyTorch为例):
class Residual(nn.Module):
def __init__(self, input_channels, num_channels, use_1x1conv=False, strides=1):
super().__init__()
self.conv1 = nn.Conv2d(input_channels, num_channels,
kernel_size=3, padding=1, stride=strides)
self.conv2 = nn.Conv2d(num_channels, num_channels,
kernel_size=3, padding=1)
if use_1x1conv:
self.conv3 = nn.Conv2d(input_channels, num_channels,
kernel_size=1, stride=strides)
else:
self.conv3 = None
self.bn1 = nn.BatchNorm2d(num_channels)
self.bn2 = nn.BatchNorm2d(num_channels)
def forward(self, X):
Y = F.relu(self.bn1(self.conv1(X)))
Y = self.bn2(self.conv2(Y))
if self.conv3:
X = self.conv3(X)
Y += X
return F.relu(Y)
构建完整ResNet模型
网络架构
ResNet-18的整体架构如下:
-
初始层:
- 7×7卷积,64通道,步长2
- 3×3最大池化,步长2
- 批量归一化和ReLU激活
-
四个残差块阶段:
- 第一阶段:2个残差块,64通道
- 第二阶段:2个残差块,128通道(下采样)
- 第三阶段:2个残差块,256通道(下采样)
- 第四阶段:2个残差块,512通道(下采样)
-
全局平均池化和全连接层
模型实现
构建完整ResNet的关键在于正确处理各阶段的维度变化:
def resnet_block(input_channels, num_channels, num_residuals, first_block=False):
blk = []
for i in range(num_residuals):
if i == 0 and not first_block:
blk.append(Residual(input_channels, num_channels,
use_1x1conv=True, strides=2))
else:
blk.append(Residual(num_channels, num_channels))
return blk
b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
nn.BatchNorm2d(64), nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))
b3 = nn.Sequential(*resnet_block(64, 128, 2))
b4 = nn.Sequential(*resnet_block(128, 256, 2))
b5 = nn.Sequential(*resnet_block(256, 512, 2))
net = nn.Sequential(b1, b2, b3, b4, b5,
nn.AdaptiveAvgPool2d((1,1)),
nn.Flatten(), nn.Linear(512, 10))
模型训练与评估
在Fashion-MNIST数据集上训练ResNet-18,我们可以观察到:
- 相比传统深层网络,ResNet训练更稳定
- 即使网络深度增加,也不会出现性能退化
- 收敛速度较快,最终准确率较高
典型训练参数:
- 学习率:0.05
- 批量大小:256
- 训练周期:10
总结与展望
ResNet的创新之处在于:
- 通过残差学习解决了深层网络退化问题
- 结构简单且易于扩展(ResNet-18到ResNet-152)
- 为后续网络设计提供了新思路
残差连接的思想不仅适用于卷积网络,也被成功应用于循环神经网络、Transformer等架构。理解ResNet的工作原理对于掌握现代深度学习模型设计至关重要。
思考题
- 残差连接与传统的跳跃连接(如Inception中的连接)有何本质区别?
- 为什么在残差块中使用两个3×3卷积而不是一个更大的卷积核?
- 如何将残差思想应用到自然语言处理任务中?
- 当网络极深时(如ResNet-152),残差块的设计可能需要哪些调整?
通过深入理解ResNet的设计哲学,我们可以更好地应用和创新深度学习模型,解决更复杂的实际问题。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考