YOLOv3 源码学习笔记

这篇博客详细介绍了YOLO模型的前向传播过程,包括日志管理、THOP库的使用来计算FLOPs,以及模型结构中的ModuleList、permute和contiguous函数的应用。此外,还展示了如何初始化模型参数,并在训练和推断模式下操作。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1.yolo.py 

logger = logging.getLogger(__name__)

logger:日志对象,logging模块中最基础的对象,用logging.getLogger(name)方法进行初始化,name可以不填。通常logger的名字我们对应模块名,如聊天模块、数据库模块、验证模块等。

参考连接:(2条消息) logging.getLogger(logger)_u011159607的博客-优快云博客_logging.getlogger

import thop  # for FLOPS computation

THOP 是 PyTorch 非常实用的一个第三方库,可以统计模型的 FLOPs 和参数量 

参考链接:CNN 模型所需的计算力flops是什么?怎么计算? - 知乎 (zhihu.com)

 (2条消息) THOP: 统计 PyTorch 模型的 FLOPs 和参数量_yiran103的专栏-优快云博客

class Detection

self.register_buffer('anchors', a)  # shape(nl,na,2)

参考链接:(2条消息) torch.nn.Module.register_buffer(name, tensor)_敲代码的小风-优快云博客

 

self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch)  # output conv

参考链接:详解PyTorch中的ModuleList和Sequential - 知乎 (zhihu.com) 

    def forward(self, x):
        # x = x.copy()  # for profiling
        z = []  # inference output
        self.training |= self.export
        for i in range(self.nl):
            x[i] = self.m[i](x[i])  # conv
            bs, _, ny, nx = x[i].shape  # x(bs,255,20,20) to x(bs,3,20,20,85)
            x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()

            if not self.training:  # inference
                if self.grid[i].shape[2:4] != x[i].shape[2:4]:
                    self.grid[i] = self._make_grid(nx, ny).to(x[i].device)

                y = x[i].sigmoid()
                y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i]  # xy
                y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh
                z.append(y.view(bs, -1, self.no))

        return x if self.training else (torch.cat(z, 1), x)

 
x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()

permute函数将tensor的维度换位

contiguous()一般在permute()等改变形状和计算返回的tensor后面,因为改变形状后,有的tensor并不是占用一整块内存,而是由不同的数据块组成,而tensor的view()操作依赖于内存是整块的,这时只需要执行contiguous()这个函数,把tensor变成在内存中连续分布的形式。

参考链接:(2条消息) PyTorch中permute的用法_york1996的博客-优快云博客_permute函数

return x if self.training else (torch.cat(z, 1), x)#??

torch.cat是将两个张量(tensor)拼接在一起,cat是concatnate的意思,即拼接,联系在一起。使用torch.cat((A,B),dim)时,除拼接维数dim数值可不同外其余维数数值需相同,方能对齐。

dim=0时,就是(竖着拼)

dim=1时,就是(横着拼)

    def _make_grid(nx=20, ny=20):
        yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
        return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()

参考链接:

(2条消息) torch.meshgrid()函数解析_小娜美要努力努力的博客-优快云博客_torch.meshgrid

【Pytorch】torch.stack()的使用 - 知乎 (zhihu.com)

class model

    def __init__(self, cfg='yolov3.yaml', ch=3, nc=None, anchors=None):  # model, input channels, number of classes
        super(Model, self).__init__()
        if isinstance(cfg, dict):
            self.yaml = cfg  # model dict
        else:  # is *.yaml
            import yaml  # for torch hub
            self.yaml_file = Path(cfg).name
            with open(cfg) as f:
                self.yaml = yaml.load(f, Loader=yaml.SafeLoader)  # model dict

        # Define model
        ch = self.yaml['ch'] = self.yaml.get('ch', ch)  # input channels
        if nc and nc != self.yaml['nc']:
            logger.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
            self.yaml['nc'] = nc  # override yaml value
        if anchors:
            logger.info(f'Overriding model.yaml anchors with anchors={anchors}')
            self.yaml['anchors'] = round(anchors)  # override yaml value
        self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch])  # model, savelist
        self.names = [str(i) for i in range(self.yaml['nc'])]  # default names
        # print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])

        # Build strides, anchors
        m = self.model[-1]  # Detect()
        if isinstance(m, Detect):
            s = 256  # 2x min stride
            m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))])  # forward
            m.anchors /= m.stride.view(-1, 1, 1)
            check_anchor_order(m)
            self.stride = m.stride
            self._initialize_biases()  # only run once
            # print('Strides: %s' % m.stride.tolist())

        # Init weights, biases
        initialize_weights(self)
        self.info()
        logger.info('')

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值