R-CNN、Fast R-CNN、Faster R-CNN ——简单实现

深度解析R-CNN系列:从基础到进阶的物体检测模型,

R-CNN

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import numpy as np

class RCNN(nn.Module):
    def __init__(self, num_classes):
        super(RCNN, self).__init__()
        # 使用预训练的 ResNet-50 作为基础网络
        self.backbone = torchvision.models.resnet50(pretrained=True)
        # 替换 ResNet-50 的最后一层全连接层
        self.backbone.fc = nn.Linear(2048, 512)
        # 分类器和回归器
        self.classifier = nn.Linear(512, num_classes)
        self.regessor = nn.Linear(512, 4)  # 4 维向量表示边界框坐标

    def forward(self, x):
        # 基础网络前向传播
        features = self.backbone(x)
        # flatten
        features = features.view(features.size(0), -1)
        # 分类和回归
        cls_output = self.classifier(features)
        reg_output = self.regessor(features)
        return cls_output, reg_output

# 定义损失函数
class RCNNLoss(nn.Module):
    def __init__(self):
        super(RCNNLoss, self).__init__()
        self.cls_loss = nn.CrossEntropyLoss()  # 分类损失
        self.reg_loss = nn.SmoothL1Loss()      # 回归损失

    def forward(self, cls_output, reg_output, cls_target, reg_target):
        # 分类损失
        cls_loss = self.cls_loss(cls_output, cls_target)
        # 回归损失
        reg_loss = self.reg_loss(reg_output, reg_target)
        # 总损失
        total_loss = cls_loss + reg_loss
        return total_loss

# 使用示例
if __name__ == "__main__":
    # 定义模型和损失函数
    num_classes = 10
    model = RCNN(num_classes)
    criterion = RCNNLoss()

    # 假设输入图像 img 是一个 numpy 数组
    img = np.random.randn(3, 224, 224)  # 示例输入图像
    img_tensor = torch.tensor(img).unsqueeze(0)  # 增加 batch 维度

    # 假设分类和回归的目标分别是 cls_target 和 reg_target
    cls_target = torch.randint(0, num_classes, (1,)
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值