预测头
ultralytics/nn/modules/head.py
class OBB(Detect):
"""YOLOv8 OBB detection head for detection with rotation models."""
def __init__(self, nc=80, ne=1, ch=()):
"""Initialize OBB with number of classes `nc` and layer channels `ch`."""
super().__init__(nc, ch)
self.ne = ne # number of extra parameters
c4 = max(ch[0] // 4, self.ne)
self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.ne, 1)) for x in ch)
def forward(self, x):
"""Concatenates and returns predicted bounding boxes and class probabilities."""
bs = x[0].shape[0] # batch size
angle = torch.cat([self.cv4[i](x[i]).view(bs, self.ne, -1) for i in range(self.nl)], 2) # OBB theta logits
# NOTE: set `angle` as an attribute so that `decode_bboxes` could use it.
angle = (angle.sigmoid() - 0.25) * math.pi # [-pi/4, 3pi/4]
# angle = angle.sigmoid() * math.pi / 2 # [0, pi/2]
if not self.training:
self.angle = angle
x = Detect.forward(self, x)
if self.training:
return x, angle
# return torch.cat([x, angle], 1) if self.export else (torch.cat([x[0], angle], 1), (x[1], angle))
return torch.cat([x, angle], 1).permute(0, 2, 1) if self.export else (torch.cat([x[0], angle], 1), (x[1], angle))
forward 输入值

self.cv4网路结构
ModuleList(
(0): Sequential(
(0): Conv(
(conv): Conv2d(64, 16, kernel_size=(3, 3), stride=(

最低0.47元/天 解锁文章
1115

被折叠的 条评论
为什么被折叠?



