import torch
import torch.nn as nn
class BasicBlockV1b(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None,
previous_dilation=1, norm_layer=nn.BatchNorm2d):
super(BasicBlockV1b, self).__init__()
# 定义了一个二维卷积用于提取特征
self.conv1 = nn.Conv2d(inplanes, planes, 3, stride,
dilation, dilation, bias=False)
# 定义批归一化层,用于加速训练过程并提高模型稳定性
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(True)
self.conv2 = nn.Conv2d(planes, planes, 3, 1, previous_dilation,
dilation=previous_dilation, bias=False)
self.bn2 = norm_layer(planes)
# 设置成下采样,用于调整残差块的输入与输出维度
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
# 检查是否有下采样操作,通常是卷积层和批归一化层的组合,用于调整identity的维度与out匹配
if self.downsample is not None:
identity = self.downsample(x)
# 将处理过的输出与跳跃连接输出相加,能减少梯度消失问题
out += identity
out = self.relu(out)
return out
# 定义更复杂的残差块
class BottleneckV1b(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None,
previous_dilation=1, norm_layer=nn.BatchNorm2d):
super(BottleneckV1b, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
self.bn1 = norm_layer(planes)
self.conv2 = nn.Conv2d(planes, planes,
深度学习——残差网络ResNet基础结构理解+注释
于 2024-01-06 23:15:30 首次发布