3.2 YOLOv3
为了节省篇幅,这一节只挑部分贴代码,更多的代码和包的导入请参考 附件\model\yolo.py。(参考:models: Models of MindSpore - Gitee.com)
class YoloBlock 是Darknet53 输出后的处理模块,包括YOLOv3结构图(参见 YOLOv3人体目标检测模型实现(一))中的 Convolutional Set 以及后面的卷积:
class YoloBlock(nn.Cell):
"""
YoloBlock for YOLOv3.
Args:
in_channels: Integer. Input channel.
out_chls: Integer. Middle channel.
out_channels: Integer. Output channel.
Returns:
Tuple, tuple of output tensor,(f1,f2,f3).
Examples:
YoloBlock(1024, 512, 255)
"""
def __init__(self, in_channels, out_chls, out_channels):
super(YoloBlock, self).__init__()
out_chls_2 = out_chls*2
self.conv0 = _conv_bn_relu(in_channels, out_chls, ksize=1)
self.conv1 = _conv_bn_relu(out_chls, out_chls_2, ksize=3)
self.conv2 = _conv_bn_relu(out_chls_2, out_chls, ksize=1)
self.conv3 = _conv_bn_relu(out_chls, out_chls_2, ksize=3)
self.conv4 = _conv_bn_relu(out_chls_2, out_chls, ksize=1)
self.conv5 = _conv_bn_relu(out_chls, out_chls_2, ksize=3)
self.conv6 = nn.Conv2d(out_chls_2, out_channels, kernel_size=1, stride=1, has_bias=True)
def construct(self, x):
c1 = self.conv0(x)
c2 = self.conv1(c1)
c3 = self.conv2(c2)
c4 = self.conv3(c3)
c5 = self.conv4(c4)
c6 = self.conv5(c5)
out = self.conv6(c6)
return c5, out
class YOLOv3 则将主干网络和 YoloBlock 组合起来(包含上采样)成为结构图中完整的YOLOv3模型:
class YOLOv3(nn.Cell):
"""
YOLOv3 Network.
Note:
backbone = darknet53
Args:
backbone_shape: List. Darknet output channels shape.
backbone: Cell. Backbone Network.
out_channel: Integer. Output channel.
Returns:
Tensor, output tensor.
Examples:
YOLOv3(backbone_shape=[64, 128, 256, 512, 1024]
backbone=darknet53(),
out_channel=255)
"""
def __init__(self, backbone_shape, backbone, out_channel):
super(YOLOv3, self).__init__()
self.out_channel = out_channel
self.backbone = backbone
self.backblock0 = YoloBlock(backbone_shape[-1], out_chls=backbone_shape[-2], out_channels=out_channel)
self.conv1 = _conv_bn_relu(in_channel=backbone_shape[-2], out_channel=backbone_shape[-2]//2, ksize=1)
self.backblock1 = YoloBlock(in_channels=backbone_shape[-2]+backbone_shape[-3],
out_chls=backbone_shape[-3],