import torch
import torch.nn as nn
import torchvision.models as models
class VGGBase(nn.Module):
def __init__(self):
super(VGGBase, self).__init__()
vgg = models.vgg16(pretrained=True)
self.features = nn.Sequential(*list(vgg.features.children())[:-2])
# 移除最后两层池化层
def forward(self, x):
return self.features(x)
class ExtraLayers(nn.Module):
def __init__(self):
super(ExtraLayers, self).__init__()
self.conv6 = nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6)
self.conv7 = nn.Conv2d(1024, 1024, kernel_size=1)
self.conv8_1 = nn.Conv2d(1024, 256, kernel_size=1)
self.conv8_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
self.conv9_1 = nn.Conv2d(512, 128, kernel_size=1)
self.conv9_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
def forward(self, x):
x = nn.ReLU(inplace=True)(self.conv6(x))
x = nn.ReLU(inplace=True)(self.conv7(x))
x = nn.ReLU(inplace=True)(self.conv8_1(x))
x = nn.ReLU(inplace=True)(self.conv8_2(x))
x = nn.ReLU(inplace=True)(self.conv9_1(x))
x = nn.ReLU(inplace=True)(self.conv9_2(x))
return x
class PredictionLayers(nn.Module):
def __init__(self, num_classes):
super(PredictionLayers, self).__init__()
self.num_classes = num_classes
self.loc_layers = nn.ModuleList()
self.conf_layers = nn.ModuleList()
# 为每个特征图添加分类和回归层
self.loc_layers.append(nn.Conv2d(512, 4 * 4, kernel_size=3, padding=1))
self.conf_layers.append(nn.Conv2d(512, 4 * num_classes, kernel_size=3, padding=1))
self.loc_layers.append(nn.Conv2d(1024, 6 * 4, kernel_size=3, padding=1))
self.conf_layers.append(nn.Conv2d(1024, 6 * num_classes, kernel_size=3, padding=1))
self.loc_layers.append(nn.Conv2d(512, 6 * 4, kernel_size=3, padding=1))
self.conf_layers.append(nn.Conv2d(512, 6 * num_classes, kernel_size=3, padding=1))
self.loc_layers.append(nn.Conv2d(256, 6 * 4, kernel_size=3, padding=1))
self.conf_layers.append(nn.Conv2d(256, 6 * num_classes, kernel_size=3, padding=1))
def forward(self, features):
loc_preds = []
conf_preds = []
for (x, l, c) in zip(features, self.loc_layers, self.conf_layers):
loc_preds.append(l(x).permute(0, 2, 3, 1).contiguous().view(x.size(0), -1, 4))
conf_preds.append(c(x).permute(0, 2, 3, 1).contiguous().view(x.size(0), -1, self.num_classes))
loc_preds = torch.cat(loc_preds, 1)
conf_preds = torch.cat(conf_preds, 1)
return loc_preds, conf_preds
class SSD(nn.Module):
def __init__(self, num_classes):
super(SSD, self).__init__()
self.num_classes = num_classes
self.base_net = VGGBase()
self.extra_layers = ExtraLayers()
self.prediction_layers = PredictionLayers(num_classes)
def forward(self, x):
base_features = self.base_net(x)
extra_features = self.extra_layers(base_features)
features = [base_features, extra_features]
loc_preds, conf_preds = self.prediction_layers(features)
return loc_preds, conf_preds
# 实例化模型
num_classes = 21 # 假设有20个目标类别,加1个背景类别
model = SSD(num_classes)
# 输入图像张量
images = torch.randn(8, 3, 300, 300) # 8张300x300的图像
model.eval()
# 前向传递
with torch.no_grad():
loc_preds, conf_preds = model(images)
print(loc_preds.shape) # 打印形状
print(conf_preds.shape)
SSD单发多框检测—基础结构
于 2024-05-17 22:39:06 首次发布