智能数字图像处理之FastRCNN(pytorch)代码解读之train_resnet50_fpn.py

本文详细解读了使用PyTorch实现的智能数字图像处理中FastRCNN模型的训练过程。通过create_model方法创建模型,包括设置分类数、加载预训练权重,并替换分类头部。在main方法中,处理设备选择、数据预处理、加载数据集、定义优化器、学习率调度器,并根据用户参数进行训练和权重保存。代码还包含了训练过程的损失和mAP曲线绘制功能。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

解读create_model方法

1. num_classes:分类数

2.backbone = resnet50_fpn_backbone()
    model = FasterRCNN(backbone=backbone, num_classes=91)-》调用faster_rcnn_framework的FasterRCNN方法,传入分类数num_classes为91

3.weights_dict = torch.load("./backbone/fasterrcnn_resnet50_fpn_coco.pth")
    missing_keys, unexpected_keys = model.load_state_dict(weights_dict, strict=False)-》加载预训练的权重文件

4.    if len(missing_keys) != 0 or len(unexpected_keys) != 0:
        print("missing_keys: ", missing_keys)
        print("unexpected_keys: ", unexpected_keys)-》打印预训练权重信息

5.in_features = model.roi_heads.box_predictor.cls_score.in_features-》获取分类器输入特征的数量

6.model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)-》用一个新的头替换一个预先训练好的头

解读main方法:

1.device = torch.device(parser_data.device if torch.cuda.is_available() else "cpu")-》查看设备是否有可用的GPU

2.data_transform = {-》图像预处理函数
        "train": transforms.Compose([transforms.ToTensor(),
                                     transforms.RandomHorizontalFlip(0.5)]),-》随机水平翻转
        "val": transforms.Compose([transforms.ToTensor()])
    }

3.VOC_root = parser_data.data_path-》数据集根目录

4.train_data_set = VOC2012DataSet(VOC_root, data_transform["train"], True)-》加载数据集的VOC2012DataSet方法

5.加载训练数据集

train_data_loader = torch.utils.data.DataLoader(train_data_set,
                                                    batch_size=1,
                                                    shuffle=False,
                                                    num_workers=0,
                                                    collate_fn=utils.collate_fn)
    val_data_set = VOC2012DataSet(VOC_root, data_transform["val"], False)
    val_data_set_loader = torch.utils.data.DataLoader(val_data_set,
                                                      batch_size=2,
                                                      shuffle=False,
                                                      num_workers=0,
                                                      collate_fn=utils.collate_fn)    -》加载验证数据集

6. model = create_model(num_classes=21)-》创建模型分类数为20个类

7.model.to(device)-》加载模型到G

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值