I0309 16:09:47.720973 7599 solver.cpp:258] Train net output #0: loss = 86.9954 (* 1 = 86.9954 l

这是一个来自新浪博客的示例链接,具体细节未给出,通常包含作者的观点或教程。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

import traceback import cv2 import json import os import sys import time import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from albumentations import ( Compose, Resize, Normalize, HorizontalFlip, VerticalFlip, Rotate, OneOf, RandomBrightnessContrast, GaussNoise, ElasticTransform, RandomGamma, HueSaturationValue, CoarseDropout, Perspective, KeypointParams, CLAHE, MotionBlur, ISONoise,Lambda ) from albumentations.pytorch import ToTensorV2 from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, StepLR, ReduceLROnPlateau from torch.utils.data import Dataset, DataLoader from torchvision.models import resnet18, ResNet18_Weights from sklearn.metrics import precision_score, recall_score, f1_score import matplotlib.pyplot as plt from tqdm import tqdm # 添加进度条库 class EnhancedTrainingLogger: """增强的训练日志记录器,跟踪多种损失指标并实时可视化""" def __init__(self): self.total_losses = [] self.bin_losses = [] self.thresh_losses = [] self.db_losses = [] self.timestamps = [] self.start_time = time.time() self.lr_history = [] self.val_metrics = {&#39;precision&#39;: [], &#39;recall&#39;: [], &#39;f1&#39;: []} # 实时可视化设置 plt.ion() # 开启交互模式 self.fig, self.axs = plt.subplots(2, 2, figsize=(15, 10)) self.fig.suptitle(&#39;Training Progress&#39;, fontsize=16) # 初始化图表 self.loss_line, = self.axs[0, 0].plot([], [], &#39;r-&#39;, label=&#39;Total Loss&#39;) self.bin_line, = self.axs[0, 0].plot([], [], &#39;g-&#39;, label=&#39;Binary Loss&#39;) self.thresh_line, = self.axs[0, 0].plot([], [], &#39;b-&#39;, label=&#39;Threshold Loss&#39;) self.db_line, = self.axs[0, 0].plot([], [], &#39;m-&#39;, label=&#39;DB Loss&#39;) self.axs[0, 0].set_title(&#39;Training Loss Components&#39;) self.axs[0, 0].set_xlabel(&#39;Batch&#39;) self.axs[0, 0].set_ylabel(&#39;Loss&#39;) self.axs[0, 0].legend() self.axs[0, 0].grid(True) self.lr_line, = self.axs[0, 1].plot([], [], &#39;c-&#39;) self.axs[0, 1].set_title(&#39;Learning Rate Schedule&#39;) self.axs[0, 1].set_xlabel(&#39;Batch&#39;) self.axs[0, 1].set_ylabel(&#39;Learning Rate&#39;) self.axs[0, 1].grid(True) self.precision_line, = self.axs[1, 0].plot([], [], &#39;r-&#39;, label=&#39;Precision&#39;) self.recall_line, = self.axs[1, 0].plot([], [], &#39;g-&#39;, label=&#39;Recall&#39;) self.f1_line, = self.axs[1, 0].plot([], [], &#39;b-&#39;, label=&#39;F1 Score&#39;) self.axs[1, 0].set_title(&#39;Validation Metrics&#39;) self.axs[1, 0].set_xlabel(&#39;Epoch&#39;) self.axs[1, 0].set_ylabel(&#39;Score&#39;) self.axs[1, 0].legend() self.axs[1, 0].grid(True) # 添加文本区域显示当前指标 self.metrics_text = self.axs[1, 1].text(0.5, 0.5, "", horizontalalignment=&#39;center&#39;, verticalalignment=&#39;center&#39;, transform=self.axs[1, 1].transAxes, fontsize=12) self.axs[1, 1].axis(&#39;off&#39;) # 关闭坐标轴 plt.tight_layout() plt.subplots_adjust(top=0.9) plt.draw() plt.pause(0.1) def on_batch_end(self, batch_idx, total_loss, bin_loss, thresh_loss, db_loss, lr=None): elapsed = time.time() - self.start_time self.total_losses.append(total_loss) self.bin_losses.append(bin_loss) self.thresh_losses.append(thresh_loss) self.db_losses.append(db_loss) self.timestamps.append(elapsed) if lr is not None: self.lr_history.append(lr) # 更新实时图表 self.update_plots(batch_idx) #10个batch打印详细日志 if batch_idx % 10 == 0: avg_total = np.mean(self.total_losses[-10:]) if len(self.total_losses) >= 10 else total_loss avg_bin = np.mean(self.bin_losses[-10:]) if len(self.bin_losses) >= 10 else bin_loss avg_thresh = np.mean(self.thresh_losses[-10:]) if len(self.thresh_losses) >= 10 else thresh_loss avg_db = np.mean(self.db_losses[-10:]) if len(self.db_losses) >= 10 else db_loss # 更新文本区域 metrics_text = ( f"Batch: {batch_idx}\n" f"Total Loss: {total_loss:.4f} (Avg10: {avg_total:.4f})\n" f"Binary Loss: {bin_loss:.4f} (Avg10: {avg_bin:.4f})\n" f"Threshold Loss: {thresh_loss:.4f} (Avg10: {avg_thresh:.4f})\n" f"DB Loss: {db_loss:.4f} (Avg10: {avg_db:.4f})\n" f"Learning Rate: {lr:.2e}\n" f"Time: {int(elapsed // 3600):02d}:{int((elapsed % 3600) // 60):02d}:{int(elapsed % 60):02d}" ) self.metrics_text.set_text(metrics_text) # 刷新图表 plt.draw() plt.pause(0.01) def update_plots(self, batch_idx): # 更新损失图表 x_data = np.arange(len(self.total_losses)) self.loss_line.set_data(x_data, self.total_losses) self.bin_line.set_data(x_data, self.bin_losses) self.thresh_line.set_data(x_data, self.thresh_losses) self.db_line.set_data(x_data, self.db_losses) # 自动调整Y轴范围 all_losses = self.total_losses + self.bin_losses + self.thresh_losses + self.db_losses if all_losses: min_loss = min(all_losses) * 0.9 max_loss = max(all_losses) * 1.1 self.axs[0, 0].set_ylim(min_loss, max_loss) # 更新学习率图表 if self.lr_history: self.lr_line.set_data(np.arange(len(self.lr_history)), self.lr_history) self.axs[0, 1].set_ylim(min(self.lr_history) * 0.9, max(self.lr_history) * 1.1) # 更新验证指标图表 if self.val_metrics[&#39;precision&#39;]: x_epochs = np.arange(len(self.val_metrics[&#39;precision&#39;])) self.precision_line.set_data(x_epochs, self.val_metrics[&#39;precision&#39;]) self.recall_line.set_data(x_epochs, self.val_metrics[&#39;recall&#39;]) self.f1_line.set_data(x_epochs, self.val_metrics[&#39;f1&#39;]) # 自动调整Y轴范围 all_metrics = self.val_metrics[&#39;precision&#39;] + self.val_metrics[&#39;recall&#39;] + self.val_metrics[&#39;f1&#39;] if all_metrics: min_metric = min(all_metrics) * 0.9 max_metric = max(all_metrics) * 1.1 self.axs[1, 0].set_ylim(min_metric, max_metric) # 调整X轴范围 self.axs[0, 0].set_xlim(0, max(1, len(self.total_losses))) self.axs[0, 1].set_xlim(0, max(1, len(self.lr_history))) if self.val_metrics[&#39;precision&#39;]: self.axs[1, 0].set_xlim(0, max(1, len(self.val_metrics[&#39;precision&#39;]))) def on_epoch_end(self, epoch, optimizer=None): # 添加空列表检查 total_min = min(self.total_losses) if self.total_losses else 0.0 total_max = max(self.total_losses) if self.total_losses else 0.0 total_avg = np.mean(self.total_losses) if self.total_losses else 0.0 bin_min = min(self.bin_losses) if self.bin_losses else 0.0 bin_avg = np.mean(self.bin_losses) if self.bin_losses else 0.0 thresh_min = min(self.thresh_losses) if self.thresh_losses else 0.0 thresh_avg = np.mean(self.thresh_losses) if self.thresh_losses else 0.0 db_min = min(self.db_losses) if self.db_losses else 0.0 db_avg = np.mean(self.db_losses) if self.db_losses else 0.0 # 生成详细的损失报告 report = ( f"\n{&#39;=&#39; * 70}\n" f"EPOCH {epoch + 1} SUMMARY:\n" f" - Total Loss: Min={total_min:.6f}, Max={total_max:.6f}, Avg={total_avg:.6f}\n" f" - Binary Loss: Min={bin_min:.6f}, Avg={bin_avg:.6f}\n" f" - Threshold Loss: Min={thresh_min:.6f}, Avg={thresh_avg:.6f}\n" f" - DB Loss: Min={db_min:.6f}, Avg={db_avg:.6f}\n" ) if self.val_metrics[&#39;precision&#39;]: report += ( f" - Val Metrics: Precision={self.val_metrics[&#39;precision&#39;][-1]:.4f}, " f"Recall={self.val_metrics[&#39;recall&#39;][-1]:.4f}, F1={self.val_metrics[&#39;f1&#39;][-1]:.4f}\n" ) if optimizer: report += f" - Learning Rate: {optimizer.param_groups[0][&#39;lr&#39;]:.6e}\n" report += f"{&#39;=&#39; * 70}" print(report) # 保存CSV日志 with open(f&#39;training_log_epoch_{epoch + 1}.csv&#39;, &#39;w&#39;) as f: f.write("Timestamp,Total_Loss,Bin_Loss,Thresh_Loss,DB_Loss,Learning_Rate\n") for i, t in enumerate(self.timestamps): lr_val = self.lr_history[i] if i < len(self.lr_history) else 0 f.write( f"{t:.2f},{self.total_losses[i]:.6f},{self.bin_losses[i]:.6f},{self.thresh_losses[i]:.6f},{self.db_losses[i]:.6f},{lr_val:.6e}\n") # 重置记录(保留最后一个批次的值) self.total_losses = [self.total_losses[-1]] if self.total_losses else [] self.bin_losses = [self.bin_losses[-1]] if self.bin_losses else [] self.thresh_losses = [self.thresh_losses[-1]] if self.thresh_losses else [] self.db_losses = [self.db_losses[-1]] if self.db_losses else [] self.timestamps = [self.timestamps[-1]] if self.timestamps else [] self.lr_history = [self.lr_history[-1]] if self.lr_history else [] # 更新图表 self.update_plots(0) plt.draw() plt.pause(0.1) def on_train_end(self): """训练结束后生成图表并保存""" plt.ioff() # 关闭交互模式 # 保存最终图表 plt.savefig(&#39;training_summary.png&#39;) plt.close() # 生成详细的训练报告图 self.generate_detailed_report() def generate_detailed_report(self): """生成详细的训练报告图表""" fig, axs = plt.subplots(3, 1, figsize=(12, 15)) # 损失图表 axs[0].plot(self.total_losses, label=&#39;Total Loss&#39;) axs[0].plot(self.bin_losses, label=&#39;Binary Loss&#39;) axs[0].plot(self.thresh_losses, label=&#39;Threshold Loss&#39;) axs[0].plot(self.db_losses, label=&#39;DB Loss&#39;) axs[0].set_title(&#39;Training Loss Components&#39;) axs[0].set_xlabel(&#39;Batch&#39;) axs[0].set_ylabel(&#39;Loss&#39;) axs[0].legend() axs[0].grid(True) # 学习率图表 axs[1].plot(self.lr_history) axs[1].set_title(&#39;Learning Rate Schedule&#39;) axs[1].set_xlabel(&#39;Batch&#39;) axs[1].set_ylabel(&#39;Learning Rate&#39;) axs[1].grid(True) # 验证指标图表 if self.val_metrics[&#39;precision&#39;]: axs[2].plot(self.val_metrics[&#39;precision&#39;], &#39;o-&#39;, label=&#39;Precision&#39;) axs[2].plot(self.val_metrics[&#39;recall&#39;], &#39;o-&#39;, label=&#39;Recall&#39;) axs[2].plot(self.val_metrics[&#39;f1&#39;], &#39;o-&#39;, label=&#39;F1 Score&#39;) axs[2].set_title(&#39;Validation Metrics&#39;) axs[2].set_xlabel(&#39;Epoch&#39;) axs[2].set_ylabel(&#39;Score&#39;) axs[2].legend() axs[2].grid(True) # 标记最佳F1分数 best_f1_idx = np.argmax(self.val_metrics[&#39;f1&#39;]) best_f1 = self.val_metrics[&#39;f1&#39;][best_f1_idx] axs[2].plot(best_f1_idx, best_f1, &#39;ro&#39;, markersize=8) axs[2].annotate(f&#39;Best F1: {best_f1:.4f}&#39;, xy=(best_f1_idx, best_f1), xytext=(best_f1_idx + 0.5, best_f1 - 0.05), arrowprops=dict(facecolor=&#39;black&#39;, shrink=0.05)) plt.tight_layout() plt.savefig(&#39;training_detailed_report.png&#39;) plt.close() # 在类外部定义全局函数 # 在类外部定义全局函数 def suppress_water_meter_glare(img, **kwargs): """水表专用反光抑制(忽略额外参数)""" lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB) l, a, b = cv2.split(lab) # 动态计算CLAHE参数 l_mean = np.mean(l) clip_limit = 2.0 + (l_mean / 40) # 亮度越高,clipLimit越大 clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=(8, 8)) l_clahe = clahe.apply(l) # 选择性增强暗部区域 _, mask = cv2.threshold(l, 100, 255, cv2.THRESH_BINARY_INV) blended = cv2.addWeighted(l, 0.7, l_clahe, 0.3, 0) l_final = np.where(mask > 0, blended, l) lab = cv2.merge((l_final, a, b)) return cv2.cvtColor(lab, cv2.COLOR_LAB2RGB) # ---------------------------- # 1. 数据集加载与预处理 (优化浮点坐标处理) # ---------------------------- class WaterMeterDataset(Dataset): """水表数字区域检测数据集 - 优化浮点坐标处理""" # ... (初始化代码保持不变) ... def __init__(self, image_dir, label_dir, input_size=(640, 640), augment=True): self.image_dir = image_dir self.label_dir = label_dir self.input_size = input_size self.augment = augment self.image_files = [f for f in os.listdir(image_dir) if os.path.isfile(os.path.join(image_dir, f))] # 基础预处理流程 self.base_transform = Compose([ Resize(height=input_size[0], width=input_size[1]), Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ToTensorV2() ]) # 简化但有效的数据增强 self.augmentation = Compose([ # 水表专用增强 OneOf([ # 模拟不同角度拍摄 Perspective(scale=(0.05, 0.1), p=0.3), # 模拟水表玻璃反光 RandomGamma(gamma_limit=(80, 120), p=0.2), # 模拟水表污渍 CoarseDropout(max_holes=5, max_height=20, max_width=20, fill_value=0, p=0.2) ], p=0.8), # 水表反光抑制 Lambda(name=&#39;glare_reduction&#39;, image=suppress_water_meter_glare), Lambda(name=&#39;water_meter_aug&#39;, image=water_meter_specific_aug, p=0.7), OneOf([ HorizontalFlip(p=0.3), VerticalFlip(p=0.2), Rotate(limit=15, p=0.5) ], p=0.7), OneOf([ RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5), CLAHE(clip_limit=2.0, p=0.3), GaussNoise(std_range=(0.15, 0.4), # 优化后范围 mean_range=(0, 0), per_channel=True, p=0.3), ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.5)) ], p=0.7) ], p=0.8, keypoint_params=KeypointParams(format=&#39;xyas&#39;)) if augment else None def __len__(self): return len(self.image_files) def __getitem__(self, idx): img_name = self.image_files[idx] img_path = os.path.join(self.image_dir, img_name) # 加载图像 image = cv2.imread(img_path) if image is None: print(f"错误: 无法读取图像 {img_path}") return self[(idx + 1) % len(self)] # 跳过错误图像 image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # 应用反光抑制 if np.random.rand() > 0.5: lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB) l, a, b = cv2.split(lab) clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8)) l = clahe.apply(l) lab = cv2.merge([l, a, b]) image = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB) # 解析标注 base_name = os.path.splitext(img_name)[0] label_path = os.path.join(self.label_dir, base_name + &#39;.json&#39;) try: with open(label_path) as f: label_data = json.load(f) polygons = [] orig_h, orig_w = image.shape[:2] # 获取标注时的图像尺寸(如果存在) json_h = label_data.get(&#39;imageHeight&#39;, orig_h) json_w = label_data.get(&#39;imageWidth&#39;, orig_w) # 计算缩放比例(处理不同尺寸的标注) scale_x = orig_w / json_w scale_y = orig_h / json_h for shape in label_data[&#39;shapes&#39;]: if shape[&#39;shape_type&#39;] == &#39;polygon&#39;: # 直接使用浮点坐标,避免整数转换 poly = np.array(shape[&#39;points&#39;], dtype=np.float32) # 应用缩放比例 poly[:, 0] = poly[:, 0] * scale_x poly[:, 1] = poly[:, 1] * scale_y # 裁剪到实际图像范围内 poly[:, 0] = np.clip(poly[:, 0], 0, orig_w - 1) poly[:, 1] = np.clip(poly[:, 1], 0, orig_h - 1) polygons.append(poly) # 生成目标前验证标注有效性 if len(polygons) == 0: print(f"警告: {img_name} 无有效标注,使用随机样本替代") return self[np.random.randint(0, len(self))] # === 调试可视化 === if idx < 5: debug_img = image.copy() for poly in polygons: int_poly = poly.astype(np.int32).reshape(-1, 1, 2) cv2.polylines(debug_img, [int_poly], True, (0, 255, 0), 3) debug_info = f"Size: {orig_w}x{orig_h} | Polys: {len(polygons)}" cv2.putText(debug_img, debug_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2) debug_path = f"debug_{base_name}.jpg" cv2.imwrite(debug_path, cv2.cvtColor(debug_img, cv2.COLOR_RGB2BGR)) print(f"保存调试图像: {debug_path}") # 应用数据增强 keypoints = [] for poly in polygons: for point in poly: # 保留浮点精度 keypoints.append((point[0], point[1], 0, 0)) if self.augment and self.augmentation: poly_lengths = [len(poly) for poly in polygons] # 应用增强 augmented = self.augmentation(image=image, keypoints=keypoints) image = augmented[&#39;image&#39;] keypoints = augmented[&#39;keypoints&#39;] # 正确重组多边形 polygons = [] start_idx = 0 for poly_len in poly_lengths: end_idx = start_idx + poly_len if end_idx <= len(keypoints): poly_points = keypoints[start_idx:end_idx] new_poly = np.array([[p[0], p[1]] for p in poly_points], dtype=np.float32) polygons.append(new_poly) start_idx = end_idx # 对所有多边形进行边界裁剪 for poly in polygons: poly[:, 0] = np.clip(poly[:, 0], 0, image.shape[1] - 1) poly[:, 1] = np.clip(poly[:, 1], 0, image.shape[0] - 1) except (FileNotFoundError, json.JSONDecodeError) as e: print(f"警告: 无法加载标注文件 {label_path} - {str(e)}") polygons = [] # 记录数据增强后的图像尺寸 aug_h, aug_w = image.shape[:2] # 基础预处理(包含Resize) processed = self.base_transform(image=image) image_tensor = processed[&#39;image&#39;] # 将多边形坐标缩放到input_size scale_x = self.input_size[1] / aug_w scale_y = self.input_size[0] / aug_h scaled_polygons = [] for poly in polygons: scaled_poly = poly.copy() scaled_poly[:, 0] = scaled_poly[:, 0] * scale_x scaled_poly[:, 1] = scaled_poly[:, 1] * scale_y scaled_poly[:, 0] = np.clip(scaled_poly[:, 0], 0, self.input_size[1] - 1) scaled_poly[:, 1] = np.clip(scaled_poly[:, 1], 0, self.input_size[0] - 1) scaled_polygons.append(scaled_poly) # 生成目标(使用input_size尺寸) binary_target = self.generate_binary_target(scaled_polygons, (self.input_size[0], self.input_size[1])) threshold_target = self.generate_threshold_target(scaled_polygons, (self.input_size[0], self.input_size[1])) return image_tensor, binary_target, threshold_target def generate_threshold_target(self, polygons, img_shape, ratio=0.4): """生成阈值目标图(优化浮点坐标处理)""" # 定义输出尺寸(特征图尺寸) output_size = (self.input_size[0] // 8, self.input_size[1] // 8) # 创建全尺寸距离图 full_size_map = np.zeros(img_shape[:2], dtype=np.float32) for poly in polygons: if len(poly) < 3: continue # 确保坐标在图像范围内 poly[:, 0] = np.clip(poly[:, 0], 0, img_shape[1] - 1) poly[:, 1] = np.clip(poly[:, 1], 0, img_shape[0] - 1) # 计算最大距离(防止除零错误) area = cv2.contourArea(poly) perimeter = cv2.arcLength(poly, True) if perimeter < 1e-3 or area < 10: continue max_dist = area * (1 - ratio ** 2) / max(perimeter, 1e-3) # 创建浮点精度的多边形掩码 mask = np.zeros(img_shape[:2], dtype=np.uint8) int_poly = poly.reshape((-1, 1, 2)).astype(np.int32) cv2.fillPoly(mask, [int_poly], 255) # 计算距离变换并更新全尺寸图 dist = cv2.distanceTransform(mask, cv2.DIST_L2, 3) normalized = np.clip(dist / max(max_dist, 1e-6), 0, 1) full_size_map = np.maximum(full_size_map, normalized) # 下采样到特征图尺寸 dist_map = cv2.resize(full_size_map, output_size, interpolation=cv2.INTER_LINEAR) # 空目标检查 if np.max(dist_map) < 1e-6: return torch.zeros((1, *output_size), dtype=torch.float32) return torch.from_numpy(dist_map).unsqueeze(0).float() def generate_binary_target(self, polygons, img_shape): """生成二值化目标图(优化浮点坐标处理)""" # 直接在目标尺寸上创建 output_size = (self.input_size[0] // 8, self.input_size[1] // 8) binary_map = np.zeros(output_size, dtype=np.float32) # 计算缩放比例 (原始图像->特征图) scale_x = output_size[1] / img_shape[1] scale_y = output_size[0] / img_shape[0] for poly in polygons: if len(poly) > 2: # 缩放多边形到特征图尺寸(保持浮点精度) scaled_poly = poly.copy() scaled_poly[:, 0] = scaled_poly[:, 0] * scale_x scaled_poly[:, 1] = scaled_poly[:, 1] * scale_y # 使用浮点坐标填充(更精确) int_poly = scaled_poly.reshape((-1, 1, 2)).astype(np.float32) # 创建临时画布进行填充 temp_canvas = np.zeros(output_size, dtype=np.uint8) cv2.fillPoly(temp_canvas, [int_poly.astype(np.int32)], 1) binary_map = np.maximum(binary_map, temp_canvas.astype(np.float32)) return torch.from_numpy(binary_map).unsqueeze(0).float() # ---------------------------- # 2. DBNet模型定义 (增强版) # ---------------------------- class DBNet(nn.Module): """基于ResNet18的DBNet文本检测模型""" def __init__(self, pretrained=True): super(DBNet, self).__init__() base_model = resnet18(weights=ResNet18_Weights.DEFAULT) # 提取中间特征层 self.conv1 = base_model.conv1 self.bn1 = base_model.bn1 self.relu = base_model.relu self.maxpool = base_model.maxpool self.layer1 = base_model.layer1 self.layer2 = base_model.layer2 self.layer3 = base_model.layer3 self.layer4 = base_model.layer4 # 特征融合层 self.fusion_conv = nn.Sequential( nn.Conv2d(512, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.Conv2d(128, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True) ) # 检测头 self.db_head = DBHead(64) def forward(self, x): # 骨干网络前向传播 x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) # 特征融合 fused = self.fusion_conv(x) # 检测头 binary_map, thresh_map = self.db_head(fused) return binary_map, thresh_map class DBHead(nn.Module): """DBNet检测头,包含注意力机制和残差连接""" def __init__(self, in_channels): super(DBHead, self).__init__() # 修改DBHead的残差块 self.res_block = nn.Sequential( nn.Conv2d(in_channels, in_channels, 3, padding=1), nn.BatchNorm2d(in_channels), nn.LeakyReLU(0.2, inplace=True), # 使用LeakyReLU防止梯度消失 nn.Conv2d(in_channels, in_channels, 3, padding=1), nn.BatchNorm2d(in_channels) ) # 添加空间注意力机制 self.spatial_attn = nn.Sequential( nn.Conv2d(in_channels, 1, kernel_size=3, padding=1), nn.Sigmoid() ) # 通道注意力机制 self.attention = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(in_channels, in_channels // 8, 1), nn.ReLU(inplace=True), nn.Conv2d(in_channels // 8, in_channels, 1), nn.Sigmoid() ) # 二值化分支 self.binarize = nn.Sequential( nn.Conv2d(in_channels, in_channels // 2, 3, padding=1), nn.BatchNorm2d(in_channels // 2), nn.ReLU(inplace=True), nn.Conv2d(in_channels // 2, in_channels // 2, 3, padding=1), nn.BatchNorm2d(in_channels // 2), nn.ReLU(inplace=True), nn.ConvTranspose2d(in_channels // 2, in_channels // 4, 4, stride=2, padding=1), nn.BatchNorm2d(in_channels // 4), nn.ReLU(inplace=True), nn.ConvTranspose2d(in_channels // 4, 1, 4, stride=2, padding=1), nn.Sigmoid() ) # 阈值分支 self.thresh = nn.Sequential( nn.Conv2d(in_channels, in_channels // 2, 3, padding=1), nn.BatchNorm2d(in_channels // 2), nn.ReLU(inplace=True), nn.Conv2d(in_channels // 2, in_channels // 4, 3, padding=1), nn.BatchNorm2d(in_channels // 4), nn.ReLU(inplace=True), nn.ConvTranspose2d(in_channels // 4, in_channels // 8, 4, stride=2, padding=1), nn.BatchNorm2d(in_channels // 8), nn.ReLU(inplace=True), nn.ConvTranspose2d(in_channels // 8, 1, 4, stride=2, padding=1), nn.Sigmoid() ) def forward(self, x): # 残差连接 residual = x x = self.res_block(x) + residual # 空间注意力 attn_map = self.spatial_attn(x) x = x * attn_map binary_map = self.binarize(x) thresh_map = self.thresh(x) return binary_map, thresh_map # ---------------------------- # 3. 损失函数定义 (增强版) # ---------------------------- class DBLoss(nn.Module): """重构后的 DBNet 损失函数,符合原始论文设计[1,2,4](@ref)""" def __init__(self, alpha=1.0, beta=10.0, k=50, ohem_ratio=3.0): super(DBLoss, self).__init__() self.alpha = alpha # 概率图损失权重 self.beta = beta # 阈值图损失权重 self.k = k # 可微二值化参数[1](@ref) self.ohem_ratio = ohem_ratio def forward(self, preds, targets): binary_pred, thresh_pred = preds binary_target, thresh_target = targets # 1. 概率图损失(二值图损失)使用带 OHEM 的 Dice Loss[2](@ref) prob_loss = self.dice_loss_with_ohem(binary_pred, binary_target) # 2. 阈值图损失使用 L1 Loss[4](@ref) thresh_loss = F.l1_loss(thresh_pred, thresh_target, reduction=&#39;mean&#39;) # 3. 可微二值化计算[1](@ref) with torch.no_grad(): # 计算近似二值图 B = 1 / (1 + exp(-k(P - T))) binary_map = torch.sigmoid(self.k * (binary_pred - thresh_pred)) # 4. 二值图损失使用 Dice Loss bin_loss = self.dice_loss(binary_map, binary_target) # 5. 组合损失:L = L_s + α × L_t + β × L_b total_loss = prob_loss + self.alpha * thresh_loss + self.beta * bin_loss return total_loss, prob_loss, thresh_loss, bin_loss def dice_loss(self, pred, target): """标准 Dice Loss 实现""" smooth = 1.0 intersection = (pred * target).sum() union = pred.sum() + target.sum() return 1 - (2. * intersection + smooth) / (union + smooth) def dice_loss_with_ohem(self, pred, target): """带 OHEM 的 Dice Loss 实现[2](@ref)""" # 计算每个像素的损失 loss_map = 1 - (2 * pred * target + 1) / (pred + target + 1) # 应用 OHEM 采样 pos_mask = (target > 0.5).float() neg_mask = 1 - pos_mask # 计算正负样本数量 n_pos = pos_mask.sum().item() n_neg = min(int(n_pos * self.ohem_ratio), neg_mask.sum().item()) if n_neg == 0: return self.dice_loss(pred, target) # 选择最难负样本 neg_loss = loss_map * neg_mask neg_loss = neg_loss.view(-1) topk_neg_loss, _ = torch.topk(neg_loss, n_neg) # 组合正负样本损失 pos_loss = (loss_map * pos_mask).sum() total_loss = (pos_loss + topk_neg_loss.sum()) / (n_pos + n_neg + 1e-6) return total_loss # ---------------------------- # 辅助函数 (保持不变) # ---------------------------- def calculate_metrics(pred, target, threshold=0.5): """计算精确度、召回率和F1分数""" pred_bin = (pred > threshold).float() target_bin = (target > 0.5).float() pred_flat = pred_bin.view(-1).cpu().numpy() target_flat = target_bin.view(-1).cpu().numpy() # 避免全零情况 if np.sum(target_flat) == 0: return 0.0, 0.0, 0.0 precision = precision_score(target_flat, pred_flat, zero_division=0) recall = recall_score(target_flat, pred_flat, zero_division=0) f1 = f1_score(target_flat, pred_flat, zero_division=0) return precision, recall, f1 # ... (保持不变) ... def validate_model(model, dataloader, device): """验证模型性能""" model.eval() total_precision = 0.0 total_recall = 0.0 total_f1 = 0.0 num_batches = 0 with torch.no_grad(): for images, binary_targets, _ in dataloader: images = images.to(device) binary_targets = binary_targets.to(device) binary_preds, _ = model(images) precision, recall, f1 = calculate_metrics(binary_preds, binary_targets) total_precision += precision total_recall += recall total_f1 += f1 num_batches += 1 avg_precision = total_precision / num_batches avg_recall = total_recall / num_batches avg_f1 = total_f1 / num_batches return avg_precision, avg_recall, avg_f1 # 2. 动态损失权重校准 - 修改DBLoss类 class AdaptiveDBLoss(DBLoss): def __init__(self, alpha=1.0, beta=5.0, gamma=2.0, adapt_step=100): super().__init__(alpha, beta, gamma) self.adapt_step = adapt_step self.beta_history = [] def forward(self, preds, targets, step): # 动态调整β系数 if step % self.adapt_step == 0 and len(self.beta_history) > 10: db_median = np.median(self.beta_history[-10:]) self.beta = max(1.0, min(db_median * 0.8, 10.0)) total_loss, bin_loss, thresh_loss, db_loss = super().forward(preds, targets) # 记录当前β值的表现 self.beta_history.append(db_loss.item()) return total_loss, bin_loss, thresh_loss, db_loss # 3. 模型架构增强 - 替换原始DBHead class EnhancedDBHead(DBHead): def __init__(self, in_channels): super().__init__(in_channels) # 增加通道容量 self.res_block = nn.Sequential( nn.Conv2d(in_channels, in_channels * 2, 3, padding=1), nn.GroupNorm(8, in_channels * 2), nn.GELU(), nn.Conv2d(in_channels * 2, in_channels, 3, padding=1), nn.GroupNorm(8, in_channels) ) # 深度可分离卷积增强特征 self.depthwise = nn.Sequential( nn.Conv2d(in_channels, in_channels, 3, padding=1, groups=in_channels), nn.Conv2d(in_channels, in_channels * 4, 1), nn.GELU() ) # 自门控注意力机制 self.gate_attn = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(in_channels, in_channels // 4, 1), nn.GELU(), nn.Conv2d(in_channels // 4, in_channels, 1), nn.Sigmoid() ) def forward(self, x): residual = x x = self.res_block(x) + residual # 深度特征提取 depth_feat = self.depthwise(x) # 门控特征融合 gate = self.gate_attn(depth_feat) x = x * gate + depth_feat # 原始输出 return super().forward(x) # ---------------------------- # 4. 训练函数 (增强版,添加进度条) # ---------------------------- def enhanced_train_model(model, train_loader, val_loader, criterion, optimizer, device, epochs=200, checkpoint_path=&#39;dbnet_checkpoint.pth&#39;, lr_init=5e-5): # 初始化 start_epoch = 0 best_loss = float(&#39;inf&#39;) best_f1 = 0.0 logger = EnhancedTrainingLogger() # 学习率调度器 (CosineAnnealingWarmRestarts) scheduler = ReduceLROnPlateau(optimizer, mode=&#39;min&#39;, factor=0.5, patience=3, verbose=True) # 混合精度训练 scaler = torch.cuda.amp.GradScaler() # 检查点恢复机制 if os.path.exists(checkpoint_path): print(f"发现检查点文件 {checkpoint_path}, 尝试恢复训练...") checkpoint = torch.load(checkpoint_path) model.load_state_dict(checkpoint[&#39;model_state_dict&#39;]) optimizer.load_state_dict(checkpoint[&#39;optimizer_state_dict&#39;]) start_epoch = checkpoint[&#39;epoch&#39;] + 1 best_loss = checkpoint[&#39;best_loss&#39;] logger = checkpoint[&#39;logger&#39;] print(f"成功恢复训练状态: 从第 {start_epoch} 轮开始, 最佳损失: {best_loss:.6f}") if not logger.total_losses: # 检查日志是否为空 logger = EnhancedTrainingLogger() # 创建新的日志记录器 model.train() optimizer.param_groups[0][&#39;lr&#39;] = lr_init try: for epoch in range(start_epoch, epochs): epoch_total_loss = 0.0 epoch_bin_loss = 0.0 epoch_thresh_loss = 0.0 epoch_db_loss = 0.0 epoch_start = time.time() # 使用tqdm添加进度条 pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch + 1}/{epochs}", unit="batch") for batch_idx, (images, binary_targets, thresh_targets) in pbar: images = images.to(device) binary_targets = binary_targets.to(device) thresh_targets = thresh_targets.to(device) # 混合精度训练 with torch.cuda.amp.autocast(): binary_preds, thresh_preds = model(images) total_loss, bin_loss, thresh_loss, db_loss = criterion( (binary_preds, thresh_preds), (binary_targets, thresh_targets) ) # 记录损失 epoch_total_loss += total_loss.item() epoch_bin_loss += bin_loss.item() epoch_thresh_loss += thresh_loss.item() epoch_db_loss += db_loss.item() # 记录日志 current_lr = optimizer.param_groups[0][&#39;lr&#39;] logger.on_batch_end( batch_idx, total_loss.item(), bin_loss.item(), thresh_loss.item(), db_loss.item(), current_lr ) # 更新进度条描述 pbar.set_postfix({ &#39;Loss&#39;: f"{total_loss.item():.4f}", &#39;Bin&#39;: f"{bin_loss.item():.4f}", &#39;Thresh&#39;: f"{thresh_loss.item():.4f}", &#39;DB&#39;: f"{db_loss.item():.4f}", &#39;LR&#39;: f"{current_lr:.2e}" }) # 反向传播 optimizer.zero_grad() scaler.scale(total_loss).backward() # 梯度裁剪 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) scaler.step(optimizer) scaler.update() # 更新学习率 scheduler.step(epoch + batch_idx / len(train_loader)) #100个batch保存一次紧急检查点 if batch_idx % 100 == 0: checkpoint = { &#39;epoch&#39;: epoch, &#39;model_state_dict&#39;: model.state_dict(), &#39;optimizer_state_dict&#39;: optimizer.state_dict(), &#39;best_loss&#39;: best_loss, &#39;logger&#39;: logger, &#39;scheduler_state&#39;: scheduler.state_dict() } torch.save(checkpoint, checkpoint_path) # 计算平均损失 num_batches = len(train_loader) avg_total_loss = epoch_total_loss / num_batches avg_bin_loss = epoch_bin_loss / num_batches avg_thresh_loss = epoch_thresh_loss / num_batches avg_db_loss = epoch_db_loss / num_batches # 验证模型 precision, recall, f1 = validate_model(model, val_loader, device) logger.val_metrics[&#39;precision&#39;].append(precision) logger.val_metrics[&#39;recall&#39;].append(recall) logger.val_metrics[&#39;f1&#39;].append(f1) epoch_time = time.time() - epoch_start print(f"Epoch [{epoch + 1}/{epochs}] completed in {epoch_time:.2f}s") print( f" - Avg Loss: {avg_total_loss:.6f} (Bin:{avg_bin_loss:.6f}, Thresh:{avg_thresh_loss:.6f}, DB:{avg_db_loss:.6f})") print(f" - Val Metrics: Precision={precision:.4f}, Recall={recall:.4f}, F1={f1:.4f}") logger.on_epoch_end(epoch, optimizer) # 保存最佳模型 if f1 > best_f1 or (f1 == best_f1 and avg_total_loss < best_loss): best_f1 = f1 best_loss = avg_total_loss torch.save({ &#39;epoch&#39;: epoch + 1, &#39;model_state_dict&#39;: model.state_dict(), &#39;optimizer_state_dict&#39;: optimizer.state_dict(), &#39;loss&#39;: avg_total_loss, &#39;f1&#39;: best_f1 }, &#39;dbnet_best.pth&#39;) print(f"🔥 发现新的最佳模型! F1: {best_f1:.4f}, 损失: {best_loss:.6f}") # 保存常规检查点 checkpoint = { &#39;epoch&#39;: epoch + 1, &#39;model_state_dict&#39;: model.state_dict(), &#39;optimizer_state_dict&#39;: optimizer.state_dict(), &#39;best_loss&#39;: best_loss, &#39;logger&#39;: logger, &#39;scheduler_state&#39;: scheduler.state_dict() } torch.save(checkpoint, checkpoint_path) except KeyboardInterrupt: print("\n训练被用户中断!") except Exception as e: print(f"\n❌ 训练中断! 原因: {str(e)}") traceback.print_exc() finally: print("训练完成! 保存最终模型...") torch.save(model.state_dict(), &#39;dbnet_final.pth&#39;) logger.on_train_end() return model # ---------------------------- # 5. 推理与区域裁剪 (增强版) # ---------------------------- def enhanced_detect_text_regions(image, model, device, threshold=0.3): # 预处理 orig_h, orig_w = image.shape[:2] input_img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) input_img = cv2.resize(input_img, (640, 640)) input_img = input_img.astype(np.float32) / 255.0 input_img = (input_img - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225] input_tensor = torch.from_numpy(input_img).permute(2, 0, 1).unsqueeze(0).to(device) input_tensor = input_tensor.to(torch.float32) # 推理 with torch.no_grad(): binary_map, _ = model(input_tensor) # 后处理 binary_map = binary_map.squeeze().cpu().numpy() binary_output = (binary_map > threshold).astype(np.uint8) * 255 # 形态学操作增强 kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) binary_output = cv2.morphologyEx(binary_output, cv2.MORPH_CLOSE, kernel) # 查找轮廓 contours, _ = cv2.findContours(binary_output, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) return contours, orig_h, orig_w # 返回检测到的文本区域轮廓 # ... (保持不变) ... def perspective_transform(image, contour): """对检测到的文本区域进行透视变换校正""" # 多边形逼近轮廓 epsilon = 0.02 * cv2.arcLength(contour, True) approx = cv2.approxPolyDP(contour, epsilon, True) # 确保是四边形 if len(approx) != 4: # 使用最小外接矩形 rect = cv2.minAreaRect(contour) box = cv2.boxPoints(rect) approx = np.int0(box) # 获取四边形顶点并排序 (左上, 右上, 右下, 左下) pts = approx.reshape(4, 2) rect_pts = np.zeros((4, 2), dtype="float32") # 计算顶点和 s = pts.sum(axis=1) rect_pts[0] = pts[np.argmin(s)] # 左上 rect_pts[2] = pts[np.argmax(s)] # 右下 # 计算顶点差 diff = np.diff(pts, axis=1) rect_pts[1] = pts[np.argmin(diff)] # 右上 rect_pts[3] = pts[np.argmax(diff)] # 左下 # 计算目标矩形尺寸 (tl, tr, br, bl) = rect_pts widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2)) widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2)) maxWidth = max(int(widthA), int(widthB)) heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2)) heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2)) maxHeight = max(int(heightA), int(heightB)) # 目标点坐标 dst = np.array([ [0, 0], [maxWidth - 1, 0], [maxWidth - 1, maxHeight - 1], [0, maxHeight - 1]], dtype="float32") # 计算透视变换矩阵并应用 M = cv2.getPerspectiveTransform(rect_pts, dst) warped = cv2.warpPerspective(image, M, (maxWidth, maxHeight)) return warped def crop_text_regions(image, contours, orig_h, orig_w): """裁剪检测到的文本区域并进行透视校正""" cropped_regions = [] # 计算缩放比例 (从640x640到原始尺寸) scale_x = orig_w / 640.0 scale_y = orig_h / 640.0 for contour in contours: # 过滤小区域 if cv2.contourArea(contour) < 100: continue # 缩放轮廓到原始图像尺寸 scaled_contour = contour.copy() scaled_contour[:, :, 0] = scaled_contour[:, :, 0] * scale_x scaled_contour[:, :, 1] = scaled_contour[:, :, 1] * scale_y # 获取轮廓边界框 x, y, w, h = cv2.boundingRect(scaled_contour) # 扩展边界框 (增加10%的边距) margin_x = int(w * 0.1) margin_y = int(h * 0.1) x = max(0, x - margin_x) y = max(0, y - margin_y) w = min(orig_w - x, w + 2 * margin_x) h = min(orig_h - y, h + 2 * margin_y) # 裁剪区域 roi = image[y:y + h, x:x + w] # 对裁剪区域进行透视校正 try: # 调整轮廓坐标到ROI坐标系 roi_contour = scaled_contour.copy() roi_contour[:, :, 0] -= x roi_contour[:, :, 1] -= y # 应用透视变换 warped_roi = perspective_transform(roi, roi_contour) # 确保最小尺寸 if warped_roi.shape[0] > 10 and warped_roi.shape[1] > 10: cropped_regions.append(warped_roi) except Exception as e: # 如果透视变换失败,使用原始ROI print(f"透视变换失败: {str(e)},使用原始ROI") cropped_regions.append(roi) return cropped_regions # ---------------------------- # 7. 模型加载与推理接口 (新增功能) # ---------------------------- def load_trained_model(model_path, device=&#39;cuda&#39;): """加载训练好的模型""" model = DBNet(pretrained=False).to(device) checkpoint = torch.load(model_path, map_location=device) model.load_state_dict(checkpoint[&#39;model_state_dict&#39;]) model.eval() return model # 5. 水表图像增强改进 def water_meter_specific_aug(image, **kwargs): """水表专用增强链""" # 抑制高频反光 kernel_size = int(min(image.shape[:2]) * 0.01) if kernel_size % 2 == 0: kernel_size += 1 blurred = cv2.GaussianBlur(image, (kernel_size, kernel_size), 0) # 自适应直方图均衡 lab = cv2.cvtColor(blurred, cv2.COLOR_RGB2LAB) l, a, b = cv2.split(lab) clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8)) l_eq = clahe.apply(l) # 色偏校正 a_balanced = cv2.normalize(a, None, 0, 255, cv2.NORM_MINMAX) b_balanced = cv2.normalize(b, None, 0, 255, cv2.NORM_MINMAX) return cv2.cvtColor(cv2.merge([l_eq, a_balanced, b_balanced]), cv2.COLOR_LAB2RGB) def suppress_glare(image): """减少图像反光区域的影响[1](@ref)""" lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB) l, a, b = cv2.split(lab) # 对亮度通道进行CLAHE均衡化 clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8)) l_clahe = clahe.apply(l) # 合并通道 lab_clahe = cv2.merge((l_clahe, a, b)) return cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2RGB) def detect_and_crop(image_path, model, device=&#39;cuda&#39;, output_dir=&#39;cropped_regions&#39;): """使用训练好的模型检测并裁剪水表数字区域""" # 创建输出目录 os.makedirs(output_dir, exist_ok=True) # 读取图像 image = cv2.imread(image_path) if image is None: print(f"错误: 无法读取图像 {image_path}") return [] # 应用反光抑制 image = suppress_glare(image) # 检测文本区域 contours, orig_h, orig_w = enhanced_detect_text_regions(image, model, device) # 裁剪文本区域 cropped_regions = crop_text_regions(image, contours, orig_h, orig_w) # 保存结果 base_name = os.path.splitext(os.path.basename(image_path))[0] for i, region in enumerate(cropped_regions): output_path = os.path.join(output_dir, f&#39;{base_name}_region_{i}.jpg&#39;) cv2.imwrite(output_path, region) print(f"成功裁剪 {len(cropped_regions)} 个文本区域到 {output_dir}") return cropped_regions # ---------------------------- # 8. 主程序 (优化版) # ---------------------------- if __name__ == "__main__": # 优化参数 INPUT_SIZE = (512, 512) # 减小输入尺寸适配水表 # 配置参数 DEVICE = &#39;cuda&#39; if torch.cuda.is_available() else &#39;cpu&#39; DATA_DIR = &#39;images_train&#39; LABEL_DIR = &#39;labels_train&#39; VAL_DATA_DIR = &#39;images_val&#39; VAL_LABEL_DIR = &#39;labels_val&#39; BATCH_SIZE = 16 EPOCHS = 100 LR = 1e-4 CHECKPOINT_PATH = &#39;dbnet_checkpoint.pth&#39; TRAINED_MODEL_PATH = &#39;dbnet_best.pth&#39; # 模式选择: &#39;train&#39; 或 &#39;inference&#39; MODE = &#39;train&#39; if MODE == &#39;train&#39;: # 1. 准备数据集 print("准备训练数据集...") train_dataset = WaterMeterDataset( image_dir=DATA_DIR, label_dir=LABEL_DIR, input_size=INPUT_SIZE, augment=True ) train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True) print("准备验证数据集...") val_dataset = WaterMeterDataset( image_dir=VAL_DATA_DIR, label_dir=VAL_LABEL_DIR, input_size=INPUT_SIZE, augment=False ) val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2) # 2. 初始化模型 print("初始化模型...") model = DBNet(pretrained=True).to(DEVICE) # 3. 损失函数和优化器 # 初始化时使用自适应损失 criterion = DBLoss(alpha=1.0, beta=8.0) # 使用更先进的优化器 # 1. 强化学习率调度机制 - 更新优化器配置 optimizer = optim.AdamW( # 替换原始Adam model.parameters(), lr=3e-4, # 适当提升基础学习率 weight_decay=1e-4 ) # 4. 训练模型 print("开始训练...") model = enhanced_train_model( model, train_loader, val_loader, criterion, optimizer, DEVICE, epochs=EPOCHS, checkpoint_path=CHECKPOINT_PATH, lr_init=LR ) print(f"✅ 训练完成! 最佳模型已保存到 {TRAINED_MODEL_PATH}") elif MODE == &#39;inference&#39;: # 加载训练好的模型 print(f"加载训练好的模型: {TRAINED_MODEL_PATH}") model = load_trained_model(TRAINED_MODEL_PATH, DEVICE) # 处理单个图像 test_image_path = &#39;test_images/test_1.jpg&#39; print(f"处理测试图像: {test_image_path}") detect_and_crop(test_image_path, model, DEVICE) # 处理整个目录 input_dir = &#39;test_images&#39; output_dir = &#39;cropped_results&#39; print(f"批量处理目录: {input_dir}") for img_file in os.listdir(input_dir): if img_file.lower().endswith((&#39;.jpg&#39;, &#39;.png&#39;, &#39;.jpeg&#39;)): img_path = os.path.join(input_dir, img_file) print(f"处理图像: {img_file}") detect_and_crop(img_path, model, DEVICE, output_dir)
最新发布
06-07
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值