解读create_model方法
1. num_classes:分类数
2.backbone = resnet50_fpn_backbone()
model = FasterRCNN(backbone=backbone, num_classes=91)-》调用faster_rcnn_framework的FasterRCNN方法,传入分类数num_classes为91
3.weights_dict = torch.load("./backbone/fasterrcnn_resnet50_fpn_coco.pth")
missing_keys, unexpected_keys = model.load_state_dict(weights_dict, strict=False)-》加载预训练的权重文件
4. if len(missing_keys) != 0 or len(unexpected_keys) != 0:
print("missing_keys: ", missing_keys)
print("unexpected_keys: ", unexpected_keys)-》打印预训练权重信息
5.in_features = model.roi_heads.box_predictor.cls_score.in_features-》获取分类器输入特征的数量
6.model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)-》用一个新的头替换一个预先训练好的头
解读main方法:
1.device = torch.device(parser_data.device if torch.cuda.is_available() else "cpu")-》查看设备是否有可用的GPU
2.data_transform = {-》图像预处理函数
"train": transforms.Compose([transforms.ToTensor(),
transforms.RandomHorizontalFlip(0.5)]),-》随机水平翻转
"val": transforms.Compose([transforms.ToTensor()])
}
3.VOC_root = parser_data.data_path-》数据集根目录
4.train_data_set = VOC2012DataSet(VOC_root, data_transform["train"], True)-》加载数据集的VOC2012DataSet方法
5.加载训练数据集
train_data_loader = torch.utils.data.DataLoader(train_data_set,
batch_size=1,
shuffle=False,
num_workers=0,
collate_fn=utils.collate_fn)
val_data_set = VOC2012DataSet(VOC_root, data_transform["val"], False)
val_data_set_loader = torch.utils.data.DataLoader(val_data_set,
batch_size=2,
shuffle=False,
num_workers=0,
collate_fn=utils.collate_fn) -》加载验证数据集
6. model = create_model(num_classes=21)-》创建模型分类数为20个类
7.model.to(device)-》加载模型到G