from albumentations import torch as AT报错处理

本文解决了一个在使用Albumentations库进行数据增强时常见的导入错误。原本的导入语句from albumentations import torch as AT应更正为from albumentations import pytorch as AT。这一更改确保了代码能够正确运行,避免了因导入模块名称变更导致的运行时错误。

问题

最近看kaggle大佬的代码,大家数据扩增(Augmentation)的时候都喜欢用albumentations,但是运行

from albumentations import torch as AT

会报错呀

error

解决办法

去git上面看了看albumentations的源代码,现在这句话应该写成

from albumentations import pytorch as AT

这样就好啦

import torch import torch.nn as nn import torchvision from torchvision.datasets import VOCDetection from torchvision import transforms from torch.utils.data import Dataset import albumentations as A from albumentations.pytorch import ToTensorV2 import numpy as np import matplotlib.pyplot as plt from datasets import load_dataset #----------------------------------------------------------------------------------------------------------------------# #模型构建 #基础组件 class Conv(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=None, groups=1): super().__init__() if padding is None: padding = (kernel_size - 1) // 2 self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, groups=groups, bias=False) self.bn = nn.BatchNorm2d(out_channels) self.act = nn.SiLU() # YOLOv8 使用 SiLU 激活函数 def forward(self, x): return self.act(self.bn(self.conv(x))) #基础组件 class C2f(nn.Module): def __init__(self, in_channels, out_channels, num_blocks=2, shortcut=False): super().__init__() self.out_channels = out_channels hidden_channels = int(out_channels * 0.5) self.conv1 = Conv(in_channels, hidden_channels * 2, 1) self.conv2 = Conv((hidden_channels * 2 + hidden_channels * num_blocks), out_channels, 1) self.blocks = nn.ModuleList() for _ in range(num_blocks): self.blocks.append(Conv(hidden_channels, hidden_channels, 3)) self.shortcut = shortcut def forward(self, x): y = list(self.conv1(x).chunk(2, 1)) # Split into two halves for block in self.blocks: y.append(block(y[-1])) return self.conv2(torch.cat(y, dim=1)) #基础组件 class SPPF(nn.Module): def __init__(self, in_channels, out_channels, k=5): super().__init__() hidden_channels = in_channels // 2 self.conv1 = Conv(in_channels, hidden_channels, 1) self.pool = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2) self.conv2 = Conv(hidden_channels * 4, out_channels, 1) def forward(self, x): x = self.conv1(x) pool1 = self.pool(x) pool2 = self.pool(pool1) pool3 = self.pool(pool2) return self.conv2(torch.cat([x, pool1, pool2, pool3], dim=1)) #主干网络 class Backbone(nn.Module): def __init__(self): super().__init__() self.stage1 = nn.Sequential( Conv(3, 64, 3, 2), Conv(64, 128, 3, 2), C2f(128, 128, 3) ) self.stage2 = nn.Sequential( Conv(128, 256, 3, 2), C2f(256, 256, 6) ) self.stage3 = nn.Sequential( Conv(256, 512, 3, 2), C2f(512, 512, 6) ) self.stage4 = nn.Sequential( Conv(512, 1024, 3, 2), C2f(1024, 1024, 3), SPPF(1024, 1024) ) def forward(self, x): x = self.stage1(x) x = self.stage2(x) c3 = self.stage3(x) c4 = self.stage4(c3) return c3, c4 # 输出两个尺度用于 Neck #特征融合 class Neck(nn.Module): def __init__(self): super().__init__() self.conv1 = Conv(1024, 512, 1) self.upsample = nn.Upsample(scale_factor=2, mode='nearest') self.c2f1 = C2f(512 + 512, 512, 3) self.conv2 = Conv(512, 256, 1) self.c2f2 = C2f(256 + 256, 256, 3) self.conv3 = Conv(256, 256, 3, 2) self.c2f3 = C2f(256 + 256, 512, 3) self.conv4 = Conv(512, 512, 3, 2) self.c2f4 = C2f(512 + 512, 1024, 3) def forward(self, c3, c4): # 自顶向下:上采样 + 融合 p4 = self.conv1(c4) p4_up = self.upsample(p4) p4_cat = torch.cat([p4_up, c3], dim=1) p3 = self.c2f1(p4_cat) # 更高分辨率输出 p3_out = self.conv2(p3) p3_down = self.conv3(p3) p3_down_cat = torch.cat([p3_down, p4], dim=1) p4_out = self.c2f3(p3_down_cat) # 最深层输出 p4_down = self.conv4(p4_out) p4_down_cat = torch.cat([p4_down, c4], dim=1) p5_out = self.c2f4(p4_down_cat) return p3_out, p4_out, p5_out #解耦检测头 class DecoupledHead(nn.Module): def __init__(self, in_channels, num_classes=80): super().__init__() # 分离的 3×3 卷积分支 self.cls_conv = nn.Conv2d(in_channels, in_channels, 3, padding=1) self.reg_conv = nn.Conv2d(in_channels, in_channels, 3, padding=1) # 预测层 self.cls_pred = nn.Conv2d(in_channels, num_classes, 1) self.reg_pred = nn.Conv2d(in_channels, 4, 1) # tx, ty, tw, th self.obj_pred = nn.Conv2d(in_channels, 1, 1) # objectness self.act = nn.SiLU() def forward(self, x): c = self.act(self.cls_conv(x)) r = self.act(self.reg_conv(x)) cls = self.cls_pred(c) reg = self.reg_pred(r) obj = self.obj_pred(r) return torch.cat([reg, obj, cls], dim=1) #多尺度检测头 class Detect(nn.Module): def __init__(self, num_classes=80): super().__init__() self.num_classes = num_classes self.strides = [8, 16, 32] # 对应输出尺度 # 为每个尺度创建一个解耦头 self.head_small = DecoupledHead(256, num_classes) self.head_medium = DecoupledHead(512, num_classes) self.head_large = DecoupledHead(1024, num_classes) def forward(self, x): p3, p4, p5 = x pred_small = self.head_small(p3) # shape: (B, 4 + 1 + num_classes, H/8, W/8) pred_medium = self.head_medium(p4) # (B, ..., H/16, W/16) pred_large = self.head_large(p5) # (B, ..., H/32, W/32) return [pred_small, pred_medium, pred_large] #YOLO-V8整体模型 class YOLOv8(nn.Module): def __init__(self, num_classes=80): super().__init__() self.backbone = Backbone() self.neck = Neck() self.detect = Detect(num_classes) def forward(self, x): c3, c4 = self.backbone(x) features = self.neck(c3, c4) predictions = self.detect(features) return predictions #----------------------------------------------------------------------------------------------------------------------# #模型训练测试 # 下载并解压PASCAL VOC 2012 dataset_dir = "./VOCdevkit" voc_train = load_dataset( path=dataset_dir, year='2012', image_set='train', download=True # 如果没有会尝试下载 ) voc_val = load_dataset( path=dataset_dir, year='2012', image_set='val', download=True ) # 图像变换 transform = transforms.Compose([ transforms.Resize((640, 640)), # 调整为网络输入大小 transforms.ToTensor(), # 转为 [0,1] 归一化张量 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet 标准化 ]) #定义Dataset类 class VOCDataset(Dataset): def __init__(self, data_list, img_size=640): self.data_list = data_list self.img_size = img_size self.class_to_idx = { 'aeroplane': 0, 'bicycle': 1, 'bird': 2, 'boat': 3, 'bottle': 4, 'bus': 5, 'car': 6, 'cat': 7, 'chair': 8, 'cow': 9, 'diningtable': 10, 'dog': 11, 'horse': 12, 'motorbike': 13, 'person': 14, 'pottedplant': 15, 'sheep': 16, 'sofa': 17, 'train': 18, 'tvmonitor': 19 } # 定义增强 pipeline(包含 resize, flip, color jitter, normalize, to_tensor) self.transform = A.Compose([ A.Resize(height=img_size, width=img_size), A.HorizontalFlip(p=0.5), A.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1, p=0.5), A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ToTensorV2() # 转为 [C,H,W] tensor ], bbox_params=A.BboxParams( format='pascal_voc', label_fields=['class_labels'], min_visibility=0.1 )) def __len__(self): return len(self.data_list) def __getitem__(self, idx): image, ann = self.data_list[idx] img = np.array(image.convert("RGB")) # 转为 numpy array boxes = [] labels = [] for obj in ann['annotation']['object']: cls_name = obj['name'] if cls_name not in self.class_to_idx: continue label = self.class_to_idx[cls_name] bbox = obj['bndbox'] xmin = float(bbox['xmin']) ymin = float(bbox['ymin']) xmax = float(bbox['xmax']) ymax = float(bbox['ymax']) # 确保坐标合法 if xmax > xmin and ymax > ymin: boxes.append([xmin, ymin, xmax, ymax]) labels.append(label) # 若无有效标注,添加虚拟目标避免崩溃 if len(boxes) == 0: boxes = [[0, 0, 10, 10]] labels = [0] # 应用增强(自动完成 resize + flip + normalize + to_tensor) try: transformed = self.transform(image=img, bboxes=boxes, class_labels=labels) except Exception as e: print(f"Augmentation error at index {idx}: {e}") # 回退到最小变换 transformed = { "image": ToTensorV2()(image=transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] )(transforms.ToTensor()(img)))["image"], "bboxes": torch.tensor(boxes, dtype=torch.float32), "class_labels": torch.tensor(labels, dtype=torch.long) } img_tensor = transformed["image"] # 已经是 [3,640,640] 归一化张量 boxes_tensor = torch.tensor(transformed["bboxes"], dtype=torch.float32) labels_tensor = torch.tensor(transformed["class_labels"], dtype=torch.long) target = { "boxes": boxes_tensor, "labels": labels_tensor, "image_id": torch.tensor([idx]), } return img_tensor, target # 创建数据集实例 train_dataset = VOCDataset(voc_train, img_size=640) val_dataset = VOCDataset(voc_val, img_size=640) # 使用 DataLoader 加载批次 train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=8, shuffle=True, num_workers=4, pin_memory=True ) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=8, shuffle=False, num_workers=4, pin_memory=True ) #损失函数 def compute_loss(outputs, targets, strides=[8, 16, 32], num_classes=20, alpha=0.5, gamma=0.5): """ Compute YOLOv8-style loss with simple positive sample assignment. outputs: List[Tensor] -> [P3, P4, P5], each shape (B, C, H, W) targets: List[dict] -> [{'boxes': (N,4), 'labels': (N,)}] """ device = outputs[0].device criterion_cls = nn.BCEWithLogitsLoss(reduction='none') criterion_obj = nn.BCEWithLogitsLoss(reduction='none') total_loss_cls = torch.zeros(1, device=device) total_loss_obj = torch.zeros(1, device=device) total_loss_reg = torch.zeros(1, device=device) num_positive = 0 # 特征图尺寸对应关系 img_size = 640 feature_sizes = [img_size // s for s in strides] # [80, 40, 20] for i, pred in enumerate(outputs): H, W = feature_sizes[i], feature_sizes[i] stride = strides[i] bs = pred.shape[0] # Reshape & split predictions: [B, C, H, W] -> [B, H*W, *] pred = pred.permute(0, 2, 3, 1).reshape(bs, -1, 4 + 1 + num_classes) reg_pred = pred[..., :4] # tx, ty, tw, th obj_pred = pred[..., 4] # objectness (flattened) cls_pred = pred[..., 5:] # class logits # Generate grid centers (center of each grid cell in original image space) yv, xv = torch.meshgrid([torch.arange(H), torch.arange(W)]) grid_xy = torch.stack((xv, yv), dim=2).float().to(device) # (H, W, 2) grid_xy = grid_xy.reshape(-1, 2) # (H*W, 2) grid_xy = grid_xy.unsqueeze(0).expand(bs, -1, -1) # (B, H*W, 2) anchor_points = (grid_xy + 0.5) * stride # (cx, cy) on original image # Decode predicted boxes (in xywh format) pred_xy = anchor_points + reg_pred[..., :2].sigmoid() * stride - 0.5 * stride pred_wh = torch.exp(reg_pred[..., 2:]) * stride pred_boxes = torch.cat([pred_xy, pred_wh], dim=-1) # (B, H*W, 4) # Prepare targets obj_target = torch.zeros_like(obj_pred) cls_target = torch.zeros_like(cls_pred) reg_target = torch.zeros_like(reg_pred) fg_mask = torch.zeros_like(obj_pred, dtype=torch.bool) for b in range(bs): tbox = targets[b]['boxes'] # (N, 4), xyxy format tlabel = targets[b]['labels'] # (N,) if len(tbox) == 0: continue # Convert tbox from xyxy to xywh tbox_xyxy = tbox tbox_xywh = torch.cat([ (tbox_xyxy[:, :2] + tbox_xyxy[:, 2:]) / 2, tbox_xyxy[:, 2:] - tbox_xyxy[:, :2] ], dim=1) # (N, 4) # Match: find best overlap between gt centers and anchor points gt_centers = tbox_xywh[:, :2] # (N, 2) distances = (anchor_points[b].unsqueeze(1) - gt_centers.unsqueeze(0)).pow(2).sum(dim=-1) # (H*W, N) _, closest_grid_idx = distances.min(dim=0) # each gt → nearest grid _, closest_gt_idx = distances.min(dim=1) # each grid → nearest gt # Positive samples: grids whose closest gt is itself pos_mask = torch.zeros(H * W, dtype=torch.bool, device=device) for gt_i in range(len(tbox)): grid_i = closest_grid_idx[gt_i] if closest_gt_idx[grid_i] == gt_i: # mutual match pos_mask[grid_i] = True fg_mask[b][pos_mask] = True obj_target[b][pos_mask] = 1.0 cls_target[b][pos_mask] = nn.functional.one_hot(tlabel.long(), num_classes=num_classes).float() # Regression target for positive samples matched_gt = tbox_xywh[pos_mask.sum(dim=0).nonzero(as_tuple=True)[0]] # tricky fix; assume one-to-one if len(matched_gt) > 0: reg_target[b][pos_mask] = torch.cat([ (matched_gt[:, :2] - anchor_points[b][pos_mask]) / stride, torch.log(matched_gt[:, 2:] / stride + 1e-8) ], dim=1) # Only compute loss on positive samples if fg_mask.any(): loss_obj = criterion_obj(obj_pred, obj_target) total_loss_obj += loss_obj.mean() pos_obj = obj_target == 1 if pos_obj.any(): total_loss_cls += criterion_cls(cls_pred[pos_obj], cls_target[pos_obj]).mean() total_loss_reg += torch.abs(reg_pred[pos_obj] - reg_target[pos_obj]).mean() num_positive += pos_obj.sum().item() else: total_loss_obj += obj_pred.sum() * 0 total_loss_cls += cls_pred.sum() * 0 total_loss_reg += reg_pred.sum() * 0 # Normalize losses if num_positive > 0: total_loss_reg /= num_positive total_loss_cls /= num_positive total_loss_obj /= len(outputs) total_loss = total_loss_obj + total_loss_cls + total_loss_reg * 5.0 # weight reg more return total_loss, total_loss_obj.detach(), total_loss_cls.detach(), total_loss_reg.detach() # 模型训练(真实损失) model = YOLOv8(num_classes=20).train().cuda() optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) for epoch in range(10): print(f"\nEpoch {epoch + 1}/10") for i, (images, targets) in enumerate(train_loader): images = images.cuda(non_blocking=True) # 前向传播 outputs = model(images) # List[Tensor]: [B, 25, H/8, W/8], ... # 计算真实损失 loss, loss_obj, loss_cls, loss_reg = compute_loss(outputs, targets, num_classes=20) # 反向传播 optimizer.zero_grad() loss.backward() optimizer.step() if i % 10 == 0: print(f"Iter {i}, Loss: {loss.item():.4f} " f"(obj={loss_obj.item():.4f}, cls={loss_cls.item():.4f}, reg={loss_reg.item():.4f})") #----------------------------------------------------------------------------------------------------------------------# #可视化 #YOLO输出解码 def decode_outputs(outputs, strides=[8, 16, 32], img_size=640): """ 将 YOLO 输出解码为 [cx, cy, w, h] 形式的归一化框(相对坐标) 返回: (pred_boxes, pred_scores, pred_labels), each as tensor """ device = outputs[0].device pred_boxes = [] pred_scores = [] pred_labels = [] for i, pred in enumerate(outputs): bs, _, ny, nx = pred.shape stride = strides[i] pred = pred.permute(0, 2, 3, 1).reshape(bs, -1, 4 + 1 + 20) # [B, H*W, ...] reg_pred = pred[..., :4] # tx, ty, tw, th obj_pred = pred[..., 4:5].sigmoid() # objectness cls_pred = pred[..., 5:].sigmoid() # class scores # 获取 grid 中心点 yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)]) grid_xy = torch.stack((xv, yv), dim=2).float().to(device) # (ny, nx, 2) grid_xy = grid_xy.reshape(-1, 2).unsqueeze(0).expand(bs, -1, -1) # (B, H*W, 2) # 解码中心点 pred_xy = (grid_xy + 0.5 + reg_pred[..., :2].sigmoid()) * stride # 解码宽高 pred_wh = torch.exp(reg_pred[..., 2:]) * stride # 拼接为 xywh 框 boxes = torch.cat([pred_xy, pred_wh], dim=-1) # (B, H*W, 4) # 计算置信度:obj * max(cls_prob) scores, labels = torch.max(cls_pred, dim=-1) scores = scores * obj_pred.squeeze(-1) # confidence = obj × cls_score # 转换为绝对坐标并归一化到 [0,1] img_w, img_h = img_size, img_size boxes_normalized = boxes / torch.tensor([img_w, img_h, img_w, img_h], device=device) pred_boxes.append(boxes_normalized) pred_scores.append(scores) pred_labels.append(labels) # 合并所有尺度 pred_boxes = torch.cat(pred_boxes, dim=1) # (B, L, 4) pred_scores = torch.cat(pred_scores, dim=1) # (B, L) pred_labels = torch.cat(pred_labels, dim=1) # (B, L) return pred_boxes, pred_scores, pred_labels #NMS后处理 def postprocess(pred_boxes, pred_scores, pred_labels, conf_thresh=0.25, nms_thresh=0.5): """ 对每张图片进行 NMS 后处理 """ final_boxes, final_scores, final_labels = [], [], [] for i in range(pred_boxes.shape[0]): # 遍历 batch 中每张图 boxes = pred_boxes[i] # (L, 4) scores = pred_scores[i] # (L,) labels = pred_labels[i] # (L,) # 过滤低置信度 mask = scores > conf_thresh boxes = boxes[mask] scores = scores[mask] labels = labels[mask] if len(boxes) == 0: final_boxes.append(torch.empty(0, 4)) final_scores.append(torch.empty(0)) final_labels.append(torch.empty(0)) continue # 转 xyxy 并应用 NMS boxes_xyxy = torchvision.ops.box_convert(boxes, in_fmt='cxcywh', out_fmt='xyxy') keep = torchvision.ops.nms(boxes_xyxy, scores, nms_thresh) final_boxes.append(boxes[keep]) final_scores.append(scores[keep]) final_labels.append(labels[keep]) return final_boxes, final_scores, final_labels #可视化函数 def visualize_predictions(model, dataloader, idx_to_class, device="cuda", num_images=4): model.eval() inv_normalize = transforms.Compose([ transforms.Normalize(mean=[0., 0., 0.], std=[1/0.229, 1/0.224, 1/0.225]), transforms.Normalize(mean=[-0.485, -0.456, -0.406], std=[1., 1., 1.]), ]) figure, ax = plt.subplots(num_images, 2, figsize=(12, 6 * num_images)) if num_images == 1: ax = ax.unsqueeze(0) with torch.no_grad(): for i, (images, targets) in enumerate(dataloader): images = images.to(device) outputs = model(images) pred_boxes, pred_scores, pred_labels = decode_outputs(outputs) det_boxes, det_scores, det_labels = postprocess(pred_boxes, pred_scores, pred_labels) for j in range(min(num_images, len(images))): # 图像 j img = images[j].cpu() img = inv_normalize(img) img = torch.clamp(img, 0, 1) img = transforms.ToPILImage()(img) # 绘图 ax[j, 0].imshow(img); ax[j, 0].set_title("Ground Truth") ax[j, 1].imshow(img); ax[j, 1].set_title("Predictions") # 绘制 GT gt_boxes = targets[j]['boxes'].cpu() gt_labels = targets[j]['labels'].cpu() for k in range(len(gt_boxes)): box = gt_boxes[k].numpy() label = idx_to_class[gt_labels[k].item()] rect = plt.Rectangle( (box[0]*640, box[1]*640), (box[2]-box[0])*640, (box[3]-box[1])*640, fill=False, edgecolor='green', linewidth=2 ) ax[j, 0].add_patch(rect) ax[j, 0].text(box[0]*640, box[1]*640, label, color='white', fontsize=10, bbox=dict(facecolor='green', alpha=0.7)) # 绘制 Pred pred_box_img = det_boxes[j].cpu() * 640 # 转为像素坐标 pred_score_img = det_scores[j].cpu() pred_label_img = det_labels[j].cpu() for k in range(len(pred_box_img)): box = pred_box_img[k].numpy() score = pred_score_img[k].item() label = idx_to_class[pred_label_img[k].item()] x1, y1 = int(box[0]), int(box[1]) w, h = int(box[2] - box[0]), int(box[3] - box[1]) rect = plt.Rectangle( (x1, y1), w, h, fill=False, edgecolor='red', linewidth=2 ) ax[j, 1].add_patch(rect) ax[j, 1].text(x1, y1, f"{label}:{score:.2f}", color='white', fontsize=10, bbox=dict(facecolor='red', alpha=0.7)) ax[j, 0].axis('off'); ax[j, 1].axis('off') break # 只看第一批次 plt.tight_layout() plt.show() model.train() # 回到训练模式 # 类别索引反查表 idx_to_class = {v: k for k, v in train_dataset.class_to_idx.items()} # 在某个 epoch 后调用 visualize_predictions(model, val_loader, idx_to_class, device="cuda", num_images=4)
最新发布
11-27
import torch import torch.nn as nn from torchvision.datasets import VOCDetection from torchvision import transforms from torch.utils.data import Dataset import albumentations as A from albumentations.pytorch import ToTensorV2 #基础组件 class Conv(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=None, groups=1): super().__init__() if padding is None: padding = (kernel_size - 1) // 2 self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, groups=groups, bias=False) self.bn = nn.BatchNorm2d(out_channels) self.act = nn.SiLU() # YOLOv8 使用 SiLU 激活函数 def forward(self, x): return self.act(self.bn(self.conv(x))) #基础组件 class C2f(nn.Module): def __init__(self, in_channels, out_channels, num_blocks=2, shortcut=False): super().__init__() self.out_channels = out_channels hidden_channels = int(out_channels * 0.5) self.conv1 = Conv(in_channels, hidden_channels * 2, 1) self.conv2 = Conv((hidden_channels * 2 + hidden_channels * num_blocks), out_channels, 1) self.blocks = nn.ModuleList() for _ in range(num_blocks): self.blocks.append(Conv(hidden_channels, hidden_channels, 3)) self.shortcut = shortcut def forward(self, x): y = list(self.conv1(x).chunk(2, 1)) # Split into two halves for block in self.blocks: y.append(block(y[-1])) return self.conv2(torch.cat(y, dim=1)) #基础组件 class SPPF(nn.Module): def __init__(self, in_channels, out_channels, k=5): super().__init__() hidden_channels = in_channels // 2 self.conv1 = Conv(in_channels, hidden_channels, 1) self.pool = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2) self.conv2 = Conv(hidden_channels * 4, out_channels, 1) def forward(self, x): x = self.conv1(x) pool1 = self.pool(x) pool2 = self.pool(pool1) pool3 = self.pool(pool2) return self.conv2(torch.cat([x, pool1, pool2, pool3], dim=1)) #主干网络 class Backbone(nn.Module): def __init__(self): super().__init__() self.stage1 = nn.Sequential( Conv(3, 64, 3, 2), Conv(64, 128, 3, 2), C2f(128, 128, 3) ) self.stage2 = nn.Sequential( Conv(128, 256, 3, 2), C2f(256, 256, 6) ) self.stage3 = nn.Sequential( Conv(256, 512, 3, 2), C2f(512, 512, 6) ) self.stage4 = nn.Sequential( Conv(512, 1024, 3, 2), C2f(1024, 1024, 3), SPPF(1024, 1024) ) def forward(self, x): x = self.stage1(x) x = self.stage2(x) c3 = self.stage3(x) c4 = self.stage4(c3) return c3, c4 # 输出两个尺度用于 Neck #特征融合 class Neck(nn.Module): def __init__(self): super().__init__() self.conv1 = Conv(1024, 512, 1) self.upsample = nn.Upsample(scale_factor=2, mode='nearest') self.c2f1 = C2f(512 + 512, 512, 3) self.conv2 = Conv(512, 256, 1) self.c2f2 = C2f(256 + 256, 256, 3) self.conv3 = Conv(256, 256, 3, 2) self.c2f3 = C2f(256 + 256, 512, 3) self.conv4 = Conv(512, 512, 3, 2) self.c2f4 = C2f(512 + 512, 1024, 3) def forward(self, c3, c4): # 自顶向下:上采样 + 融合 p4 = self.conv1(c4) p4_up = self.upsample(p4) p4_cat = torch.cat([p4_up, c3], dim=1) p3 = self.c2f1(p4_cat) # 更高分辨率输出 p3_out = self.conv2(p3) p3_down = self.conv3(p3) p3_down_cat = torch.cat([p3_down, p4], dim=1) p4_out = self.c2f3(p3_down_cat) # 最深层输出 p4_down = self.conv4(p4_out) p4_down_cat = torch.cat([p4_down, c4], dim=1) p5_out = self.c2f4(p4_down_cat) return p3_out, p4_out, p5_out #解耦检测头 class DecoupledHead(nn.Module): def __init__(self, in_channels, num_classes=80): super().__init__() # 分离的 3×3 卷积分支 self.cls_conv = nn.Conv2d(in_channels, in_channels, 3, padding=1) self.reg_conv = nn.Conv2d(in_channels, in_channels, 3, padding=1) # 预测层 self.cls_pred = nn.Conv2d(in_channels, num_classes, 1) self.reg_pred = nn.Conv2d(in_channels, 4, 1) # tx, ty, tw, th self.obj_pred = nn.Conv2d(in_channels, 1, 1) # objectness self.act = nn.SiLU() def forward(self, x): c = self.act(self.cls_conv(x)) r = self.act(self.reg_conv(x)) cls = self.cls_pred(c) reg = self.reg_pred(r) obj = self.obj_pred(r) return torch.cat([reg, obj, cls], dim=1) #多尺度检测头 class Detect(nn.Module): def __init__(self, num_classes=80): super().__init__() self.num_classes = num_classes self.strides = [8, 16, 32] # 对应输出尺度 # 为每个尺度创建一个解耦头 self.head_small = DecoupledHead(256, num_classes) self.head_medium = DecoupledHead(512, num_classes) self.head_large = DecoupledHead(1024, num_classes) def forward(self, x): p3, p4, p5 = x pred_small = self.head_small(p3) # shape: (B, 4 + 1 + num_classes, H/8, W/8) pred_medium = self.head_medium(p4) # (B, ..., H/16, W/16) pred_large = self.head_large(p5) # (B, ..., H/32, W/32) return [pred_small, pred_medium, pred_large] #YOLO-V8整体模型 class YOLOv8(nn.Module): def __init__(self, num_classes=80): super().__init__() self.backbone = Backbone() self.neck = Neck() self.detect = Detect(num_classes) def forward(self, x): c3, c4 = self.backbone(x) features = self.neck(c3, c4) predictions = self.detect(features) return predictions #----------------------------------------------------------------------------------------------------------------------# # 下载并解压PASCAL VOC 2012 dataset_dir = "./VOCdevkit" voc_train = VOCDetection( root=dataset_dir, year='2012', image_set='train', download=True # 如果没有会尝试下载 ) voc_val = VOCDetection( root=dataset_dir, year='2012', image_set='val', download=True ) # 图像变换 transform = transforms.Compose([ transforms.Resize((640, 640)), # 调整为网络输入大小 transforms.ToTensor(), # 转为 [0,1] 归一化张量 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet 标准化 ]) #定义Dataset类 class VOCDataset(Dataset): def __init__(self, data_list, transform=None): self.data_list = data_list self.transform = transform self.class_to_idx = { 'aeroplane': 0, 'bicycle': 1, 'bird': 2, 'boat': 3, 'bottle': 4, 'bus': 5, 'car': 6, 'cat': 7, 'chair': 8, 'cow': 9, 'diningtable': 10, 'dog': 11, 'horse': 12, 'motorbike': 13, 'person': 14, 'pottedplant': 15, 'sheep': 16, 'sofa': 17, 'train': 18, 'tvmonitor': 19 } def __len__(self): return len(self.data_list) def __getitem__(self, idx): image, ann = self.data_list[idx] img = image.convert("RGB") # 解析标注 boxes = [] labels = [] for obj in ann['annotation']['object']: cls_name = obj['name'] if cls_name not in self.class_to_idx: continue label = self.class_to_idx[cls_name] bbox = obj['bndbox'] xmin = float(bbox['xmin']) ymin = float(bbox['ymin']) xmax = float(bbox['xmax']) ymax = float(bbox['ymax']) boxes.append([xmin, ymin, xmax, ymax]) labels.append(label) boxes = torch.tensor(boxes, dtype=torch.float32) labels = torch.tensor(labels, dtype=torch.long) # 数据增强 + resize 坐标 if self.transform is not None: aug_transform = A.Compose([ A.Resize(640, 640), A.HorizontalFlip(p=0.5), A.ColorJitter(p=0.3), A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ToTensorV2() ], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['class_labels'])) target = { "boxes": boxes, "labels": labels, "image_id": torch.tensor([idx]) } return img, target # 创建数据集实例 train_dataset = VOCDataset(voc_train, transform=transform) val_dataset = VOCDataset(voc_val, transform=transform) # 使用 DataLoader 加载批次 train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=8, shuffle=True, collate_fn=lambda x: tuple(zip(*x)) # 处理不同数量的目标 ) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=8, shuffle=False, collate_fn=lambda x: tuple(zip(*x)) ) #损失函数 def compute_loss(predictions, targets, num_classes=20): device = predictions[0].device loss_cls = nn.BCEWithLogitsLoss() loss_obj = nn.BCEWithLogitsLoss() loss_box = lambda b_pred, b_gt: torch.abs(b_pred - b_gt).sum() # 可替换为 CIoU total_loss = 0 for pred in predictions: # 遍历每个尺度输出 bs, _, ny, nx = pred.shape pred = pred.view(bs, 3, -1, ny, nx) # 假设每层有3个anchor reg_pred = pred[..., :4] # tx, ty, tw, th obj_pred = pred[..., 4:5] cls_pred = pred[..., 5:] # 这里需要正样本匹配逻辑(后续任务对齐分配器) # 当前仅示意:跳过复杂匹配,假设已有对齐目标 pass # 实际训练中应加入 Task-Aligned Assigner 或 SimOTA return total_loss #模型训练 model = YOLOv8(num_classes=20).train().cuda() # 20类 optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) for epoch in range(10): print(f"Epoch {epoch+1}/10") for i, batch in enumerate(train_loader): images, targets = batch images = torch.stack(images).cuda() # 前向传播 outputs = model(images) # List[Tensor]: [B, 25, H, W] * 3 # 简单损失占位符(避免报错) loss = sum(out.sum() * 0 for out in outputs) # 占位,不实际计算 # 反向传播(仅测试梯度流) optimizer.zero_grad() loss.backward() optimizer.step() if i % 10 == 0: print(f"Iter {i}, Loss: {loss.item():.4f}")
11-27
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值