centerNet pytorch复现

论文题目:Objects as Points

论文地址:https://arxiv.org/pdf/1904.07850.pdf

官方代码:https://github.com/xingyizhou/CenterNet

下面是我写的代码

centerNet.py

 

import torch
from torch import nn
import torch.nn.functional as f
import torchvision.models as models
import numpy as np

"""
这个文件是centerNet的网络结构
"""


# 预训练模型的路径
BACKBONE = "G:/工作空间/预训练模型/resnet18-5c106cde.pth"

class SepConv(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size=3, stride=1, padding=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channel, in_channel,kernel_size,stride,padding, groups=in_channel)
        self.conv2 = nn.Conv2d(in_channel, out_channel, kernel_size=1,stride=1,padding=0)

    def forward(self, input):
        x = self.conv1(input)
        x = self.conv2(x)
        return x


class CenterNet(nn.Module):
    # backbone是预训练模型的路径
    # class_num是分类数量,voc数据集中分类数量是20
    # feature是上采样之后卷积层的通道数
    def __init__(self, backbone=None, class_num=20):
        super(CenterNet, self).__init__()
        if(backbone==None):
            self.Backbone = BACKBONE
        else:
            self.Backbone = backbone

        self.backbone = models.resnet18(pretrained=False)
        self.backbone.load_state_dict(torch.load(self.Backbone))
        self.softmax = nn.Softmax(dim=1)
        # [1,3,500,500] -> [1,256,32,32]
        self.stage1 = nn.Sequential(*list(self.backbone.children())[:-3])

        """
        # [1,64,125,125] -> [1,128,63,63]
        self.stage2 = nn.Sequential(list(backbone.children())[-5])
        # [1,128,63,63] -> [1,256,32,32]
        self.stage3 = nn.Sequential(list(backbone.children())[-4])
        """

        # 改变通道数
        self.conv1 = nn.Conv2d(256, 128, kernel_size=1)
        self.conv2 = nn.Conv2d(128, 64, kernel_size=1)

        batchNorm_momentum = 0.1
        self.block = nn.Sequential(
            SepConv(64, 64, kernel_size=3, padding=1, stride=1),
            nn.BatchNorm2d(64, momentum= batchNorm_momentum),
            nn.ReLU(),
        )
        # head的内容
        self.head = nn.Sequential(
            self.block,
            self.block,
            self.block,
            self.block
        )
        # 分类预测
        self.head_cls = nn.Conv2d(64, class_num, kernel_size=3, padding=1, stride=1)
        # 偏移量修正预测
        self.head_offset = nn.Conv2d(64, 2, kernel_size=3, padding=1, stride=1)
        # 回归框大小预测
        self.head_size = nn.Conv2d(64, 2, kernel_size=3, padding=1, stride=1)


    # 上采样,mode参数默认的是"nearest",使用mode="bilinear"的时候会有warning
    def upsampling(self, src, width, height, mode="nearest"):
        # target的形状举例 torch.Size([1, 256, 50, 64])
        return f.interpolate(src, size=[width, height], mode=mode)

    def forward(self, input):
        output = self.stage1(input)
        # 将通道数由256变为128
        output = self.conv1(output)
        width = input.shape[2] // 8
        height = input.shape[3] // 8
        output = self.upsampling(output, width, height)
        # 将通道数由128变为64
        output = self.conv2(output)
        width = input.shape[2] // 4
        height = input.shape[3] // 4
        output = self.upsampling(output, width, height)
        output = self.head(output)
        # 分类预测
        classes = self.head_cls(output)
        # 偏移量预测
        offset = self.head_offset(output)
        # 回归框大小预测
        size = self.head_size(output)
        # 由于分类值输出在[0,1]之间,所以需要使用sigmoid函数
        # classes = nn.Sigmoid()(classes)
        # 使用softmax函数
        classes = self.softmax(classes)
        # 回归值为正
        size = torch.exp(size)
        return classes, offset, size





if __name__ == "__main__":
    network = CenterNet()
    img = torch.rand(1,3,500,500)
    output = network(img)
    print(output[0])
    print(output[1])
    print(output[2])

lossFunction.py

import torch
import torch.nn as nn
import exp.voc_dataset as dataload
import exp.centerNet as network
import time
import numpy as np


"""
这个脚本是centerNet的三个损失函数
分类损失 Focal loss
校正损失 L1 loss
回归损失 L1 loss
"""

class CenterNetLoss(nn.Module):
    # pred是网络输出结果,包含三个部分(分类信息,校正值和回归值)
    # target是数据集给定的结果,包含两个部分(bbox和分类信息)
    # candidate_num是候选点的个数,文中是100
    def __init__(self, pred=None, target=None, candidate_num=100):
        super(CenterNetLoss, self).__init__()
        # 先获取三个输入
        if(pr
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值