import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torch.utils.data import Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
import matplotlib.pyplot as plt
from datasets import load_dataset
#----------------------------------------------------------------------------------------------------------------------#
#模型构建
#基础组件
class Conv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=None, groups=1):
super().__init__()
if padding is None:
padding = (kernel_size - 1) // 2
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, groups=groups, bias=False)
self.bn = nn.BatchNorm2d(out_channels)
self.act = nn.SiLU() # YOLOv8 使用 SiLU 激活函数
def forward(self, x):
return self.act(self.bn(self.conv(x)))
#基础组件
class C2f(nn.Module):
def __init__(self, in_channels, out_channels, num_blocks=2, shortcut=False):
super().__init__()
self.out_channels = out_channels
hidden_channels = int(out_channels * 0.5)
self.conv1 = Conv(in_channels, hidden_channels * 2, 1)
self.conv2 = Conv((hidden_channels * 2 + hidden_channels * num_blocks), out_channels, 1)
self.blocks = nn.ModuleList()
for _ in range(num_blocks):
self.blocks.append(Conv(hidden_channels, hidden_channels, 3))
self.shortcut = shortcut
def forward(self, x):
y = list(self.conv1(x).chunk(2, 1)) # Split into two halves
for block in self.blocks:
y.append(block(y[-1]))
return self.conv2(torch.cat(y, dim=1))
#基础组件
class SPPF(nn.Module):
def __init__(self, in_channels, out_channels, k=5):
super().__init__()
hidden_channels = in_channels // 2
self.conv1 = Conv(in_channels, hidden_channels, 1)
self.pool = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
self.conv2 = Conv(hidden_channels * 4, out_channels, 1)
def forward(self, x):
x = self.conv1(x)
pool1 = self.pool(x)
pool2 = self.pool(pool1)
pool3 = self.pool(pool2)
return self.conv2(torch.cat([x, pool1, pool2, pool3], dim=1))
#主干网络
class Backbone(nn.Module):
def __init__(self):
super().__init__()
self.stage1 = nn.Sequential(
Conv(3, 64, 3, 2),
Conv(64, 128, 3, 2),
C2f(128, 128, 3)
)
self.stage2 = nn.Sequential(
Conv(128, 256, 3, 2),
C2f(256, 256, 6)
)
self.stage3 = nn.Sequential(
Conv(256, 512, 3, 2),
C2f(512, 512, 6)
)
self.stage4 = nn.Sequential(
Conv(512, 1024, 3, 2),
C2f(1024, 1024, 3),
SPPF(1024, 1024)
)
def forward(self, x):
x = self.stage1(x)
x = self.stage2(x)
c3 = self.stage3(x)
c4 = self.stage4(c3)
return c3, c4 # 输出两个尺度用于 Neck
#特征融合
class Neck(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = Conv(1024, 512, 1)
self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
self.c2f1 = C2f(512 + 512, 512, 3)
self.conv2 = Conv(512, 256, 1)
self.c2f2 = C2f(256 + 256, 256, 3)
self.conv3 = Conv(256, 256, 3, 2)
self.c2f3 = C2f(256 + 256, 512, 3)
self.conv4 = Conv(512, 512, 3, 2)
self.c2f4 = C2f(512 + 512, 1024, 3)
def forward(self, c3, c4):
# 自顶向下:上采样 + 融合
p4 = self.conv1(c4)
p4_up = self.upsample(p4)
p4_cat = torch.cat([p4_up, c3], dim=1)
p3 = self.c2f1(p4_cat)
# 更高分辨率输出
p3_out = self.conv2(p3)
p3_down = self.conv3(p3)
p3_down_cat = torch.cat([p3_down, p4], dim=1)
p4_out = self.c2f3(p3_down_cat)
# 最深层输出
p4_down = self.conv4(p4_out)
p4_down_cat = torch.cat([p4_down, c4], dim=1)
p5_out = self.c2f4(p4_down_cat)
return p3_out, p4_out, p5_out
#解耦检测头
class DecoupledHead(nn.Module):
def __init__(self, in_channels, num_classes=80):
super().__init__()
# 分离的 3×3 卷积分支
self.cls_conv = nn.Conv2d(in_channels, in_channels, 3, padding=1)
self.reg_conv = nn.Conv2d(in_channels, in_channels, 3, padding=1)
# 预测层
self.cls_pred = nn.Conv2d(in_channels, num_classes, 1)
self.reg_pred = nn.Conv2d(in_channels, 4, 1) # tx, ty, tw, th
self.obj_pred = nn.Conv2d(in_channels, 1, 1) # objectness
self.act = nn.SiLU()
def forward(self, x):
c = self.act(self.cls_conv(x))
r = self.act(self.reg_conv(x))
cls = self.cls_pred(c)
reg = self.reg_pred(r)
obj = self.obj_pred(r)
return torch.cat([reg, obj, cls], dim=1)
#多尺度检测头
class Detect(nn.Module):
def __init__(self, num_classes=80):
super().__init__()
self.num_classes = num_classes
self.strides = [8, 16, 32] # 对应输出尺度
# 为每个尺度创建一个解耦头
self.head_small = DecoupledHead(256, num_classes)
self.head_medium = DecoupledHead(512, num_classes)
self.head_large = DecoupledHead(1024, num_classes)
def forward(self, x):
p3, p4, p5 = x
pred_small = self.head_small(p3) # shape: (B, 4 + 1 + num_classes, H/8, W/8)
pred_medium = self.head_medium(p4) # (B, ..., H/16, W/16)
pred_large = self.head_large(p5) # (B, ..., H/32, W/32)
return [pred_small, pred_medium, pred_large]
#YOLO-V8整体模型
class YOLOv8(nn.Module):
def __init__(self, num_classes=80):
super().__init__()
self.backbone = Backbone()
self.neck = Neck()
self.detect = Detect(num_classes)
def forward(self, x):
c3, c4 = self.backbone(x)
features = self.neck(c3, c4)
predictions = self.detect(features)
return predictions
#----------------------------------------------------------------------------------------------------------------------#
#模型训练测试
# 下载并解压PASCAL VOC 2012
ds = load_dataset("corypaik/pascal_voc") # 包含 train/val split
voc_train = ds["train"]
voc_val = ds["validation"]
print(voc_train)
print(voc_train[0].keys())
#定义Dataset类
class HFDataset(Dataset):
def __init__(self, hf_dataset, img_size=640):
self.hf_dataset = hf_dataset
self.img_size = img_size
self.class_to_idx = {
'aeroplane': 0, 'bicycle': 1, 'bird': 2, 'boat': 3,
'bottle': 4, 'bus': 5, 'car': 6, 'cat': 7,
'chair': 8, 'cow': 9, 'diningtable': 10, 'dog': 11,
'horse': 12, 'motorbike': 13, 'person': 14, 'pottedplant': 15,
'sheep': 16, 'sofa': 17, 'train': 18, 'tvmonitor': 19
}
self.idx_to_class = {v: k for k, v in self.class_to_idx.items()}
# 增强 pipeline(注意:boxes 是归一化的 xyxy 格式)
self.transform = A.Compose([
A.Resize(height=img_size, width=img_size),
A.HorizontalFlip(p=0.5),
A.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1, p=0.5),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2()
], bbox_params=A.BboxParams(
format='coco', # 输入是 [x,y,w,h] 相对坐标!⚠️
label_fields=['class_labels'],
min_visibility=0.1
))
def __len__(self):
return len(self.hf_dataset)
def __getitem__(self, idx):
item = self.hf_dataset[idx]
image = item["image"] # PIL Image
orig_width, orig_height = item["width"], item["height"]
img = np.array(image.convert("RGB"))
# boxes: list of [x,y,w,h] (COCO format, absolute pixel values)
boxes_abs = torch.tensor(item["boxes"], dtype=torch.float32) # (N, 4)
labels = torch.tensor(item["labels"], dtype=torch.long) # (N,)
# 转换为相对坐标 [0,1] 的 COCO 格式用于 Albumentations
h, w = orig_height, orig_width
boxes_rel_coco = boxes_abs.clone()
boxes_rel_coco[:, [0, 2]] /= w # x, w
boxes_rel_coco[:, [1, 3]] /= h # y, h
# 过滤无效框
keep = (boxes_rel_coco[:, 2] > 0) & (boxes_rel_coco[:, 3] > 0)
boxes_rel_coco = boxes_rel_coco[keep]
labels = labels[keep]
if len(boxes_rel_coco) == 0:
boxes_rel_coco = torch.tensor([[0.1, 0.1, 0.1, 0.1]]) # dummy box
labels = torch.tensor([0])
class_labels = labels.tolist()
# 应用增强(会自动处理 resize + flip + normalize)
try:
transformed = self.transform(
image=img,
bboxes=boxes_rel_coco.numpy(),
class_labels=class_labels
)
except Exception as e:
print(f"Augmentation error at index {idx}: {e}")
# 回退
img_tensor = ToTensorV2()(image=A.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)(image=img)["image"])["image"]
boxes_tensor = torch.tensor([[0.1, 0.1, 0.2, 0.2]])
labels_tensor = torch.tensor([0])
else:
img_tensor = transformed["image"]
boxes_tensor = torch.tensor(transformed["bboxes"], dtype=torch.float32)
labels_tensor = torch.tensor(transformed["class_labels"], dtype=torch.long)
# 将增强后的 boxes 转回 xyxy 格式(归一化)以便 compute_loss 使用
# 当前 boxes_tensor 是 coco 格式 [cx,cy,w,h] → 转为 xyxy
boxes_xyxy = torch.zeros_like(boxes_tensor)
boxes_xyxy[:, 0] = boxes_tensor[:, 0] - boxes_tensor[:, 2] / 2 # x1
boxes_xyxy[:, 1] = boxes_tensor[:, 1] - boxes_tensor[:, 3] / 2 # y1
boxes_xyxy[:, 2] = boxes_tensor[:, 0] + boxes_tensor[:, 2] / 2 # x2
boxes_xyxy[:, 3] = boxes_tensor[:, 1] + boxes_tensor[:, 3] / 2 # y2
# clip 到 [0,1]
boxes_xyxy.clamp_(0, 1)
target = {
"boxes": boxes_xyxy,
"labels": labels_tensor,
"image_id": torch.tensor([idx])
}
return img_tensor, target
# 创建数据集实例
train_dataset = HFDataset(voc_train, img_size=640)
val_dataset = HFDataset(voc_val, img_size=640)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=8,
shuffle=True,
num_workers=4,
pin_memory=True,
collate_fn=lambda x: tuple(zip(*x)) # 默认即可
)
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=8,
shuffle=False,
num_workers=4,
pin_memory=True,
collate_fn=lambda x: tuple(zip(*x))
)
#损失函数
def compute_loss(outputs, targets, strides=[8, 16, 32], num_classes=20, alpha=0.5, gamma=0.5):
"""
Compute YOLOv8-style loss with simple positive sample assignment.
outputs: List[Tensor] -> [P3, P4, P5], each shape (B, C, H, W)
targets: List[dict] -> [{'boxes': (N,4), 'labels': (N,)}]
"""
device = outputs[0].device
criterion_cls = nn.BCEWithLogitsLoss(reduction='none')
criterion_obj = nn.BCEWithLogitsLoss(reduction='none')
total_loss_cls = torch.zeros(1, device=device)
total_loss_obj = torch.zeros(1, device=device)
total_loss_reg = torch.zeros(1, device=device)
num_positive = 0
# 特征图尺寸对应关系
img_size = 640
feature_sizes = [img_size // s for s in strides] # [80, 40, 20]
for i, pred in enumerate(outputs):
H, W = feature_sizes[i], feature_sizes[i]
stride = strides[i]
bs = pred.shape[0]
# Reshape & split predictions: [B, C, H, W] -> [B, H*W, *]
pred = pred.permute(0, 2, 3, 1).reshape(bs, -1, 4 + 1 + num_classes)
reg_pred = pred[..., :4] # tx, ty, tw, th
obj_pred = pred[..., 4] # objectness (flattened)
cls_pred = pred[..., 5:] # class logits
# Generate grid centers (center of each grid cell in original image space)
yv, xv = torch.meshgrid([torch.arange(H), torch.arange(W)])
grid_xy = torch.stack((xv, yv), dim=2).float().to(device) # (H, W, 2)
grid_xy = grid_xy.reshape(-1, 2) # (H*W, 2)
grid_xy = grid_xy.unsqueeze(0).expand(bs, -1, -1) # (B, H*W, 2)
anchor_points = (grid_xy + 0.5) * stride # (cx, cy) on original image
# Decode predicted boxes (in xywh format)
pred_xy = anchor_points + reg_pred[..., :2].sigmoid() * stride - 0.5 * stride
pred_wh = torch.exp(reg_pred[..., 2:]) * stride
pred_boxes = torch.cat([pred_xy, pred_wh], dim=-1) # (B, H*W, 4)
# Prepare targets
obj_target = torch.zeros_like(obj_pred)
cls_target = torch.zeros_like(cls_pred)
reg_target = torch.zeros_like(reg_pred)
fg_mask = torch.zeros_like(obj_pred, dtype=torch.bool)
for b in range(bs):
tbox = targets[b]['boxes'] # (N, 4), xyxy format
tlabel = targets[b]['labels'] # (N,)
if len(tbox) == 0:
continue
# Convert tbox from xyxy to xywh
tbox_xyxy = tbox
tbox_xywh = torch.cat([
(tbox_xyxy[:, :2] + tbox_xyxy[:, 2:]) / 2,
tbox_xyxy[:, 2:] - tbox_xyxy[:, :2]
], dim=1) # (N, 4)
# Match: find best overlap between gt centers and anchor points
gt_centers = tbox_xywh[:, :2] # (N, 2)
distances = (anchor_points[b].unsqueeze(1) - gt_centers.unsqueeze(0)).pow(2).sum(dim=-1) # (H*W, N)
_, closest_grid_idx = distances.min(dim=0) # each gt → nearest grid
_, closest_gt_idx = distances.min(dim=1) # each grid → nearest gt
# Positive samples: grids whose closest gt is itself
pos_mask = torch.zeros(H * W, dtype=torch.bool, device=device)
for gt_i in range(len(tbox)):
grid_i = closest_grid_idx[gt_i]
if closest_gt_idx[grid_i] == gt_i: # mutual match
pos_mask[grid_i] = True
fg_mask[b][pos_mask] = True
obj_target[b][pos_mask] = 1.0
cls_target[b][pos_mask] = nn.functional.one_hot(tlabel.long(), num_classes=num_classes).float()
# Regression target for positive samples
matched_gt = tbox_xywh[pos_mask.sum(dim=0).nonzero(as_tuple=True)[0]] # tricky fix; assume one-to-one
if len(matched_gt) > 0:
reg_target[b][pos_mask] = torch.cat([
(matched_gt[:, :2] - anchor_points[b][pos_mask]) / stride,
torch.log(matched_gt[:, 2:] / stride + 1e-8)
], dim=1)
# Only compute loss on positive samples
if fg_mask.any():
loss_obj = criterion_obj(obj_pred, obj_target)
total_loss_obj += loss_obj.mean()
pos_obj = obj_target == 1
if pos_obj.any():
total_loss_cls += criterion_cls(cls_pred[pos_obj], cls_target[pos_obj]).mean()
total_loss_reg += torch.abs(reg_pred[pos_obj] - reg_target[pos_obj]).mean()
num_positive += pos_obj.sum().item()
else:
total_loss_obj += obj_pred.sum() * 0
total_loss_cls += cls_pred.sum() * 0
total_loss_reg += reg_pred.sum() * 0
# Normalize losses
if num_positive > 0:
total_loss_reg /= num_positive
total_loss_cls /= num_positive
total_loss_obj /= len(outputs)
total_loss = total_loss_obj + total_loss_cls + total_loss_reg * 5.0 # weight reg more
return total_loss, total_loss_obj.detach(), total_loss_cls.detach(), total_loss_reg.detach()
# 模型训练(真实损失)
model = YOLOv8(num_classes=20).train().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
for epoch in range(10):
print(f"\nEpoch {epoch + 1}/10")
for i, (images, targets) in enumerate(train_loader):
images = images.cuda(non_blocking=True)
# 前向传播
outputs = model(images) # List[Tensor]: [B, 25, H/8, W/8], ...
# 计算真实损失
loss, loss_obj, loss_cls, loss_reg = compute_loss(outputs, targets, num_classes=20)
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i % 10 == 0:
print(f"Iter {i}, Loss: {loss.item():.4f} "
f"(obj={loss_obj.item():.4f}, cls={loss_cls.item():.4f}, reg={loss_reg.item():.4f})")
#----------------------------------------------------------------------------------------------------------------------#
#可视化
#YOLO输出解码
def decode_outputs(outputs, strides=[8, 16, 32], img_size=640):
"""
将 YOLO 输出解码为 [cx, cy, w, h] 形式的归一化框(相对坐标)
返回: (pred_boxes, pred_scores, pred_labels), each as tensor
"""
device = outputs[0].device
pred_boxes = []
pred_scores = []
pred_labels = []
for i, pred in enumerate(outputs):
bs, _, ny, nx = pred.shape
stride = strides[i]
pred = pred.permute(0, 2, 3, 1).reshape(bs, -1, 4 + 1 + 20) # [B, H*W, ...]
reg_pred = pred[..., :4] # tx, ty, tw, th
obj_pred = pred[..., 4:5].sigmoid() # objectness
cls_pred = pred[..., 5:].sigmoid() # class scores
# 获取 grid 中心点
yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
grid_xy = torch.stack((xv, yv), dim=2).float().to(device) # (ny, nx, 2)
grid_xy = grid_xy.reshape(-1, 2).unsqueeze(0).expand(bs, -1, -1) # (B, H*W, 2)
# 解码中心点
pred_xy = (grid_xy + 0.5 + reg_pred[..., :2].sigmoid()) * stride
# 解码宽高
pred_wh = torch.exp(reg_pred[..., 2:]) * stride
# 拼接为 xywh 框
boxes = torch.cat([pred_xy, pred_wh], dim=-1) # (B, H*W, 4)
# 计算置信度:obj * max(cls_prob)
scores, labels = torch.max(cls_pred, dim=-1)
scores = scores * obj_pred.squeeze(-1) # confidence = obj × cls_score
# 转换为绝对坐标并归一化到 [0,1]
img_w, img_h = img_size, img_size
boxes_normalized = boxes / torch.tensor([img_w, img_h, img_w, img_h], device=device)
pred_boxes.append(boxes_normalized)
pred_scores.append(scores)
pred_labels.append(labels)
# 合并所有尺度
pred_boxes = torch.cat(pred_boxes, dim=1) # (B, L, 4)
pred_scores = torch.cat(pred_scores, dim=1) # (B, L)
pred_labels = torch.cat(pred_labels, dim=1) # (B, L)
return pred_boxes, pred_scores, pred_labels
#NMS后处理
def postprocess(pred_boxes, pred_scores, pred_labels, conf_thresh=0.25, nms_thresh=0.5):
"""
对每张图片进行 NMS 后处理
"""
final_boxes, final_scores, final_labels = [], [], []
for i in range(pred_boxes.shape[0]): # 遍历 batch 中每张图
boxes = pred_boxes[i] # (L, 4)
scores = pred_scores[i] # (L,)
labels = pred_labels[i] # (L,)
# 过滤低置信度
mask = scores > conf_thresh
boxes = boxes[mask]
scores = scores[mask]
labels = labels[mask]
if len(boxes) == 0:
final_boxes.append(torch.empty(0, 4))
final_scores.append(torch.empty(0))
final_labels.append(torch.empty(0))
continue
# 转 xyxy 并应用 NMS
boxes_xyxy = torchvision.ops.box_convert(boxes, in_fmt='cxcywh', out_fmt='xyxy')
keep = torchvision.ops.nms(boxes_xyxy, scores, nms_thresh)
final_boxes.append(boxes[keep])
final_scores.append(scores[keep])
final_labels.append(labels[keep])
return final_boxes, final_scores, final_labels
#可视化函数
def visualize_predictions(model, dataloader, idx_to_class, device="cuda", num_images=4):
model.eval()
inv_normalize = transforms.Compose([
transforms.Normalize(mean=[0., 0., 0.], std=[1/0.229, 1/0.224, 1/0.225]),
transforms.Normalize(mean=[-0.485, -0.456, -0.406], std=[1., 1., 1.]),
])
figure, ax = plt.subplots(num_images, 2, figsize=(12, 6 * num_images))
if num_images == 1:
ax = ax.unsqueeze(0)
with torch.no_grad():
for i, (images, targets) in enumerate(dataloader):
images = images.to(device)
outputs = model(images)
pred_boxes, pred_scores, pred_labels = decode_outputs(outputs)
det_boxes, det_scores, det_labels = postprocess(pred_boxes, pred_scores, pred_labels)
for j in range(min(num_images, len(images))):
# 图像 j
img = images[j].cpu()
img = inv_normalize(img)
img = torch.clamp(img, 0, 1)
img = transforms.ToPILImage()(img)
# 绘图
ax[j, 0].imshow(img); ax[j, 0].set_title("Ground Truth")
ax[j, 1].imshow(img); ax[j, 1].set_title("Predictions")
# 绘制 GT
gt_boxes = targets[j]['boxes'].cpu()
gt_labels = targets[j]['labels'].cpu()
for k in range(len(gt_boxes)):
box = gt_boxes[k].numpy()
label = idx_to_class[gt_labels[k].item()]
rect = plt.Rectangle(
(box[0]*640, box[1]*640), (box[2]-box[0])*640, (box[3]-box[1])*640,
fill=False, edgecolor='green', linewidth=2
)
ax[j, 0].add_patch(rect)
ax[j, 0].text(box[0]*640, box[1]*640, label, color='white', fontsize=10,
bbox=dict(facecolor='green', alpha=0.7))
# 绘制 Pred
pred_box_img = det_boxes[j].cpu() * 640 # 转为像素坐标
pred_score_img = det_scores[j].cpu()
pred_label_img = det_labels[j].cpu()
for k in range(len(pred_box_img)):
box = pred_box_img[k].numpy()
score = pred_score_img[k].item()
label = idx_to_class[pred_label_img[k].item()]
x1, y1 = int(box[0]), int(box[1])
w, h = int(box[2] - box[0]), int(box[3] - box[1])
rect = plt.Rectangle(
(x1, y1), w, h, fill=False, edgecolor='red', linewidth=2
)
ax[j, 1].add_patch(rect)
ax[j, 1].text(x1, y1, f"{label}:{score:.2f}", color='white',
fontsize=10, bbox=dict(facecolor='red', alpha=0.7))
ax[j, 0].axis('off'); ax[j, 1].axis('off')
break # 只看第一批次
plt.tight_layout()
plt.show()
model.train() # 回到训练模式
# 类别索引反查表
idx_to_class = {v: k for k, v in train_dataset.class_to_idx.items()}
# 在某个 epoch 后调用
visualize_predictions(model, val_loader, idx_to_class, device="cuda", num_images=4)