解读create_model方法:
1.backbone = resnet50_fpn_backbone()-》默认使用resnet50_fpn_model的resnet50_fpn_backbone方法
2. model = FasterRCNN(backbone=backbone, num_classes=num_classes)-》调用FasterRCNN建立模型
3.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")-》调用GPU
print(device)
4.model = create_model(num_classes=21)-》调用分类数
5.train_weights = "./save_weights/model.pth"-》载入训练权重
model.load_state_dict(torch.load(train_weights)["model"])
model.to(device)
6.category_index = {}
try:
json_file = open('./pascal_voc_classes.json', 'r')
class_dict = json.load(json_file)
category_index = {v: k for k, v in class_dict.items()}
except Exception as e:
print(e)
exit(-1)-》载入分类的json文件
7.original_img = Image.open("./test.jpg")-》载入测试文件
8.data_t