R-CNN
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
class RCNN(nn.Module):
def __init__(self, num_classes):
super(RCNN, self).__init__()
self.backbone = torchvision.models.resnet50(pretrained=True)
self.backbone.fc = nn.Linear(2048, 512)
self.classifier = nn.Linear(512, num_classes)
self.regessor = nn.Linear(512, 4)
def forward(self, x):
features = self.backbone(x)
features = features.view(features.size(0), -1)
cls_output = self.classifier(features)
reg_output = self.regessor(features)
return cls_output, reg_output
class RCNNLoss(nn.Module):
def __init__(self):
super(RCNNLoss, self).__init__()
self.cls_loss = nn.CrossEntropyLoss()
self.reg_loss = nn.SmoothL1Loss()
def forward(self, cls_output, reg_output, cls_target, reg_target):
cls_loss = self.cls_loss(cls_output, cls_target)
reg_loss = self.reg_loss(reg_output, reg_target)
total_loss = cls_loss + reg_loss
return total_loss
if __name__ == "__main__":
num_classes = 10
model = RCNN(num_classes)
criterion = RCNNLoss()
img = np.random.randn(3, 224, 224)
img_tensor = torch.tensor(img).unsqueeze(0)
cls_target = torch.randint(0, num_classes, (1,)