models/modules/asf.py
import torch
import torch.nn as nn
import torch.nn.functional as F
class ScalSeq(nn.Module):
def init(self, c1, c2, kernel_size=(1,1,1)):
super().init()
self.conv_1x1 = nn.Conv2d(c1, c2, 1, 1, 0)
self.conv3d = nn.Conv3d(1, 1, kernel_size=kernel_size, bias=False)
self.bn = nn.BatchNorm3d(1)
self.act = nn.LeakyReLU(0.1)
self.pool3d = nn.MaxPool3d((3, 1, 1))
def forward(self, x3, x4, x5): # x3: P3 (B,C,H,W), x4: P4, x5: P5 _, _, H, W = x3.shape x4 = F.interpolate(x4, size=(H, W), mode=‘nearest’) x5 = F.interpolate(x5, size=(H, W), mode=‘nearest’) x3 = self.conv_1x1(x3) x4 = self.conv_1x1(x4) x5 = self.conv_1x1(x5) # Stack as (B, C, 3, H, W) -> reshape for 3D conv x = torch.stack([x3, x4, x5], dim=2) # (B, C, 3, H, W) x = x.unsqueeze(1) # (B, 1, C, 3, H, W) x = x.transpose(2, 1) # (B, 1, 3, C, H, W) → conv3d on scale dim x = self.conv3d(x) # Apply 3D conv across scales x = self.bn(x) x = self.act(x) x = self.pool3d(x) # Reduce depth x = x.squeeze(1).squeeze(2) # Back to (B, C, H, W) return x
class ChannelAttention(nn.Module):
def init(self, c, ratio=16):
super().init()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc = nn.Sequential(
nn.Conv2d(c, c // ratio, 1, bias=False),
nn.ReLU(),
nn.Conv2d(c // ratio, c, 1, bias=False)
)
self.sigmoid = nn.Sigmoid()
def forward(self, x): avg_out = self.fc(self.avg_pool(x)) max_out = self.fc(self.max_pool(x)) return self.sigmoid(avg_out + max_out)
class LocalAttention(nn.Module):
def init(self, k=7):
super().init()
self.conv = nn.Conv2d(2, 1, kernel_size=k, padding=k//2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x): avg_out = torch.mean(x, dim=1, keepdim=True) max_out, _ = torch.max(x, dim=1, keepdim=True) attn = torch.cat([avg_out, max_out], dim=1) return self.sigmoid(self.conv(attn))
class ASF_Attention(nn.Module):
def init(self, c):
super().init()
self.ch_attn = ChannelAttention©
self.local_attn = LocalAttention()
def forward(self, x_low, x_high): # Fuse low-res semantic and high-res spatial info x_attended = x_low * self.ch_attn(x_low) x_fused = x_attended + x_high x_out = x_fused * self.local_attn(x_fused) return x_out
models/modules/nms.py
import torch
def calc_iou(box1, box2):
“”“Compute IoU between two sets of boxes.”“”
b1_x1, b1_y1, b1_x2, b1_y2 = box1.unbind(dim=-1)
b2_x1, b2_y1, b2_x2, b2_y2 = box2.unbind(dim=-1)
inter_x1 = torch.max(b1_x1[:, None], b2_x1[None, :]) inter_y1 = torch.max(b1_y1[:, None], b2_y1[None, :]) inter_x2 = torch.min(b1_x2[:, None], b2_x2[None, :]) inter_y2 = torch.min(b1_y2[:, None], b2_y2[None, :]) inter_area = (inter_x2 - inter_x1).clamp(0) * (inter_y2 - inter_y1).clamp(0) area1 = (b1_x2 - b1_x1) * (b1_y2 - b1_y1) area2 = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) union_area = area1[:, None] + area2[None, :] - inter_area return inter_area / union_area.clamp(min=1e-7)
def soft_nms(bboxes, scores, iou_threshold=0.5, sigma=0.5, score_threshold=0.001):
“”"
Apply Soft-NMS to bounding boxes.
Args:
bboxes: (N, 4) tensor
scores: (N,) tensor
iou_threshold: float
sigma: Gaussian weight decay parameter
Returns:
keep_indices: list of indices to keep
“”"
keep = []
idxs = scores.argsort(descending=True)
while len(idxs) > 0: i = idxs[0] keep.append(i.item()) if len(idxs) == 1: break ious = calc_iou(bboxes[i:i + 1], bboxes[idxs[1:]]).flatten() weights = torch.exp(-ious ** 2 / sigma) scores[idxs[1:]] *= weights idxs = idxs[1:][scores[idxs[1:]] > score_threshold] return keep
models/modules/p2_head.py
import torch
import torch.nn as nn
from ultralytics.nn.modules import C2f
from models.modules.asf import ScalSeq, ASF_Attention
class P2Head(nn.Module):
def init(self, in_channels, out_channels=128):
super().init()
self.up_conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
self.concat_conv = nn.Conv2d(out_channels * 2, out_channels, 1)
self.c2f = C2f(out_channels, out_channels, n=2)
self.scalseq = ScalSeq(out_channels, out_channels)
def forward(self, high_feat, low_feat): up_feat = self.up_conv(high_feat) # e.g., from P3 → 80x80 concat_feat = torch.cat([up_feat, low_feat], dim=1) # skip connection x = self.concat_conv(concat_feat) x = self.c2f(x) x = self.scalseq(x, x, x) # self-aggregation using ScalSeq return x
models/sod_yolo.py
from ultralytics.nn.tasks import DetectionModel
from ultralytics.nn.modules import Conv, C2f, SPPF
from models.modules.asf import ScalSeq, ASF_Attention
from models.modules.p2_head import P2Head
import torch.nn as nn
class SODYOLO(DetectionModel):
def init(self, cfg=‘yolov8m.yaml’, ch=3, nc=None):
super().init(cfg, ch, nc)
Assume original backbone outputs [c2, c3, c4, c5]
m = self.model
ch_in = [m[c].cv2.conv.out_channels for c in self.fuse_layer_index] # get P3,P4,P5 channels
Replace neck with ASF-based fusion self.neck_scalseq = ScalSeq(ch_in[0], ch_in[0]) self.neck_asf = ASF_Attention(ch_in[0]) # Add P2 head self.p2_head = P2Head(ch_in[1], ch_in[0]) # from P3 to P2 path # Modify detect layer input channels: add P2 output det_idx = self.detect_layer_index self.model[det_idx].in_channels = [ch_in[0]] * 4 # P2, P3, P4, P5 self.model[det_idx].no = self.model[det_idx].no self.model[det_idx].na = self.model[det_idx].na self.model[det_idx].nc = self.model[det_idx].nc def forward(self, x): y = [] p2_feat, p3_feat, p4_feat, p5_feat = None, None, None, None # Forward through backbone manually for i, m in enumerate(self.model): if isinstance(m, Conv): # initial conv x = m(x) elif isinstance(m, C2f): x = m(x) if i == 9: p2_feat = x # stride=4 elif i == 17: p3_feat = x # stride=8 elif i == 25: p4_feat = x # stride=16 elif i == 33: p5_feat = x # stride=32 elif isinstance(m, SPPF): x = m(x) else: break # Neck: ASF Fusion on P3/P4/P5 fused_p3 = self.neck_scalseq(p3_feat, p4_feat, p5_feat) fused_p3 = self.neck_asf(fused_p3, p3_feat) # P2 Branch p2_out = self.p2_head(p3_feat, p2_feat) # Detect Layer Input: [P2, P3, P4, P5] det_inputs = [p2_out, fused_p3, p4_feat, p5_feat] for i, d in enumerate(det_inputs): det_layer = self.model[self.detect_layer_index] y.append(det_layer([d])[0]) return y
utils/datasets.py
import os
import cv2
import torch
from torch.utils.data import Dataset
import numpy as np
from pathlib import Path
class LoadImagesAndLabels(Dataset):
def init(self, root, img_size=640, augment=False, rect=False):
super().init()
self.root = root
self.img_size = img_size
self.augment = augment
self.rect = rect
图像和标签路径 self.img_dir = os.path.join(root, ‘images’) self.label_dir = os.path.join(root, ‘labels’) self.img_files = sorted( [os.path.join(self.img_dir, x) for x in os.listdir(self.img_dir) if x.endswith((‘.jpg’, ‘.png’, ‘.jpeg’))]) self.label_files = [x.replace(‘images’, ‘labels’).replace(os.path.splitext(x)[-1], ‘.txt’) for x in self.img_files] assert len(self.img_files) == len(self.label_files), “图像与标签数量不匹配!” # 预读尺寸用于矩形训练(可选) if rect: shapes = np.array([cv2.imread(im_file).shape[:2] for im_file in self.img_files]) self.shapes = np.round(shapes * (img_size / shapes.max(axis=1).reshape(-1, 1))).astype(int) else: self.shapes = None def len(self): return len(self.img_files) def getitem(self, idx): img_path = self.img_files[idx] label_path = self.label_files[idx] # 加载图像 img = cv2.imread(img_path) assert img is not None, f"无法读取图像: {img_path}" h0, w0 = img.shape[:2] r = self.img_size / max(h0, w0) if r != 1: interp = cv2.INTER_LINEAR new_shape = (int(w0 * r), int(h0 * r)) img = cv2.resize(img, new_shape, interpolation=interp) img = img.transpose(2, 0, 1)[::-1] # HWC to CHW, BGR to RGB img = np.ascontiguousarray(img) img = torch.from_numpy(img).float() / 255.0 # 加载标签 (class_id, x, y, w, h) 归一化 labels = [] if os.path.exists(label_path): with open(label_path, ‘r’) as f: lb_lines = f.readlines() for line in lb_lines: parts = list(map(float, line.strip().split())) cls_id, x, y, w, h = parts labels.append([cls_id, x, y, w, h]) labels = torch.tensor(labels) if len(labels) > 0 else torch.zeros((0, 5)) return img, labels, img_path, idx
utils/loss.py
import torch
import torch.nn as nn
import torch.nn.functional as F
def smooth_BCE(eps=0.1):
“”“返回平滑后的正负样本标签值”“”
return 1.0 - 0.5 * eps, 0.5 * eps
class BCEBlurWithLogitsLoss(nn.Module):
def init(self, alpha=0.0):
super().init()
self.alpha = alpha
def forward(self, pred, true): loss = ((pred - true) ** 2).mean() + self.alpha * (pred - true).abs().mean() return loss
class ComputeLoss:
def init(self, model, autobalance=False):
device = next(model.parameters()).device
h = model.hyp # hyperparameters
设置损失系数 self.cls_loss = nn.CrossEntropyLoss(weight=torch.tensor(h.get(‘cls_pw’, 1.0)).to(device)) self.obj_loss = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(h.get(‘obj_pw’, 1.0)).to(device)) self.box_loss = nn.CIOUloss() self.nc = model.nc # 类别数 self.nl = len(model.model[-1].anchors) # 检测头数量 self.anchors = model.model[-1].anchors.to(device) self.balance = [4.0, 1.0, 0.4, 0.1] if self.nl == 4 else [4.0, 1.0, 0.4] # P3-P7 self.ssi = 0 # stride 8 index def call(self, preds, targets): “”" preds: list of tensor, each head output (bs, num_anchors * (nc+5), grid_h, grid_w) targets: (num_targets, 6) -> (batch_idx, cls_id, x, y, w, h) “”" lbox = torch.zeros(1, device=preds[0].device) lobj = torch.zeros(1, device=preds[0].device) lcls = torch.zeros(1, device=preds[0].device) tcls, tbox, indices, anchors = self.build_targets(preds, targets) for i, pred in enumerate(preds): b, a, gj, gi = indices[i] # target matching anchor at grid (gi, gj) n = b.shape[0] if n == 0: lobj += self.obj_loss(pred[…, 4], torch.zeros_like(pred[…, 4])).sum() * self.balance[i] continue ps = pred[b, a, gj, gi] # prediction subset # Regression Loss (xywh) pxy = ps[:, :2].sigmoid() * 2 - 0.5 pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i] pbox = torch.cat((pxy, pwh), dim=1) iou = bbox_iou(pbox, tbox[i], CIoU=True) lbox += (1.0 - iou).mean() # Objectness Loss tobj = torch.zeros_like(pred[…, 4]) tobj[b, a, gj, gi] = iou.detach().clamp(0).type(tobj.dtype) lobj += self.obj_loss(pred[…, 4], tobj) * self.balance[i] # Classification Loss if self.nc > 1: t = torch.full_like(ps[:, 5:], (1.0 - 0.5) / self.nc) t[range(n), tcls[i]] = 0.5 lcls += self.cls_loss(ps[:, 5:], t) lbox *= 0.05 lobj *= 1.0 lcls *= 0.5 loss = lbox + lobj + lcls return loss, torch.stack((loss, lbox, lobj, lcls)).detach() def build_targets(self, preds, targets): “”“构建训练目标:分类、框、索引、锚点”“” na, nt = self.nl, targets.shape[0] tcls, tbox, indices, anch = [], [], [], [] gain = torch.ones(7, device=targets.device) ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) targets = torch.cat((targets.repeat(na, 1, 1), ai[…, None]), dim=2) g = 0.5 off = torch.tensor([ [0, 0], [1, 0], [0, 1], [-1, 0], [0, -1], # j,k,l,m ], device=targets.device).float() * g for i in range(self.nl): anchors_ = self.anchors[i] gain[2:6] = torch.tensor(preds[i].shape)[[3, 2, 3, 2]] # 匹配到当前特征图上的真值 t = targets * gain if nt: r = t[…, 4:6] / anchors_[None] j = torch.max(r, 1 / r).max(dim=2).values < 4.0 t = t[j] # 偏移网格 gxy = t[:, 2:4] gxi = gain[[2, 3]] - gxy j, k = ((gxy % 1 < g) & (gxy > 1)).T l, m = ((gxi % 1 < g) & (gxi > 1)).T j = torch.stack((torch.ones_like(j), j, k, l, m)) t = t.repeat((5, 1, 1))[j] offsets = (off * torch.zeros((na, nt, 2), device=targets.device))[j] else: t = targets[0] offsets = 0 b, c = t[:, :2].long().T gxy = t[:, 2:4] gwh = t[:, 4:6] gij = (gxy - offsets).long() gi, gj = gij.T a = t[:, 6].long() indices.append((b, a, gj.clamp_(0, gain[3] - 1), gi.clamp_(0, gain[2] - 1))) tbox.append(torch.cat((gxy - gij, gwh), dim=1)) anch.append(anchors_[a]) tcls.append© return tcls, tbox, indices, anch
def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
if xywh:
(x1, y1, w1, h1), (x2, y2, w2, h2) = box1.T, box2.T
w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_
b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_
else:
b1_x1, b1_y1, b1_x2, b1_y2 = box1.T
b2_x1, b2_y1, b2_x2, b2_y2 = box2.T
inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \ (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0) area1 = (b1_x2 - b1_x1) * (b1_y2 - b1_y1) area2 = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) union = area1 + area2 - inter + eps iou = inter / union if CIoU or DIoU: cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) c2 = cw ** 2 + ch ** 2 + eps rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 if DIoU: return iou - rho2 / c2 elif CIoU: v = (4 / torch.pi ** 2) * (torch.atan(w2 / (h2 + eps)) - torch.atan(w1 / (h1 + eps))) ** 2 with torch.no_grad(): alpha = v / (v - iou + (1 + eps)) return iou - (rho2 / c2 + alpha * v) return iou
utils/metrics.py
“”"
SOD-YOLO Evaluation Metrics (Pure PyTorch)
No torchvision dependency. Fully self-contained.
“”"
import torch
import numpy as np
from typing import List, Optional
工具函数:XYWH -> XYXY
def xywh2xyxy(x: torch.Tensor) -> torch.Tensor:
“”"
Convert bounding boxes from [x_center, y_center, w, h] to [x1, y1, x2, y2]
“”"
y = x.clone() if isinstance(x, torch.Tensor) else torch.tensor(x)
y[…, 0] = x[…, 0] - x[…, 2] / 2 # x1
y[…, 1] = x[…, 1] - x[…, 3] / 2 # y1
y[…, 2] = x[…, 0] + x[…, 2] / 2 # x2
y[…, 3] = x[…, 1] + x[…, 3] / 2 # y2
return y
边界框 IoU 计算(支持 xyxy 格式)
def bbox_iou(box1: torch.Tensor,
box2: torch.Tensor,
xywh: bool = False,
eps: float = 1e-7) -> torch.Tensor:
“”"
Compute IoU between two sets of boxes.
Args:
box1: (N, 4)
box2: (M, 4)
xywh: if input is in xywh format
eps: small value to avoid div by zero
Returns:
iou: (N, M)
“”"
if xywh:
b1_x1, b1_y1 = box1[…, 0] - box1[…, 2] / 2, box1[…, 1] - box1[…, 3] / 2
b1_x2, b1_y2 = box1[…, 0] + box1[…, 2] / 2, box1[…, 1] + box1[…, 3] / 2
b2_x1, b2_y1 = box2[…, 0] - box2[…, 2] / 2, box2[…, 1] - box2[…, 3] / 2
b2_x2, b2_y2 = box2[…, 0] + box2[…, 2] / 2, box2[…, 1] + box2[…, 3] / 2
else:
b1_x1, b1_y1, b1_x2, b1_y2 = box1.unbind(dim=-1)
b2_x1, b2_y1, b2_x2, b2_y2 = box2.unbind(dim=-1)
inter_x1 = torch.max(b1_x1.unsqueeze(1), b2_x1.unsqueeze(0)) inter_y1 = torch.max(b1_y1.unsqueeze(1), b2_y1.unsqueeze(0)) inter_x2 = torch.min(b1_x2.unsqueeze(1), b2_x2.unsqueeze(0)) inter_y2 = torch.min(b1_y2.unsqueeze(1), b2_y2.unsqueeze(0)) inter = (inter_x2 - inter_x1).clamp(0) * (inter_y2 - inter_y1).clamp(0) area1 = (b1_x2 - b1_x1) * (b1_y2 - b1_y1) area2 = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) union = area1.unsqueeze(1) + area2.unsqueeze(0) - inter + eps iou = inter / union return iou
非极大抑制(Non-Maximum Suppression)
纯 PyTorch 实现,无需 torchvision.ops.nms
def non_max_suppression(
prediction: torch.Tensor,
conf_thres: float = 0.25,
iou_thres: float = 0.45,
classes: Optional[List[int]] = None,
agnostic: bool = False,
max_det: int = 300
) -> List[torch.Tensor]:
“”"
Perform NMS on detection predictions.
Args:
prediction: (batch_size, num_boxes, 85) [x,y,w,h,obj,cls…]
conf_thres: confidence threshold
iou_thres: IoU threshold for suppression
max_det: max number of detections per image
Returns:
list of tensors with shape (num_dets, 6): [x1, y1, x2, y2, conf, cls]
“”"
assert 0 <= conf_thres <= 1, f’Invalid conf_thres={conf_thres}’
assert 0 <= iou_thres <= 1, f’Invalid iou_thres={iou_thres}’
device = prediction.device bs = prediction.shape[0] nc = prediction.shape[2] - 5 # number of classes xc = prediction[…, 4] > conf_thres # candidates output = [torch.zeros((0, 6), device=device)] * bs for xi, x in enumerate(prediction): x = x[xc[xi]] # filter by confidence if x.shape[0] == 0: continue # Compute confidence: obj_conf * cls_conf x[:, 5:] *= x[:, 4:5] # Convert to xyxy box = xywh2xyxy(x[:, :4]) conf, j = x[:, 5:].max(1, keepdim=True) conf_flat = conf.view(-1) v = torch.cat((box, conf, j.float()), dim=1) # Filter by class if classes is not None: include = torch.tensor(classes, device=device) v = v[(v[:, 5:6] == include).any(1)] if v.shape[0] == 0: continue # Class-agnostic or class-aware NMS if agnostic: scores = v[:, 4] boxes = v[:, :4] else: scores = v[:, 4] boxes = v[:, :4].clone() boxes += v[:, 5:6] * 4096 # 移动不同类别的框避免跨类干扰 # Sort by score scores_sorted, order = scores.sort(descending=True) keep = [] while len(order) > 0: i = order[0].item() keep.append(i) if len(order) == 1: break # Compute IoU ious = bbox_iou(boxes[i:i+1], boxes[order[1:]]) # Keep only those with IoU <= threshold idxs_to_keep = (ious.squeeze(0) <= iou_thres).nonzero(as_tuple=False).flatten() order = order[idxs_to_keep + 1] # +1 because we used order[1:] if len(keep) >= max_det: break keep = keep[:max_det] output[xi] = v[keep] return output
平均精度计算 AP@0.5:0.95
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir=‘.’, names=()):
“”"
Compute precision, recall, AP, F1 per class.
“”"
i = np.argsort(-conf)
tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
unique_classes, nt = np.unique(target_cls, return_counts=True)
nc = unique_classes.shape[0]
px = np.linspace(0, 1, 1000) ap, p, r = [], [], [] for cls in unique_classes: idx = pred_cls == cls if not idx.any(): continue n_p = idx.sum() n_gt = (target_cls == cls).sum() if n_p == 0 or n_gt == 0: continue fpc = (1 - tp[idx]).cumsum() tpc = tp[idx].cumsum() recall = tpc / (n_gt + 1e-16) precision = tpc / (tpc + fpc) r.append(recall[-1]) p.append(precision[-1]) # AP: 101-point interpolation ap.append(np.trapz(np.interp(px, recall, precision), px)) p = np.array(p) r = np.array® ap = np.array(ap) f1 = 2 * p * r / (p + r + 1e-16) if len(p) > 0 else np.zeros_like(p) return p, r, ap, f1, unique_classes.astype(int)
匹配检测结果与真实标签
def process_batch(detections, labels, iouv=torch.linspace(0.5, 0.95, 10)):
“”"
Return correct prediction matrix.
detections: (N, 6) [x1,y1,x2,y2,conf,cls]
labels: (M, 5) [cls_id, x, y, w, h] (normalized xywh)
“”"
if detections.shape[0] == 0:
return torch.zeros((0, len(iouv)), dtype=torch.bool)
if labels.shape[0] == 0: return torch.zeros((detections.shape[0], len(iouv)), dtype=torch.bool) # Convert label xywh to xyxy tbox = xywh2xyxy(labels[:, 1:]) tbox *= torch.tensor([detections.new_tensor(detections.shape[3]), detections.new_tensor(detections.shape[2])] * 2) # scale to image size tcls = labels[:, 0] correct = torch.zeros(detections.shape[0], len(iouv), dtype=torch.bool, device=detections.device) used = [] for d_idx, (*xyxy, conf, cls_pred) in enumerate(detections): det_box = torch.tensor(xyxy).float().to(tbox.device).unsqueeze(0) best_iou, best_t_idx = 0.0, -1 for t_idx in range(len(tbox)): if t_idx in used or abs(cls_pred - tcls[t_idx]) > 1e-6: continue iou = bbox_iou(det_box, tbox[t_idx].unsqueeze(0)).item() if iou > best_iou: best_iou = iou best_t_idx = t_idx if best_iou >= iouv[0] and best_t_idx != -1: correct[d_idx] = iouv <= best_iou used.append(best_t_idx) return correct.cpu().numpy()
主评估函数
@torch.no_grad()
def evaluate(model, dataloader, iou_thres=0.5, conf_thres=0.001, nc=10):
“”"
Evaluate model mAP, Precision, Recall.
“”"
model.eval()
stats = []
for batch_i, (imgs, targets, paths, shapes) in enumerate(dataloader): imgs = imgs.to(next(model.parameters()).device) targets = targets.to(imgs.device) # Forward pass out = model(imgs) if isinstance(out, dict): out = list(out.values()) out_cat = [] for o in out: bs, _, ny, nx = o.shape o = o.permute(0, 2, 3, 1).reshape(bs, ny * nx, -1) out_cat.append(o) out = torch.cat(out_cat, dim=1) # NMS output = non_max_suppression(out, conf_thres=conf_thres, iou_thres=iou_thres) # Collect results for si, pred in enumerate(output): labels = targets[targets[:, 0] == si, 1:] stat = { ‘p’: pred.cpu(), ‘t’: labels.cpu(), ‘path’: paths[si], ‘shape’: shapes[si] } stats.append(stat) if len(stats) == 0: return 0.0, 0.0, 0.0, 0 tp, conf, pred_cls, target_cls = [], [], [], [] for s in stats: pred = s[‘p’] labels = s[‘t’] if pred.shape[0] == 0: continue if len(labels) == 0: continue matches = process_batch(pred, labels) for i, match_row in enumerate(matches): if match_row.any(): tp.append(1.0) conf.append(float(pred[i][4])) pred_cls.append(int(pred[i][5])) target_cls.append(int(labels[match_row.argmax(), 0])) else: tp.append(0.0) conf.append(float(pred[i][4])) pred_cls.append(int(pred[i][5])) if len(tp) == 0: return 0.0, 0.0, 0.0, len(stats) tp = np.array(tp) conf = np.array(conf) pred_cls = np.array(pred_cls) target_cls = np.array(target_cls) p, r, ap, f1, _ = ap_per_class(tp, conf, pred_cls, target_cls) map50 = ap.mean() if len(ap) > 0 else 0.0 precision = p.mean() if len(p) > 0 else 0.0 recall = r.mean() if len® > 0 else 0.0 return map50, precision, recall, len(stats)
test.py
import torch
import cv2
import os
from glob import glob
from models.sod_yolo import SODYOLO
from utils.metrics import non_max_suppression
import numpy as np
颜色列表(BGR)
colors = [
(255, 0, 0), (0, 255, 0), (0, 0, 255),
(255, 255, 0), (255, 0, 255), (0, 255, 255),
(128, 0, 0), (0, 128, 0), (0, 0, 128), (128, 128, 0)
]
def load_model(weights=‘weights/sod_yolo_exp1/best.pt’, nc=10):
model = SODYOLO(nc=nc)
state_dict = torch.load(weights, map_location=‘cpu’)[‘model’]
model.load_state_dict(state_dict)
model.eval()
return model
def infer_and_visualize(
model,
source=‘inference/images’,
output=‘inference/output’,
img_size=640,
conf_thres=0.4,
iou_thres=0.5
):
os.makedirs(output, exist_ok=True)
files = glob(os.path.join(source, ‘.jpg’)) + glob(os.path.join(source, '.png’))
names = [‘pedestrian’, ‘people’, ‘bicycle’, ‘car’, ‘van’, ‘truck’, ‘tricycle’, ‘awning-tricycle’, ‘bus’, ‘motor’] for img_path in files: # 读图 img0 = cv2.imread(img_path) h, w = img0.shape[:2] img = cv2.resize(img0, (img_size, img_size)) img = img.transpose(2, 0, 1)[::-1] # HWC to CHW, BGR to RGB img = torch.from_numpy(np.ascontiguousarray(img)).float() / 255.0 img = img.unsqueeze(0) # (1, 3, 640, 640) # 推理 with torch.no_grad(): preds = model(img) # 合并输出层 out_cat = [] for o in preds: bs, _, ny, nx = o.shape o = o.permute(0, 2, 3, 1).reshape(bs, ny * nx, -1) out_cat.append(o) out = torch.cat(out_cat, dim=1) dets = non_max_suppression(out, conf_thres=conf_thres, iou_thres=iou_thres)[0] # 可视化 if len(dets): dets[:, :4] = scale_boxes((img_size, img_size), dets[:, :4], (h, w)) for *xyxy, conf, cls in dets: label = f"{names[int(cls)]} {conf:.2f}" c = int(cls) color = colors[c % len(colors)] plot_one_box(xyxy, img0, label=label, color=color, line_thickness=2) # 保存 out_path = os.path.join(output, os.path.basename(img_path)) cv2.imwrite(out_path, img0) print(f"Saved: {out_path}")
def scale_boxes(img1_shape, boxes, img0_shape):
“”“Rescale boxes from img1_shape to img0_shape”“”
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1])
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2
boxes[:, [0, 2]] -= pad[0] boxes[:, [1, 3]] -= pad[1] boxes[:, :4] /= gain return boxes.clamp(0, max(img0_shape))
def plot_one_box(x, im, label=None, color=(128, 128, 128), line_thickness=3):
“”“Draw one box on image.”“”
c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
cv2.rectangle(im, c1, c2, color, thickness=line_thickness, lineType=cv2.LINE_AA)
if label:
tf = max(line_thickness - 1, 1)
t_size = cv2.getTextSize(label, 0, fontScale=tf / 3, thickness=tf)[0]
c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
cv2.rectangle(im, c1, c2, color, -1, cv2.LINE_AA)
cv2.putText(im, label, (c1[0], c1[1] - 2), 0, tf / 3, (255, 255, 255), thickness=tf, lineType=cv2.LINE_AA)
if name == “main”:
model = load_model()
infer_and_visualize(model)
train.py
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import yaml
import os
from tqdm import tqdm
from models.sod_yolo import SODYOLO
from utils.datasets import LoadImagesAndLabels
from utils.loss import ComputeLoss
from utils.metrics import evaluate
配置参数
class Args:
img_size = 640
batch_size = 4
epochs = 50
data = ‘data/visdrone.yaml’
cfg = ‘models/sod_yolo.yaml’
hyp = ‘config/hyp.scratch-low.yaml’
device = ‘cpu’ # 使用 CPU
workers = 0 # Windows 上设为 0 避免多进程问题
name = ‘sod_yolo_exp1’
args = Args()
os.makedirs(f"weights/{args.name}", exist_ok=True)
加载超参数
with open(args.hyp) as f:
hyp = yaml.load(f, Loader=yaml.SafeLoader)
构建模型
model = SODYOLO(cfg=args.cfg, nc=10) # VisDrone 有 10 类
model.hyp = hyp
model = model.to(args.device)
优化器
pg = []
pg.append(p for p in model.parameters() if p.requires_grad) # 所有可训练参数
optimizer = optim.SGD(pg[0], lr=hyp[‘lr0’], momentum=hyp[‘momentum’], nesterov=True)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
数据加载器
dataset = LoadImagesAndLabels(
root=os.path.join(args.data, ‘images/train’),
img_size=args.img_size,
augment=True
)
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=True,
collate_fn=lambda x: x # 自定义合并
)
损失函数
compute_loss = ComputeLoss(model)
训练循环
model.train()
best_map = 0.0
for epoch in range(args.epochs):
print(f"\nEpoch {epoch + 1}/{args.epochs}")
pbar = tqdm(dataloader)
total_loss = 0.0
for i, batch in enumerate(pbar): imgs, targets, _, _ = zip(*batch) # 转为 tensor imgs = torch.stack(imgs).to(args.device) targets = torch.cat([t.to(args.device) for t in targets]) # 前向传播 preds = model(imgs) loss, loss_items = compute_loss(preds, targets) # 反向传播 optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() avg_loss = total_loss / (i + 1) pbar.set_description(f"Loss: {avg_loss:.4f}“) scheduler.step() # 每 5 个 epoch 验证一次 if epoch % 5 == 0 or epoch == args.epochs - 1: map50, p, r, _ = evaluate(model, dataloader, iou_thres=0.5, conf_thres=0.001) print(f”[Val] mAP50: {map50:.4f}, Prec: {p:.4f}, Rec: {r:.4f}“) if map50 > best_map: best_map = map50 torch.save({ ‘model’: model.state_dict(), ‘epoch’: epoch, ‘best_map’: best_map }, f"weights/{args.name}/best.pt”) torch.save({ ‘model’: model.state_dict(), ‘epoch’: epoch, ‘map’: map50 }, f"weights/{args.name}/last.pt")
print(“✅ Training completed.”)
val.py
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import os
from models.sod_yolo import SODYOLO
from utils.datasets import LoadImagesAndLabels
from utils.metrics import evaluate
def main():
device = ‘cpu’
weights_path = ‘weights/sod_yolo_exp1/best.pt’
data_dir = ‘datasets/VisDrone2019-DET/images/val’
assert os.path.exists(weights_path), f"权重文件不存在: {weights_path}" # 加载模型 model = SODYOLO(nc=10) state_dict = torch.load(weights_path, map_location=device)[‘model’] model.load_state_dict(state_dict) model.to(device).eval() # 数据集 dataset = LoadImagesAndLabels(data_dir, img_size=640) dataloader = DataLoader(dataset, batch_size=2, shuffle=False, collate_fn=lambda x: x) # 评估 with torch.no_grad(): map50, precision, recall, _ = evaluate(model, dataloader, iou_thres=0.5, conf_thres=0.001) print(f"\n📊 最终结果:“) print(f"mAP@0.5: {map50:.4f}”) print(f"Precision: {precision:.4f}“) print(f"Recall: {recall:.4f}”)
if name == “main”:
main()