第64题 3Sum Closest

本文介绍了一个算法问题:给定一个整数数组和一个目标值,找到数组中三个整数的和最接近目标值的方法。提供了一种解决方案,通过先排序数组,然后使用双指针技巧遍历数组并调整指针位置来逐步逼近目标值。

Given an array S of n integers, find three integers in S such that the sum is closest to a given number, target. Return the sum of the three integers. You may assume that each input would have exactly one solution.

    For example, given array S = {-1 2 1 -4}, and target = 1.

    The sum that is closest to the target is 2. (-1 + 2 + 1 = 2).

Hide Tags
  Array Two Pointers










Solution in C++:
class Solution {
public:
    int threeSumClosest(vector<int>& nums, int target) {
        int result=0;
        if(nums.size()<=3){
            for(int i=0; i<nums.size(); i++) result+=nums[i];
            return result;
        }
        sort(nums.begin(), nums.end());
        
        result = nums[0]+nums[1]+nums[2];
        int sum;
        for(int i=0; i<nums.size()-2; i++){
            int start = i+1, end = nums.size()-1;
            while(start<end){
              sum = nums[i]+nums[start]+nums[end];
              if(sum==target) return target;
              else if(sum<target) start++;
              else  end--;
              result = abs(result-target)<=abs(sum-target)?result:sum;
            }
            while(nums[i+1]==nums[i]&&i+1<nums.size()) i++;
        }
        return result;
    }
};


import torch import torch.nn as nn from torchvision.datasets import VOCDetection from torchvision import transforms from torch.utils.data import Dataset import albumentations as A from albumentations.pytorch import ToTensorV2 import numpy as np #基础组件 class Conv(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=None, groups=1): super().__init__() if padding is None: padding = (kernel_size - 1) // 2 self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, groups=groups, bias=False) self.bn = nn.BatchNorm2d(out_channels) self.act = nn.SiLU() # YOLOv8 使用 SiLU 激活函数 def forward(self, x): return self.act(self.bn(self.conv(x))) #基础组件 class C2f(nn.Module): def __init__(self, in_channels, out_channels, num_blocks=2, shortcut=False): super().__init__() self.out_channels = out_channels hidden_channels = int(out_channels * 0.5) self.conv1 = Conv(in_channels, hidden_channels * 2, 1) self.conv2 = Conv((hidden_channels * 2 + hidden_channels * num_blocks), out_channels, 1) self.blocks = nn.ModuleList() for _ in range(num_blocks): self.blocks.append(Conv(hidden_channels, hidden_channels, 3)) self.shortcut = shortcut def forward(self, x): y = list(self.conv1(x).chunk(2, 1)) # Split into two halves for block in self.blocks: y.append(block(y[-1])) return self.conv2(torch.cat(y, dim=1)) #基础组件 class SPPF(nn.Module): def __init__(self, in_channels, out_channels, k=5): super().__init__() hidden_channels = in_channels // 2 self.conv1 = Conv(in_channels, hidden_channels, 1) self.pool = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2) self.conv2 = Conv(hidden_channels * 4, out_channels, 1) def forward(self, x): x = self.conv1(x) pool1 = self.pool(x) pool2 = self.pool(pool1) pool3 = self.pool(pool2) return self.conv2(torch.cat([x, pool1, pool2, pool3], dim=1)) #主干网络 class Backbone(nn.Module): def __init__(self): super().__init__() self.stage1 = nn.Sequential( Conv(3, 64, 3, 2), Conv(64, 128, 3, 2), C2f(128, 128, 3) ) self.stage2 = nn.Sequential( Conv(128, 256, 3, 2), C2f(256, 256, 6) ) self.stage3 = nn.Sequential( Conv(256, 512, 3, 2), C2f(512, 512, 6) ) self.stage4 = nn.Sequential( Conv(512, 1024, 3, 2), C2f(1024, 1024, 3), SPPF(1024, 1024) ) def forward(self, x): x = self.stage1(x) x = self.stage2(x) c3 = self.stage3(x) c4 = self.stage4(c3) return c3, c4 # 输出两个尺度用于 Neck #特征融合 class Neck(nn.Module): def __init__(self): super().__init__() self.conv1 = Conv(1024, 512, 1) self.upsample = nn.Upsample(scale_factor=2, mode=&#39;nearest&#39;) self.c2f1 = C2f(512 + 512, 512, 3) self.conv2 = Conv(512, 256, 1) self.c2f2 = C2f(256 + 256, 256, 3) self.conv3 = Conv(256, 256, 3, 2) self.c2f3 = C2f(256 + 256, 512, 3) self.conv4 = Conv(512, 512, 3, 2) self.c2f4 = C2f(512 + 512, 1024, 3) def forward(self, c3, c4): # 自顶向下:上采样 + 融合 p4 = self.conv1(c4) p4_up = self.upsample(p4) p4_cat = torch.cat([p4_up, c3], dim=1) p3 = self.c2f1(p4_cat) # 更高分辨率输出 p3_out = self.conv2(p3) p3_down = self.conv3(p3) p3_down_cat = torch.cat([p3_down, p4], dim=1) p4_out = self.c2f3(p3_down_cat) # 最深层输出 p4_down = self.conv4(p4_out) p4_down_cat = torch.cat([p4_down, c4], dim=1) p5_out = self.c2f4(p4_down_cat) return p3_out, p4_out, p5_out #解耦检测头 class DecoupledHead(nn.Module): def __init__(self, in_channels, num_classes=80): super().__init__() # 分离的 3×3 卷积分支 self.cls_conv = nn.Conv2d(in_channels, in_channels, 3, padding=1) self.reg_conv = nn.Conv2d(in_channels, in_channels, 3, padding=1) # 预测层 self.cls_pred = nn.Conv2d(in_channels, num_classes, 1) self.reg_pred = nn.Conv2d(in_channels, 4, 1) # tx, ty, tw, th self.obj_pred = nn.Conv2d(in_channels, 1, 1) # objectness self.act = nn.SiLU() def forward(self, x): c = self.act(self.cls_conv(x)) r = self.act(self.reg_conv(x)) cls = self.cls_pred(c) reg = self.reg_pred(r) obj = self.obj_pred(r) return torch.cat([reg, obj, cls], dim=1) #多尺度检测头 class Detect(nn.Module): def __init__(self, num_classes=80): super().__init__() self.num_classes = num_classes self.strides = [8, 16, 32] # 对应输出尺度 # 为每个尺度创建一个解耦头 self.head_small = DecoupledHead(256, num_classes) self.head_medium = DecoupledHead(512, num_classes) self.head_large = DecoupledHead(1024, num_classes) def forward(self, x): p3, p4, p5 = x pred_small = self.head_small(p3) # shape: (B, 4 + 1 + num_classes, H/8, W/8) pred_medium = self.head_medium(p4) # (B, ..., H/16, W/16) pred_large = self.head_large(p5) # (B, ..., H/32, W/32) return [pred_small, pred_medium, pred_large] #YOLO-V8整体模型 class YOLOv8(nn.Module): def __init__(self, num_classes=80): super().__init__() self.backbone = Backbone() self.neck = Neck() self.detect = Detect(num_classes) def forward(self, x): c3, c4 = self.backbone(x) features = self.neck(c3, c4) predictions = self.detect(features) return predictions #----------------------------------------------------------------------------------------------------------------------# # 下载并解压PASCAL VOC 2012 dataset_dir = "./VOCdevkit" voc_train = VOCDetection( root=dataset_dir, year=&#39;2012&#39;, image_set=&#39;train&#39;, download=True # 如果没有会尝试下载 ) voc_val = VOCDetection( root=dataset_dir, year=&#39;2012&#39;, image_set=&#39;val&#39;, download=True ) # 图像变换 transform = transforms.Compose([ transforms.Resize((640, 640)), # 调整为网络输入大小 transforms.ToTensor(), # 转为 [0,1] 归一化张量 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet 标准化 ]) #定义Dataset类 class VOCDataset(Dataset): def __init__(self, data_list, img_size=640): self.data_list = data_list self.img_size = img_size self.class_to_idx = { &#39;aeroplane&#39;: 0, &#39;bicycle&#39;: 1, &#39;bird&#39;: 2, &#39;boat&#39;: 3, &#39;bottle&#39;: 4, &#39;bus&#39;: 5, &#39;car&#39;: 6, &#39;cat&#39;: 7, &#39;chair&#39;: 8, &#39;cow&#39;: 9, &#39;diningtable&#39;: 10, &#39;dog&#39;: 11, &#39;horse&#39;: 12, &#39;motorbike&#39;: 13, &#39;person&#39;: 14, &#39;pottedplant&#39;: 15, &#39;sheep&#39;: 16, &#39;sofa&#39;: 17, &#39;train&#39;: 18, &#39;tvmonitor&#39;: 19 } # 定义增强 pipeline(包含 resize, flip, color jitter, normalize, to_tensor) self.transform = A.Compose([ A.Resize(height=img_size, width=img_size), A.HorizontalFlip(p=0.5), A.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1, p=0.5), A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ToTensorV2() # 转为 [C,H,W] tensor ], bbox_params=A.BboxParams( format=&#39;pascal_voc&#39;, label_fields=[&#39;class_labels&#39;], min_visibility=0.1 )) def __len__(self): return len(self.data_list) def __getitem__(self, idx): image, ann = self.data_list[idx] img = np.array(image.convert("RGB")) # 转为 numpy array boxes = [] labels = [] for obj in ann[&#39;annotation&#39;][&#39;object&#39;]: cls_name = obj[&#39;name&#39;] if cls_name not in self.class_to_idx: continue label = self.class_to_idx[cls_name] bbox = obj[&#39;bndbox&#39;] xmin = float(bbox[&#39;xmin&#39;]) ymin = float(bbox[&#39;ymin&#39;]) xmax = float(bbox[&#39;xmax&#39;]) ymax = float(bbox[&#39;ymax&#39;]) # 确保坐标合法 if xmax > xmin and ymax > ymin: boxes.append([xmin, ymin, xmax, ymax]) labels.append(label) # 若无有效标注,添加虚拟目标避免崩溃 if len(boxes) == 0: boxes = [[0, 0, 10, 10]] labels = [0] # 应用增强(自动完成 resize + flip + normalize + to_tensor) try: transformed = self.transform(image=img, bboxes=boxes, class_labels=labels) except Exception as e: print(f"Augmentation error at index {idx}: {e}") # 回退到最小变换 transformed = { "image": ToTensorV2()(image=transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] )(transforms.ToTensor()(img)))["image"], "bboxes": torch.tensor(boxes, dtype=torch.float32), "class_labels": torch.tensor(labels, dtype=torch.long) } img_tensor = transformed["image"] # 已经是 [3,640,640] 归一化张量 boxes_tensor = torch.tensor(transformed["bboxes"], dtype=torch.float32) labels_tensor = torch.tensor(transformed["class_labels"], dtype=torch.long) target = { "boxes": boxes_tensor, "labels": labels_tensor, "image_id": torch.tensor([idx]), } return img_tensor, target # 创建数据集实例 train_dataset = VOCDataset(voc_train, img_size=640) val_dataset = VOCDataset(voc_val, img_size=640) # 使用 DataLoader 加载批次 train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=8, shuffle=True, num_workers=4, pin_memory=True ) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=8, shuffle=False, num_workers=4, pin_memory=True ) #损失函数 def compute_loss(outputs, targets, strides=[8, 16, 32], num_classes=20, alpha=0.5, gamma=0.5): """ Compute YOLOv8-style loss with simple positive sample assignment. outputs: List[Tensor] -> [P3, P4, P5], each shape (B, C, H, W) targets: List[dict] -> [{&#39;boxes&#39;: (N,4), &#39;labels&#39;: (N,)}] """ device = outputs[0].device criterion_cls = nn.BCEWithLogitsLoss(reduction=&#39;none&#39;) criterion_obj = nn.BCEWithLogitsLoss(reduction=&#39;none&#39;) total_loss_cls = torch.zeros(1, device=device) total_loss_obj = torch.zeros(1, device=device) total_loss_reg = torch.zeros(1, device=device) num_positive = 0 # 特征图尺寸对应关系 img_size = 640 feature_sizes = [img_size // s for s in strides] # [80, 40, 20] for i, pred in enumerate(outputs): H, W = feature_sizes[i], feature_sizes[i] stride = strides[i] bs = pred.shape[0] # Reshape & split predictions: [B, C, H, W] -> [B, H*W, *] pred = pred.permute(0, 2, 3, 1).reshape(bs, -1, 4 + 1 + num_classes) reg_pred = pred[..., :4] # tx, ty, tw, th obj_pred = pred[..., 4] # objectness (flattened) cls_pred = pred[..., 5:] # class logits # Generate grid centers (center of each grid cell in original image space) yv, xv = torch.meshgrid([torch.arange(H), torch.arange(W)]) grid_xy = torch.stack((xv, yv), dim=2).float().to(device) # (H, W, 2) grid_xy = grid_xy.reshape(-1, 2) # (H*W, 2) grid_xy = grid_xy.unsqueeze(0).expand(bs, -1, -1) # (B, H*W, 2) anchor_points = (grid_xy + 0.5) * stride # (cx, cy) on original image # Decode predicted boxes (in xywh format) pred_xy = anchor_points + reg_pred[..., :2].sigmoid() * stride - 0.5 * stride pred_wh = torch.exp(reg_pred[..., 2:]) * stride pred_boxes = torch.cat([pred_xy, pred_wh], dim=-1) # (B, H*W, 4) # Prepare targets obj_target = torch.zeros_like(obj_pred) cls_target = torch.zeros_like(cls_pred) reg_target = torch.zeros_like(reg_pred) fg_mask = torch.zeros_like(obj_pred, dtype=torch.bool) for b in range(bs): tbox = targets[b][&#39;boxes&#39;] # (N, 4), xyxy format tlabel = targets[b][&#39;labels&#39;] # (N,) if len(tbox) == 0: continue # Convert tbox from xyxy to xywh tbox_xyxy = tbox tbox_xywh = torch.cat([ (tbox_xyxy[:, :2] + tbox_xyxy[:, 2:]) / 2, tbox_xyxy[:, 2:] - tbox_xyxy[:, :2] ], dim=1) # (N, 4) # Match: find best overlap between gt centers and anchor points gt_centers = tbox_xywh[:, :2] # (N, 2) distances = (anchor_points[b].unsqueeze(1) - gt_centers.unsqueeze(0)).pow(2).sum(dim=-1) # (H*W, N) _, closest_grid_idx = distances.min(dim=0) # each gt → nearest grid _, closest_gt_idx = distances.min(dim=1) # each grid → nearest gt # Positive samples: grids whose closest gt is itself pos_mask = torch.zeros(H * W, dtype=torch.bool, device=device) for gt_i in range(len(tbox)): grid_i = closest_grid_idx[gt_i] if closest_gt_idx[grid_i] == gt_i: # mutual match pos_mask[grid_i] = True fg_mask[b][pos_mask] = True obj_target[b][pos_mask] = 1.0 cls_target[b][pos_mask] = nn.functional.one_hot(tlabel.long(), num_classes=num_classes).float() # Regression target for positive samples matched_gt = tbox_xywh[pos_mask.sum(dim=0).nonzero(as_tuple=True)[0]] # tricky fix; assume one-to-one if len(matched_gt) > 0: reg_target[b][pos_mask] = torch.cat([ (matched_gt[:, :2] - anchor_points[b][pos_mask]) / stride, torch.log(matched_gt[:, 2:] / stride + 1e-8) ], dim=1) # Only compute loss on positive samples if fg_mask.any(): loss_obj = criterion_obj(obj_pred, obj_target) total_loss_obj += loss_obj.mean() pos_obj = obj_target == 1 if pos_obj.any(): total_loss_cls += criterion_cls(cls_pred[pos_obj], cls_target[pos_obj]).mean() total_loss_reg += torch.abs(reg_pred[pos_obj] - reg_target[pos_obj]).mean() num_positive += pos_obj.sum().item() else: total_loss_obj += obj_pred.sum() * 0 total_loss_cls += cls_pred.sum() * 0 total_loss_reg += reg_pred.sum() * 0 # Normalize losses if num_positive > 0: total_loss_reg /= num_positive total_loss_cls /= num_positive total_loss_obj /= len(outputs) total_loss = total_loss_obj + total_loss_cls + total_loss_reg * 5.0 # weight reg more return total_loss, total_loss_obj.detach(), total_loss_cls.detach(), total_loss_reg.detach() # 模型训练(真实损失) model = YOLOv8(num_classes=20).train().cuda() optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) for epoch in range(10): print(f"\nEpoch {epoch + 1}/10") for i, (images, targets) in enumerate(train_loader): images = images.cuda(non_blocking=True) # 前向传播 outputs = model(images) # List[Tensor]: [B, 25, H/8, W/8], ... # 计算真实损失 loss, loss_obj, loss_cls, loss_reg = compute_loss(outputs, targets, num_classes=20) # 反向传播 optimizer.zero_grad() loss.backward() optimizer.step() if i % 10 == 0: print(f"Iter {i}, Loss: {loss.item():.4f} " f"(obj={loss_obj.item():.4f}, cls={loss_cls.item():.4f}, reg={loss_reg.item():.4f})") 可视化
最新发布
11-27
你已经成功实现了 **YOLOv8 风格模型 + VOC 数据加载 + 损失计算与训练流程**,现在可以进入关键的下一步:**可视化训练过程中的预测结果和 GT 框对比图**。 这不仅能帮助你调试模型是否“学会检测”,还能直观判断数据增强、正样本匹配、解码逻辑是否正确。 --- ## ✅ 目标:实现一个 `visualize_predictions` 函数 功能: - 取一批验证图像(或训练图像) - 前向推理得到输出 - 解码网络输出为边界框(xyxy 格式) - 应用 NMS 抑制重复框 - 在原图上绘制 GT 和 Pred 框 - 使用 `matplotlib` 显示/保存图像 --- ### ✅ 1. 添加:将模型输出解码为真实坐标框 ```python import torch import torchvision import cv2 import numpy as np import matplotlib.pyplot as plt def decode_outputs(outputs, strides=[8, 16, 32], img_size=640): """ 将 YOLO 输出解码为 [cx, cy, w, h] 形式的归一化框(相对坐标) 返回: (pred_boxes, pred_scores, pred_labels), each as tensor """ device = outputs[0].device pred_boxes = [] pred_scores = [] pred_labels = [] for i, pred in enumerate(outputs): bs, _, ny, nx = pred.shape stride = strides[i] pred = pred.permute(0, 2, 3, 1).reshape(bs, -1, 4 + 1 + 20) # [B, H*W, ...] reg_pred = pred[..., :4] # tx, ty, tw, th obj_pred = pred[..., 4:5].sigmoid() # objectness cls_pred = pred[..., 5:].sigmoid() # class scores # 获取 grid 中心点 yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)]) grid_xy = torch.stack((xv, yv), dim=2).float().to(device) # (ny, nx, 2) grid_xy = grid_xy.reshape(-1, 2).unsqueeze(0).expand(bs, -1, -1) # (B, H*W, 2) # 解码中心点 pred_xy = (grid_xy + 0.5 + reg_pred[..., :2]) * stride # 解码宽高 pred_wh = torch.exp(reg_pred[..., 2:]) * stride # 拼接为 xywh 框 boxes = torch.cat([pred_xy, pred_wh], dim=-1) # (B, H*W, 4) # 计算置信度:obj * max(cls_prob) scores, labels = torch.max(cls_pred, dim=-1) scores = scores * obj_pred.squeeze(-1) # confidence = obj × cls_score # 转换为绝对坐标并归一化到 [0,1] img_w, img_h = img_size, img_size boxes_normalized = boxes / torch.tensor([img_w, img_h, img_w, img_h], device=device) pred_boxes.append(boxes_normalized) pred_scores.append(scores) pred_labels.append(labels) # 合并所有尺度 pred_boxes = torch.cat(pred_boxes, dim=1) # (B, L, 4) pred_scores = torch.cat(pred_scores, dim=1) # (B, L) pred_labels = torch.cat(pred_labels, dim=1) # (B, L) return pred_boxes, pred_scores, pred_labels ``` --- ### ✅ 2. 实现 NMS 并过滤低分框 ```python def postprocess(pred_boxes, pred_scores, pred_labels, conf_thresh=0.25, nms_thresh=0.5): """ 对每张图片进行 NMS 后处理 """ final_boxes, final_scores, final_labels = [], [], [] for i in range(pred_boxes.shape[0]): # 遍历 batch 中每张图 boxes = pred_boxes[i] # (L, 4) scores = pred_scores[i] # (L,) labels = pred_labels[i] # (L,) # 过滤低置信度 mask = scores > conf_thresh boxes = boxes[mask] scores = scores[mask] labels = labels[mask] if len(boxes) == 0: final_boxes.append(torch.empty(0, 4)) final_scores.append(torch.empty(0)) final_labels.append(torch.empty(0)) continue # 转 xyxy 并应用 NMS boxes_xyxy = torchvision.ops.box_convert(boxes, in_fmt=&#39;cxcywh&#39;, out_fmt=&#39;xyxy&#39;) keep = torchvision.ops.nms(boxes_xyxy, scores, nms_thresh) final_boxes.append(boxes[keep]) final_scores.append(scores[keep]) final_labels.append(labels[keep]) return final_boxes, final_scores, final_labels ``` --- ### ✅ 3. 可视化函数:绘制 GT 和 Pred 框 ```python def visualize_predictions(model, dataloader, idx_to_class, device="cuda", num_images=4): model.eval() inv_normalize = transforms.Compose([ transforms.Normalize(mean=[0., 0., 0.], std=[1/0.229, 1/0.224, 1/0.225]), transforms.Normalize(mean=[-0.485, -0.456, -0.406], std=[1., 1., 1.]), ]) figure, ax = plt.subplots(num_images, 2, figsize=(12, 6 * num_images)) if num_images == 1: ax = ax.unsqueeze(0) with torch.no_grad(): for i, (images, targets) in enumerate(dataloader): images = images.to(device) outputs = model(images) pred_boxes, pred_scores, pred_labels = decode_outputs(outputs) det_boxes, det_scores, det_labels = postprocess(pred_boxes, pred_scores, pred_labels) for j in range(min(num_images, len(images))): # 图像 j img = images[j].cpu() img = inv_normalize(img) img = torch.clamp(img, 0, 1) img = transforms.ToPILImage()(img) # 绘图 ax[j, 0].imshow(img); ax[j, 0].set_title("Ground Truth") ax[j, 1].imshow(img); ax[j, 1].set_title("Predictions") # 绘制 GT gt_boxes = targets[j][&#39;boxes&#39;].cpu() gt_labels = targets[j][&#39;labels&#39;].cpu() for k in range(len(gt_boxes)): box = gt_boxes[k].numpy() label = idx_to_class[gt_labels[k].item()] rect = plt.Rectangle( (box[0]*640, box[1]*640), (box[2]-box[0])*640, (box[3]-box[1])*640, fill=False, edgecolor=&#39;green&#39;, linewidth=2 ) ax[j, 0].add_patch(rect) ax[j, 0].text(box[0]*640, box[1]*640, label, color=&#39;white&#39;, fontsize=10, bbox=dict(facecolor=&#39;green&#39;, alpha=0.7)) # 绘制 Pred pred_box_img = det_boxes[j].cpu() * 640 # 转为像素坐标 pred_score_img = det_scores[j].cpu() pred_label_img = det_labels[j].cpu() for k in range(len(pred_box_img)): box = pred_box_img[k].numpy() score = pred_score_img[k].item() label = idx_to_class[pred_label_img[k].item()] x1, y1 = int(box[0]), int(box[1]) w, h = int(box[2] - box[0]), int(box[3] - box[1]) rect = plt.Rectangle( (x1, y1), w, h, fill=False, edgecolor=&#39;red&#39;, linewidth=2 ) ax[j, 1].add_patch(rect) ax[j, 1].text(x1, y1, f"{label}:{score:.2f}", color=&#39;white&#39;, fontsize=10, bbox=dict(facecolor=&#39;red&#39;, alpha=0.7)) ax[j, 0].axis(&#39;off&#39;); ax[j, 1].axis(&#39;off&#39;) break # 只看第一批次 plt.tight_layout() plt.show() model.train() # 回到训练模式 ``` --- ### ✅ 4. 调用可视化(在训练中或训练后) ```python # 类别索引反查表 idx_to_class = {v: k for k, v in train_dataset.class_to_idx.items()} # 在某个 epoch 后调用 visualize_predictions(model, val_loader, idx_to_class, device="cuda", num_images=4) ``` --- ### ✅ 效果示例(想象图) | Ground Truth | Predictions | |-------------|------------| | 绿框 + 标签 | 红框 + 标签 + 分数 | | 正确标注物体位置 | 显示模型预测的框和类别 | > 初期可能很多误检,随着训练会逐渐改善。 --- ###
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值