yolov1 net

在这里插入图片描述

class Convention(nn.Module):
    def __init__(self, in_channels, out_channels, conv_size, conv_stride, padding):
        super(Convention, self).__init__()
        self.Conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, conv_size, conv_stride, padding),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU()
        )

    def forward(self, x):
        return self.Conv(x)


class YOLO_V1(nn.Module):

    def __init__(self, B=2, Classes_Num=20):
        super(YOLO_V1, self).__init__()
        self.B = B   #每个框box数量
        self.Classes_Num = Classes_Num #class num

        self.Conv_448 = nn.Sequential(
            Convention(3, 64, 7, 2, 3), #[10, 64, 224, 224]
            nn.MaxPool2d(2, 2), #[10, 64, 112, 112]
        )

        self.Conv_112 = nn.Sequential(
            Convention(64, 192, 3, 1, 1), #[10, 192, 112, 112]
            nn.MaxPool2d(2, 2), #[10, 192, 56, 56]
        )

        self.Conv_56 = nn.Sequential(
            Convention(192, 128, 1, 1, 0), #[10, 128, 56, 56]
            Convention(128, 256, 3, 1, 1), #[10, 256, 56, 56]
            Convention(256, 256, 1, 1, 0), #[10, 256, 56, 56]
            Convention(256, 512, 3, 1, 1), #[10, 512, 56, 56]
            nn.MaxPool2d(2, 2), #[10, 512, 28, 28]
        )

        self.Conv_28 = nn.Sequential(
            Convention(512, 256, 1, 1, 0), #[10, 256, 28, 28]
            Convention(256, 512, 3, 1, 1), #[10, 512, 28, 28]
            Convention(512, 256, 1, 1, 0), #[10, 256, 28, 28]
            Convention(256, 512, 3, 1, 1), #[10, 512, 28, 28]
            Convention(512, 256, 1, 1, 0), #[10, 256, 28, 28]
            Convention(256, 512, 3, 1, 1), #[10, 512, 28, 28]
            Convention(512, 256, 1, 1, 0), #[10, 256, 28, 28]
            Convention(256, 512, 3, 1, 1), #[10, 512, 28, 28]
            Convention(512, 512, 1, 1, 0), #[10, 512, 28, 28]
            Convention(512, 1024, 3, 1, 1), #[10, 1024, 28, 28]
            nn.MaxPool2d(2, 2), #[10, 1024, 14, 14]
        )

        self.Conv_14 = nn.Sequential(
            Convention(1024, 512, 1, 1, 0), #[10, 512, 14, 14]
            Convention(512, 1024, 3, 1, 1), #[10, 1024, 14, 14]
            Convention(1024, 512, 1, 1, 0), #[10, 512, 14, 14]
            Convention(512, 1024, 3, 1, 1), #[10, 1024, 14, 14]
            Convention(1024, 1024, 3, 1, 1),# [10, 1024, 14, 14]
            Convention(1024, 1024, 3, 2, 1),# [10, 1024, 14, 14]
        )
        self.Conv_7 = nn.Sequential(
            Convention(1024, 1024, 3, 1, 1), #[10, 1024, 14, 14]
            Convention(1024, 1024, 3, 1, 1), #[10, 1024, 14, 14]
        )

        self.Fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(7 * 7 * 1024, 4096),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(4096, 7 * 7 * (B * 5 + Classes_Num)),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.Conv_448(x)
        x = self.Conv_112(x)
        x = self.Conv_56(x)
        x = self.Conv_28(x)
        x = self.Conv_14(x)
        x = self.Conv_7(x) # [10, 1024, 7, 7]
        # batch_size * channel * height * weight -> batch_size * height * weight * channel
        x = x.permute(0, 2, 3, 1).contiguous()
        x = x.view(-1, 7 * 7 * 1024) 
        x = self.Fc(x) #[10, 1470]
        x = x.view((-1, 7, 7, (self.B * 5 + self.Classes_Num)))  #[10, 7, 7, 30]
        return x


if __name__ == '__main__':
    input = torch.randn(10,3,448,448)
    model = YOLO_V1()
    output = model(input)
    print(output.shape)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值