import torch
import torch.nn as nn
from torchvision.datasets import VOCDetection
from torchvision import transforms
from torch.utils.data import Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
#基础组件
class Conv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=None, groups=1):
super().__init__()
if padding is None:
padding = (kernel_size - 1) // 2
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, groups=groups, bias=False)
self.bn = nn.BatchNorm2d(out_channels)
self.act = nn.SiLU() # YOLOv8 使用 SiLU 激活函数
def forward(self, x):
return self.act(self.bn(self.conv(x)))
#基础组件
class C2f(nn.Module):
def __init__(self, in_channels, out_channels, num_blocks=2, shortcut=False):
super().__init__()
self.out_channels = out_channels
hidden_channels = int(out_channels * 0.5)
self.conv1 = Conv(in_channels, hidden_channels * 2, 1)
self.conv2 = Conv((hidden_channels * 2 + hidden_channels * num_blocks), out_channels, 1)
self.blocks = nn.ModuleList()
for _ in range(num_blocks):
self.blocks.append(Conv(hidden_channels, hidden_channels, 3))
self.shortcut = shortcut
def forward(self, x):
y = list(self.conv1(x).chunk(2, 1)) # Split into two halves
for block in self.blocks:
y.append(block(y[-1]))
return self.conv2(torch.cat(y, dim=1))
#基础组件
class SPPF(nn.Module):
def __init__(self, in_channels, out_channels, k=5):
super().__init__()
hidden_channels = in_channels // 2
self.conv1 = Conv(in_channels, hidden_channels, 1)
self.pool = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
self.conv2 = Conv(hidden_channels * 4, out_channels, 1)
def forward(self, x):
x = self.conv1(x)
pool1 = self.pool(x)
pool2 = self.pool(pool1)
pool3 = self.pool(pool2)
return self.conv2(torch.cat([x, pool1, pool2, pool3], dim=1))
#主干网络
class Backbone(nn.Module):
def __init__(self):
super().__init__()
self.stage1 = nn.Sequential(
Conv(3, 64, 3, 2),
Conv(64, 128, 3, 2),
C2f(128, 128, 3)
)
self.stage2 = nn.Sequential(
Conv(128, 256, 3, 2),
C2f(256, 256, 6)
)
self.stage3 = nn.Sequential(
Conv(256, 512, 3, 2),
C2f(512, 512, 6)
)
self.stage4 = nn.Sequential(
Conv(512, 1024, 3, 2),
C2f(1024, 1024, 3),
SPPF(1024, 1024)
)
def forward(self, x):
x = self.stage1(x)
x = self.stage2(x)
c3 = self.stage3(x)
c4 = self.stage4(c3)
return c3, c4 # 输出两个尺度用于 Neck
#特征融合
class Neck(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = Conv(1024, 512, 1)
self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
self.c2f1 = C2f(512 + 512, 512, 3)
self.conv2 = Conv(512, 256, 1)
self.c2f2 = C2f(256 + 256, 256, 3)
self.conv3 = Conv(256, 256, 3, 2)
self.c2f3 = C2f(256 + 256, 512, 3)
self.conv4 = Conv(512, 512, 3, 2)
self.c2f4 = C2f(512 + 512, 1024, 3)
def forward(self, c3, c4):
# 自顶向下:上采样 + 融合
p4 = self.conv1(c4)
p4_up = self.upsample(p4)
p4_cat = torch.cat([p4_up, c3], dim=1)
p3 = self.c2f1(p4_cat)
# 更高分辨率输出
p3_out = self.conv2(p3)
p3_down = self.conv3(p3)
p3_down_cat = torch.cat([p3_down, p4], dim=1)
p4_out = self.c2f3(p3_down_cat)
# 最深层输出
p4_down = self.conv4(p4_out)
p4_down_cat = torch.cat([p4_down, c4], dim=1)
p5_out = self.c2f4(p4_down_cat)
return p3_out, p4_out, p5_out
#解耦检测头
class DecoupledHead(nn.Module):
def __init__(self, in_channels, num_classes=80):
super().__init__()
# 分离的 3×3 卷积分支
self.cls_conv = nn.Conv2d(in_channels, in_channels, 3, padding=1)
self.reg_conv = nn.Conv2d(in_channels, in_channels, 3, padding=1)
# 预测层
self.cls_pred = nn.Conv2d(in_channels, num_classes, 1)
self.reg_pred = nn.Conv2d(in_channels, 4, 1) # tx, ty, tw, th
self.obj_pred = nn.Conv2d(in_channels, 1, 1) # objectness
self.act = nn.SiLU()
def forward(self, x):
c = self.act(self.cls_conv(x))
r = self.act(self.reg_conv(x))
cls = self.cls_pred(c)
reg = self.reg_pred(r)
obj = self.obj_pred(r)
return torch.cat([reg, obj, cls], dim=1)
#多尺度检测头
class Detect(nn.Module):
def __init__(self, num_classes=80):
super().__init__()
self.num_classes = num_classes
self.strides = [8, 16, 32] # 对应输出尺度
# 为每个尺度创建一个解耦头
self.head_small = DecoupledHead(256, num_classes)
self.head_medium = DecoupledHead(512, num_classes)
self.head_large = DecoupledHead(1024, num_classes)
def forward(self, x):
p3, p4, p5 = x
pred_small = self.head_small(p3) # shape: (B, 4 + 1 + num_classes, H/8, W/8)
pred_medium = self.head_medium(p4) # (B, ..., H/16, W/16)
pred_large = self.head_large(p5) # (B, ..., H/32, W/32)
return [pred_small, pred_medium, pred_large]
#YOLO-V8整体模型
class YOLOv8(nn.Module):
def __init__(self, num_classes=80):
super().__init__()
self.backbone = Backbone()
self.neck = Neck()
self.detect = Detect(num_classes)
def forward(self, x):
c3, c4 = self.backbone(x)
features = self.neck(c3, c4)
predictions = self.detect(features)
return predictions
#----------------------------------------------------------------------------------------------------------------------#
# 下载并解压PASCAL VOC 2012
dataset_dir = "./VOCdevkit"
voc_train = VOCDetection(
root=dataset_dir,
year='2012',
image_set='train',
download=True # 如果没有会尝试下载
)
voc_val = VOCDetection(
root=dataset_dir,
year='2012',
image_set='val',
download=True
)
# 图像变换
transform = transforms.Compose([
transforms.Resize((640, 640)), # 调整为网络输入大小
transforms.ToTensor(), # 转为 [0,1] 归一化张量
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet 标准化
])
#定义Dataset类
class VOCDataset(Dataset):
def __init__(self, data_list, img_size=640):
self.data_list = data_list
self.img_size = img_size
self.class_to_idx = {
'aeroplane': 0, 'bicycle': 1, 'bird': 2, 'boat': 3,
'bottle': 4, 'bus': 5, 'car': 6, 'cat': 7,
'chair': 8, 'cow': 9, 'diningtable': 10, 'dog': 11,
'horse': 12, 'motorbike': 13, 'person': 14, 'pottedplant': 15,
'sheep': 16, 'sofa': 17, 'train': 18, 'tvmonitor': 19
}
# 定义增强 pipeline(包含 resize, flip, color jitter, normalize, to_tensor)
self.transform = A.Compose([
A.Resize(height=img_size, width=img_size),
A.HorizontalFlip(p=0.5),
A.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1, p=0.5),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2() # 转为 [C,H,W] tensor
], bbox_params=A.BboxParams(
format='pascal_voc',
label_fields=['class_labels'],
min_visibility=0.1
))
def __len__(self):
return len(self.data_list)
def __getitem__(self, idx):
image, ann = self.data_list[idx]
img = np.array(image.convert("RGB")) # 转为 numpy array
boxes = []
labels = []
for obj in ann['annotation']['object']:
cls_name = obj['name']
if cls_name not in self.class_to_idx:
continue
label = self.class_to_idx[cls_name]
bbox = obj['bndbox']
xmin = float(bbox['xmin'])
ymin = float(bbox['ymin'])
xmax = float(bbox['xmax'])
ymax = float(bbox['ymax'])
# 确保坐标合法
if xmax > xmin and ymax > ymin:
boxes.append([xmin, ymin, xmax, ymax])
labels.append(label)
# 若无有效标注,添加虚拟目标避免崩溃
if len(boxes) == 0:
boxes = [[0, 0, 10, 10]]
labels = [0]
# 应用增强(自动完成 resize + flip + normalize + to_tensor)
try:
transformed = self.transform(image=img, bboxes=boxes, class_labels=labels)
except Exception as e:
print(f"Augmentation error at index {idx}: {e}")
# 回退到最小变换
transformed = {
"image": ToTensorV2()(image=transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)(transforms.ToTensor()(img)))["image"],
"bboxes": torch.tensor(boxes, dtype=torch.float32),
"class_labels": torch.tensor(labels, dtype=torch.long)
}
img_tensor = transformed["image"] # 已经是 [3,640,640] 归一化张量
boxes_tensor = torch.tensor(transformed["bboxes"], dtype=torch.float32)
labels_tensor = torch.tensor(transformed["class_labels"], dtype=torch.long)
target = {
"boxes": boxes_tensor,
"labels": labels_tensor,
"image_id": torch.tensor([idx]),
}
return img_tensor, target
# 创建数据集实例
train_dataset = VOCDataset(voc_train, img_size=640)
val_dataset = VOCDataset(voc_val, img_size=640)
# 使用 DataLoader 加载批次
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=8,
shuffle=True,
num_workers=4,
pin_memory=True
)
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=8,
shuffle=False,
num_workers=4,
pin_memory=True
)
#损失函数
def compute_loss(outputs, targets, strides=[8, 16, 32], num_classes=20, alpha=0.5, gamma=0.5):
"""
Compute YOLOv8-style loss with simple positive sample assignment.
outputs: List[Tensor] -> [P3, P4, P5], each shape (B, C, H, W)
targets: List[dict] -> [{'boxes': (N,4), 'labels': (N,)}]
"""
device = outputs[0].device
criterion_cls = nn.BCEWithLogitsLoss(reduction='none')
criterion_obj = nn.BCEWithLogitsLoss(reduction='none')
total_loss_cls = torch.zeros(1, device=device)
total_loss_obj = torch.zeros(1, device=device)
total_loss_reg = torch.zeros(1, device=device)
num_positive = 0
# 特征图尺寸对应关系
img_size = 640
feature_sizes = [img_size // s for s in strides] # [80, 40, 20]
for i, pred in enumerate(outputs):
H, W = feature_sizes[i], feature_sizes[i]
stride = strides[i]
bs = pred.shape[0]
# Reshape & split predictions: [B, C, H, W] -> [B, H*W, *]
pred = pred.permute(0, 2, 3, 1).reshape(bs, -1, 4 + 1 + num_classes)
reg_pred = pred[..., :4] # tx, ty, tw, th
obj_pred = pred[..., 4] # objectness (flattened)
cls_pred = pred[..., 5:] # class logits
# Generate grid centers (center of each grid cell in original image space)
yv, xv = torch.meshgrid([torch.arange(H), torch.arange(W)])
grid_xy = torch.stack((xv, yv), dim=2).float().to(device) # (H, W, 2)
grid_xy = grid_xy.reshape(-1, 2) # (H*W, 2)
grid_xy = grid_xy.unsqueeze(0).expand(bs, -1, -1) # (B, H*W, 2)
anchor_points = (grid_xy + 0.5) * stride # (cx, cy) on original image
# Decode predicted boxes (in xywh format)
pred_xy = anchor_points + reg_pred[..., :2].sigmoid() * stride - 0.5 * stride
pred_wh = torch.exp(reg_pred[..., 2:]) * stride
pred_boxes = torch.cat([pred_xy, pred_wh], dim=-1) # (B, H*W, 4)
# Prepare targets
obj_target = torch.zeros_like(obj_pred)
cls_target = torch.zeros_like(cls_pred)
reg_target = torch.zeros_like(reg_pred)
fg_mask = torch.zeros_like(obj_pred, dtype=torch.bool)
for b in range(bs):
tbox = targets[b]['boxes'] # (N, 4), xyxy format
tlabel = targets[b]['labels'] # (N,)
if len(tbox) == 0:
continue
# Convert tbox from xyxy to xywh
tbox_xyxy = tbox
tbox_xywh = torch.cat([
(tbox_xyxy[:, :2] + tbox_xyxy[:, 2:]) / 2,
tbox_xyxy[:, 2:] - tbox_xyxy[:, :2]
], dim=1) # (N, 4)
# Match: find best overlap between gt centers and anchor points
gt_centers = tbox_xywh[:, :2] # (N, 2)
distances = (anchor_points[b].unsqueeze(1) - gt_centers.unsqueeze(0)).pow(2).sum(dim=-1) # (H*W, N)
_, closest_grid_idx = distances.min(dim=0) # each gt → nearest grid
_, closest_gt_idx = distances.min(dim=1) # each grid → nearest gt
# Positive samples: grids whose closest gt is itself
pos_mask = torch.zeros(H * W, dtype=torch.bool, device=device)
for gt_i in range(len(tbox)):
grid_i = closest_grid_idx[gt_i]
if closest_gt_idx[grid_i] == gt_i: # mutual match
pos_mask[grid_i] = True
fg_mask[b][pos_mask] = True
obj_target[b][pos_mask] = 1.0
cls_target[b][pos_mask] = nn.functional.one_hot(tlabel.long(), num_classes=num_classes).float()
# Regression target for positive samples
matched_gt = tbox_xywh[pos_mask.sum(dim=0).nonzero(as_tuple=True)[0]] # tricky fix; assume one-to-one
if len(matched_gt) > 0:
reg_target[b][pos_mask] = torch.cat([
(matched_gt[:, :2] - anchor_points[b][pos_mask]) / stride,
torch.log(matched_gt[:, 2:] / stride + 1e-8)
], dim=1)
# Only compute loss on positive samples
if fg_mask.any():
loss_obj = criterion_obj(obj_pred, obj_target)
total_loss_obj += loss_obj.mean()
pos_obj = obj_target == 1
if pos_obj.any():
total_loss_cls += criterion_cls(cls_pred[pos_obj], cls_target[pos_obj]).mean()
total_loss_reg += torch.abs(reg_pred[pos_obj] - reg_target[pos_obj]).mean()
num_positive += pos_obj.sum().item()
else:
total_loss_obj += obj_pred.sum() * 0
total_loss_cls += cls_pred.sum() * 0
total_loss_reg += reg_pred.sum() * 0
# Normalize losses
if num_positive > 0:
total_loss_reg /= num_positive
total_loss_cls /= num_positive
total_loss_obj /= len(outputs)
total_loss = total_loss_obj + total_loss_cls + total_loss_reg * 5.0 # weight reg more
return total_loss, total_loss_obj.detach(), total_loss_cls.detach(), total_loss_reg.detach()
# 模型训练(真实损失)
model = YOLOv8(num_classes=20).train().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
for epoch in range(10):
print(f"\nEpoch {epoch + 1}/10")
for i, (images, targets) in enumerate(train_loader):
images = images.cuda(non_blocking=True)
# 前向传播
outputs = model(images) # List[Tensor]: [B, 25, H/8, W/8], ...
# 计算真实损失
loss, loss_obj, loss_cls, loss_reg = compute_loss(outputs, targets, num_classes=20)
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i % 10 == 0:
print(f"Iter {i}, Loss: {loss.item():.4f} "
f"(obj={loss_obj.item():.4f}, cls={loss_cls.item():.4f}, reg={loss_reg.item():.4f})")
可视化
最新发布