GoogleNet网络结构:

1.首先定义一个基本卷积模块包含一个卷积层和一个Relu激活层和一个正向传播函数。
class BasicConv2d(nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super(BasicConv2d, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.relu(x)
return x
in_channels-》输入特征矩阵的深度
out_channels-》输出特征矩阵的深度。其中 self.conv = nn.Conv2d()中的out_channels也代表卷积核个数。
2.定义Inception模块:
class Inception(nn.Module):
def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
super(Inception, self).__init__()
self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)
self.branch2 = nn.Sequential(
BasicConv2d(in_channels, ch3x3red, kernel_size=1),
BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1) # 保证输出大小等于输入大小
)
self.branch3 = nn.Sequential(
BasicConv2d(in_channels, ch5x5red, kernel_size=1),
BasicConv2d(ch5x5red, ch5x5, kernel_size=5, padding=2) # 保证输出大小等于输入大小
)
self.branch4 = nn.Sequential(
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
BasicConv2d(in_channels, pool_proj, kernel_size=1)
)
def forward(self, x):
branch1 = self.branch1(x)
branch2 = self.branch2(x)
branch3 = self.branch3(x)
branch4 = self.branch4(x)
outputs = [branch1, branch2, branch3, branch4]
return torch.cat(outputs, 1)
ch1x1、ch3x3red、ch5x5red, ch5x5、pool_proj均代表输出深度或者卷积核数量。
self.branch1为第一个分支
self.branch2为第二个分支
self.branch3为第三个分支
self.branch4为第四个分支
注意:Inception模块的正向传播函数forward中的branch1、branch2、branch3、branch4是并联关系而不是串联关系。x是分别传入进去计算的,所得结果分别赋值给变量branch1、branch2、branch3、branch4。最后通过torch.cat()函数合并成一个矩阵。1代表从深度上进行拼接。
3.定义分类器模块
class InceptionAux(nn.Module):
def __init__(self, in_channels, num_classes):
super(InceptionAux, self).__init__()
self.averagePool = nn.AvgPool2d(kernel_size=5, stride=3)
self.conv = BasicConv2d(in_channels, 128, kernel_size=1) # output[batch, 128, 4, 4]
self.fc1 = nn.Linear(2048, 1024)
self.fc2 = nn.Linear(1024, num_classes)
def forward(self, x):
# aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
x = self.averagePool(x)
# aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
x = self.conv(x)
# N x 128 x 4 x 4
x = torch.flatten(x, 1)
x = F.dropout(x, 0.5,

本文详细解读了GoogleNet的网络结构,包括基本卷积模块、Inception模块和分类器的设计。Inception模块由多个分支构成,特征通过并联关系处理,最后合并。此外,还介绍了辅助分类器在训练过程中的作用。完整代码可在提供的GitHub链接中查看。
最低0.47元/天 解锁文章
3888

被折叠的 条评论
为什么被折叠?



