import torch
import torch.nn as nn
from torchvision.datasets import VOCDetection
from torchvision import transforms
from torch.utils.data import Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
#基础组件
class Conv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=None, groups=1):
super().__init__()
if padding is None:
padding = (kernel_size - 1) // 2
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, groups=groups, bias=False)
self.bn = nn.BatchNorm2d(out_channels)
self.act = nn.SiLU() # YOLOv8 使用 SiLU 激活函数
def forward(self, x):
return self.act(self.bn(self.conv(x)))
#基础组件
class C2f(nn.Module):
def __init__(self, in_channels, out_channels, num_blocks=2, shortcut=False):
super().__init__()
self.out_channels = out_channels
hidden_channels = int(out_channels * 0.5)
self.conv1 = Conv(in_channels, hidden_channels * 2, 1)
self.conv2 = Conv((hidden_channels * 2 + hidden_channels * num_blocks), out_channels, 1)
self.blocks = nn.ModuleList()
for _ in range(num_blocks):
self.blocks.append(Conv(hidden_channels, hidden_channels, 3))
self.shortcut = shortcut
def forward(self, x):
y = list(self.conv1(x).chunk(2, 1)) # Split into two halves
for block in self.blocks:
y.append(block(y[-1]))
return self.conv2(torch.cat(y, dim=1))
#基础组件
class SPPF(nn.Module):
def __init__(self, in_channels, out_channels, k=5):
super().__init__()
hidden_channels = in_channels // 2
self.conv1 = Conv(in_channels, hidden_channels, 1)
self.pool = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
self.conv2 = Conv(hidden_channels * 4, out_channels, 1)
def forward(self, x):
x = self.conv1(x)
pool1 = self.pool(x)
pool2 = self.pool(pool1)
pool3 = self.pool(pool2)
return self.conv2(torch.cat([x, pool1, pool2, pool3], dim=1))
#主干网络
class Backbone(nn.Module):
def __init__(self):
super().__init__()
self.stage1 = nn.Sequential(
Conv(3, 64, 3, 2),
Conv(64, 128, 3, 2),
C2f(128, 128, 3)
)
self.stage2 = nn.Sequential(
Conv(128, 256, 3, 2),
C2f(256, 256, 6)
)
self.stage3 = nn.Sequential(
Conv(256, 512, 3, 2),
C2f(512, 512, 6)
)
self.stage4 = nn.Sequential(
Conv(512, 1024, 3, 2),
C2f(1024, 1024, 3),
SPPF(1024, 1024)
)
def forward(self, x):
x = self.stage1(x)
x = self.stage2(x)
c3 = self.stage3(x)
c4 = self.stage4(c3)
return c3, c4 # 输出两个尺度用于 Neck
#特征融合
class Neck(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = Conv(1024, 512, 1)
self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
self.c2f1 = C2f(512 + 512, 512, 3)
self.conv2 = Conv(512, 256, 1)
self.c2f2 = C2f(256 + 256, 256, 3)
self.conv3 = Conv(256, 256, 3, 2)
self.c2f3 = C2f(256 + 256, 512, 3)
self.conv4 = Conv(512, 512, 3, 2)
self.c2f4 = C2f(512 + 512, 1024, 3)
def forward(self, c3, c4):
# 自顶向下:上采样 + 融合
p4 = self.conv1(c4)
p4_up = self.upsample(p4)
p4_cat = torch.cat([p4_up, c3], dim=1)
p3 = self.c2f1(p4_cat)
# 更高分辨率输出
p3_out = self.conv2(p3)
p3_down = self.conv3(p3)
p3_down_cat = torch.cat([p3_down, p4], dim=1)
p4_out = self.c2f3(p3_down_cat)
# 最深层输出
p4_down = self.conv4(p4_out)
p4_down_cat = torch.cat([p4_down, c4], dim=1)
p5_out = self.c2f4(p4_down_cat)
return p3_out, p4_out, p5_out
#解耦检测头
class DecoupledHead(nn.Module):
def __init__(self, in_channels, num_classes=80):
super().__init__()
# 分离的 3×3 卷积分支
self.cls_conv = nn.Conv2d(in_channels, in_channels, 3, padding=1)
self.reg_conv = nn.Conv2d(in_channels, in_channels, 3, padding=1)
# 预测层
self.cls_pred = nn.Conv2d(in_channels, num_classes, 1)
self.reg_pred = nn.Conv2d(in_channels, 4, 1) # tx, ty, tw, th
self.obj_pred = nn.Conv2d(in_channels, 1, 1) # objectness
self.act = nn.SiLU()
def forward(self, x):
c = self.act(self.cls_conv(x))
r = self.act(self.reg_conv(x))
cls = self.cls_pred(c)
reg = self.reg_pred(r)
obj = self.obj_pred(r)
return torch.cat([reg, obj, cls], dim=1)
#多尺度检测头
class Detect(nn.Module):
def __init__(self, num_classes=80):
super().__init__()
self.num_classes = num_classes
self.strides = [8, 16, 32] # 对应输出尺度
# 为每个尺度创建一个解耦头
self.head_small = DecoupledHead(256, num_classes)
self.head_medium = DecoupledHead(512, num_classes)
self.head_large = DecoupledHead(1024, num_classes)
def forward(self, x):
p3, p4, p5 = x
pred_small = self.head_small(p3) # shape: (B, 4 + 1 + num_classes, H/8, W/8)
pred_medium = self.head_medium(p4) # (B, ..., H/16, W/16)
pred_large = self.head_large(p5) # (B, ..., H/32, W/32)
return [pred_small, pred_medium, pred_large]
#YOLO-V8整体模型
class YOLOv8(nn.Module):
def __init__(self, num_classes=80):
super().__init__()
self.backbone = Backbone()
self.neck = Neck()
self.detect = Detect(num_classes)
def forward(self, x):
c3, c4 = self.backbone(x)
features = self.neck(c3, c4)
predictions = self.detect(features)
return predictions
#----------------------------------------------------------------------------------------------------------------------#
# 下载并解压PASCAL VOC 2012
dataset_dir = "./VOCdevkit"
voc_train = VOCDetection(
root=dataset_dir,
year='2012',
image_set='train',
download=True # 如果没有会尝试下载
)
voc_val = VOCDetection(
root=dataset_dir,
year='2012',
image_set='val',
download=True
)
# 图像变换
transform = transforms.Compose([
transforms.Resize((640, 640)), # 调整为网络输入大小
transforms.ToTensor(), # 转为 [0,1] 归一化张量
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet 标准化
])
#定义Dataset类
class VOCDataset(Dataset):
def __init__(self, data_list, transform=None):
self.data_list = data_list
self.transform = transform
self.class_to_idx = {
'aeroplane': 0, 'bicycle': 1, 'bird': 2, 'boat': 3,
'bottle': 4, 'bus': 5, 'car': 6, 'cat': 7,
'chair': 8, 'cow': 9, 'diningtable': 10, 'dog': 11,
'horse': 12, 'motorbike': 13, 'person': 14, 'pottedplant': 15,
'sheep': 16, 'sofa': 17, 'train': 18, 'tvmonitor': 19
}
def __len__(self):
return len(self.data_list)
def __getitem__(self, idx):
image, ann = self.data_list[idx]
img = image.convert("RGB")
# 解析标注
boxes = []
labels = []
for obj in ann['annotation']['object']:
cls_name = obj['name']
if cls_name not in self.class_to_idx:
continue
label = self.class_to_idx[cls_name]
bbox = obj['bndbox']
xmin = float(bbox['xmin'])
ymin = float(bbox['ymin'])
xmax = float(bbox['xmax'])
ymax = float(bbox['ymax'])
boxes.append([xmin, ymin, xmax, ymax])
labels.append(label)
boxes = torch.tensor(boxes, dtype=torch.float32)
labels = torch.tensor(labels, dtype=torch.long)
# 数据增强 + resize 坐标
if self.transform is not None:
aug_transform = A.Compose([
A.Resize(640, 640),
A.HorizontalFlip(p=0.5),
A.ColorJitter(p=0.3),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2()
], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['class_labels']))
target = {
"boxes": boxes,
"labels": labels,
"image_id": torch.tensor([idx])
}
return img, target
# 创建数据集实例
train_dataset = VOCDataset(voc_train, transform=transform)
val_dataset = VOCDataset(voc_val, transform=transform)
# 使用 DataLoader 加载批次
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=8,
shuffle=True,
collate_fn=lambda x: tuple(zip(*x)) # 处理不同数量的目标
)
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=8,
shuffle=False,
collate_fn=lambda x: tuple(zip(*x))
)
#损失函数
def compute_loss(predictions, targets, num_classes=20):
device = predictions[0].device
loss_cls = nn.BCEWithLogitsLoss()
loss_obj = nn.BCEWithLogitsLoss()
loss_box = lambda b_pred, b_gt: torch.abs(b_pred - b_gt).sum() # 可替换为 CIoU
total_loss = 0
for pred in predictions: # 遍历每个尺度输出
bs, _, ny, nx = pred.shape
pred = pred.view(bs, 3, -1, ny, nx) # 假设每层有3个anchor
reg_pred = pred[..., :4] # tx, ty, tw, th
obj_pred = pred[..., 4:5]
cls_pred = pred[..., 5:]
# 这里需要正样本匹配逻辑(后续任务对齐分配器)
# 当前仅示意:跳过复杂匹配,假设已有对齐目标
pass # 实际训练中应加入 Task-Aligned Assigner 或 SimOTA
return total_loss
#模型训练
model = YOLOv8(num_classes=20).train().cuda() # 20类
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
for epoch in range(10):
print(f"Epoch {epoch+1}/10")
for i, batch in enumerate(train_loader):
images, targets = batch
images = torch.stack(images).cuda()
# 前向传播
outputs = model(images) # List[Tensor]: [B, 25, H, W] * 3
# 简单损失占位符(避免报错)
loss = sum(out.sum() * 0 for out in outputs) # 占位,不实际计算
# 反向传播(仅测试梯度流)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i % 10 == 0:
print(f"Iter {i}, Loss: {loss.item():.4f}")
最新发布