1.yolo.py
logger = logging.getLogger(__name__)
logger:日志对象,logging模块中最基础的对象,用logging.getLogger(name)方法进行初始化,name可以不填。通常logger的名字我们对应模块名,如聊天模块、数据库模块、验证模块等。
参考连接:(2条消息) logging.getLogger(logger)_u011159607的博客-优快云博客_logging.getlogger
import thop # for FLOPS computation
THOP 是 PyTorch 非常实用的一个第三方库,可以统计模型的 FLOPs 和参数量
参考链接:CNN 模型所需的计算力flops是什么?怎么计算? - 知乎 (zhihu.com)
(2条消息) THOP: 统计 PyTorch 模型的 FLOPs 和参数量_yiran103的专栏-优快云博客
class Detection
self.register_buffer('anchors', a) # shape(nl,na,2)
参考链接:(2条消息) torch.nn.Module.register_buffer(name, tensor)_敲代码的小风-优快云博客
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
参考链接:详解PyTorch中的ModuleList和Sequential - 知乎 (zhihu.com)
def forward(self, x):
# x = x.copy() # for profiling
z = [] # inference output
self.training |= self.export
for i in range(self.nl):
x[i] = self.m[i](x[i]) # conv
bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
if not self.training: # inference
if self.grid[i].shape[2:4] != x[i].shape[2:4]:
self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
y = x[i].sigmoid()
y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
z.append(y.view(bs, -1, self.no))
return x if self.training else (torch.cat(z, 1), x)
x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
permute函数将tensor的维度换位
contiguous()一般在permute()等改变形状和计算返回的tensor后面,因为改变形状后,有的tensor并不是占用一整块内存,而是由不同的数据块组成,而tensor的view()
操作依赖于内存是整块的,这时只需要执行contiguous()
这个函数,把tensor变成在内存中连续分布的形式。
参考链接:(2条消息) PyTorch中permute的用法_york1996的博客-优快云博客_permute函数
return x if self.training else (torch.cat(z, 1), x)#??
torch.cat是将两个张量(tensor)拼接在一起,cat是concatnate的意思,即拼接,联系在一起。使用torch.cat((A,B),dim)时,除拼接维数dim数值可不同外其余维数数值需相同,方能对齐。
dim=0时,就是(竖着拼)
dim=1时,就是(横着拼)
def _make_grid(nx=20, ny=20):
yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
参考链接:
(2条消息) torch.meshgrid()函数解析_小娜美要努力努力的博客-优快云博客_torch.meshgrid
【Pytorch】torch.stack()的使用 - 知乎 (zhihu.com)
class model
def __init__(self, cfg='yolov3.yaml', ch=3, nc=None, anchors=None): # model, input channels, number of classes
super(Model, self).__init__()
if isinstance(cfg, dict):
self.yaml = cfg # model dict
else: # is *.yaml
import yaml # for torch hub
self.yaml_file = Path(cfg).name
with open(cfg) as f:
self.yaml = yaml.load(f, Loader=yaml.SafeLoader) # model dict
# Define model
ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
if nc and nc != self.yaml['nc']:
logger.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
self.yaml['nc'] = nc # override yaml value
if anchors:
logger.info(f'Overriding model.yaml anchors with anchors={anchors}')
self.yaml['anchors'] = round(anchors) # override yaml value
self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
self.names = [str(i) for i in range(self.yaml['nc'])] # default names
# print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
# Build strides, anchors
m = self.model[-1] # Detect()
if isinstance(m, Detect):
s = 256 # 2x min stride
m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
m.anchors /= m.stride.view(-1, 1, 1)
check_anchor_order(m)
self.stride = m.stride
self._initialize_biases() # only run once
# print('Strides: %s' % m.stride.tolist())
# Init weights, biases
initialize_weights(self)
self.info()
logger.info('')