import traceback
import cv2
import json
import os
import sys
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from albumentations import (
Compose, Resize, Normalize, HorizontalFlip, VerticalFlip, Rotate, OneOf,
RandomBrightnessContrast, GaussNoise, ElasticTransform, RandomGamma, HueSaturationValue, CoarseDropout, Perspective,
KeypointParams,
CLAHE, MotionBlur, ISONoise,Lambda
)
from albumentations.pytorch import ToTensorV2
from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, StepLR, ReduceLROnPlateau
from torch.utils.data import Dataset, DataLoader
from torchvision.models import resnet18, ResNet18_Weights
from sklearn.metrics import precision_score, recall_score, f1_score
import matplotlib.pyplot as plt
from tqdm import tqdm # 添加进度条库
class EnhancedTrainingLogger:
"""增强的训练日志记录器,跟踪多种损失指标并实时可视化"""
def __init__(self):
self.total_losses = []
self.bin_losses = []
self.thresh_losses = []
self.db_losses = []
self.timestamps = []
self.start_time = time.time()
self.lr_history = []
self.val_metrics = {'precision': [], 'recall': [], 'f1': []}
# 实时可视化设置
plt.ion() # 开启交互模式
self.fig, self.axs = plt.subplots(2, 2, figsize=(15, 10))
self.fig.suptitle('Training Progress', fontsize=16)
# 初始化图表
self.loss_line, = self.axs[0, 0].plot([], [], 'r-', label='Total Loss')
self.bin_line, = self.axs[0, 0].plot([], [], 'g-', label='Binary Loss')
self.thresh_line, = self.axs[0, 0].plot([], [], 'b-', label='Threshold Loss')
self.db_line, = self.axs[0, 0].plot([], [], 'm-', label='DB Loss')
self.axs[0, 0].set_title('Training Loss Components')
self.axs[0, 0].set_xlabel('Batch')
self.axs[0, 0].set_ylabel('Loss')
self.axs[0, 0].legend()
self.axs[0, 0].grid(True)
self.lr_line, = self.axs[0, 1].plot([], [], 'c-')
self.axs[0, 1].set_title('Learning Rate Schedule')
self.axs[0, 1].set_xlabel('Batch')
self.axs[0, 1].set_ylabel('Learning Rate')
self.axs[0, 1].grid(True)
self.precision_line, = self.axs[1, 0].plot([], [], 'r-', label='Precision')
self.recall_line, = self.axs[1, 0].plot([], [], 'g-', label='Recall')
self.f1_line, = self.axs[1, 0].plot([], [], 'b-', label='F1 Score')
self.axs[1, 0].set_title('Validation Metrics')
self.axs[1, 0].set_xlabel('Epoch')
self.axs[1, 0].set_ylabel('Score')
self.axs[1, 0].legend()
self.axs[1, 0].grid(True)
# 添加文本区域显示当前指标
self.metrics_text = self.axs[1, 1].text(0.5, 0.5, "",
horizontalalignment='center',
verticalalignment='center',
transform=self.axs[1, 1].transAxes,
fontsize=12)
self.axs[1, 1].axis('off') # 关闭坐标轴
plt.tight_layout()
plt.subplots_adjust(top=0.9)
plt.draw()
plt.pause(0.1)
def on_batch_end(self, batch_idx, total_loss, bin_loss, thresh_loss, db_loss, lr=None):
elapsed = time.time() - self.start_time
self.total_losses.append(total_loss)
self.bin_losses.append(bin_loss)
self.thresh_losses.append(thresh_loss)
self.db_losses.append(db_loss)
self.timestamps.append(elapsed)
if lr is not None:
self.lr_history.append(lr)
# 更新实时图表
self.update_plots(batch_idx)
# 每10个batch打印详细日志
if batch_idx % 10 == 0:
avg_total = np.mean(self.total_losses[-10:]) if len(self.total_losses) >= 10 else total_loss
avg_bin = np.mean(self.bin_losses[-10:]) if len(self.bin_losses) >= 10 else bin_loss
avg_thresh = np.mean(self.thresh_losses[-10:]) if len(self.thresh_losses) >= 10 else thresh_loss
avg_db = np.mean(self.db_losses[-10:]) if len(self.db_losses) >= 10 else db_loss
# 更新文本区域
metrics_text = (
f"Batch: {batch_idx}\n"
f"Total Loss: {total_loss:.4f} (Avg10: {avg_total:.4f})\n"
f"Binary Loss: {bin_loss:.4f} (Avg10: {avg_bin:.4f})\n"
f"Threshold Loss: {thresh_loss:.4f} (Avg10: {avg_thresh:.4f})\n"
f"DB Loss: {db_loss:.4f} (Avg10: {avg_db:.4f})\n"
f"Learning Rate: {lr:.2e}\n"
f"Time: {int(elapsed // 3600):02d}:{int((elapsed % 3600) // 60):02d}:{int(elapsed % 60):02d}"
)
self.metrics_text.set_text(metrics_text)
# 刷新图表
plt.draw()
plt.pause(0.01)
def update_plots(self, batch_idx):
# 更新损失图表
x_data = np.arange(len(self.total_losses))
self.loss_line.set_data(x_data, self.total_losses)
self.bin_line.set_data(x_data, self.bin_losses)
self.thresh_line.set_data(x_data, self.thresh_losses)
self.db_line.set_data(x_data, self.db_losses)
# 自动调整Y轴范围
all_losses = self.total_losses + self.bin_losses + self.thresh_losses + self.db_losses
if all_losses:
min_loss = min(all_losses) * 0.9
max_loss = max(all_losses) * 1.1
self.axs[0, 0].set_ylim(min_loss, max_loss)
# 更新学习率图表
if self.lr_history:
self.lr_line.set_data(np.arange(len(self.lr_history)), self.lr_history)
self.axs[0, 1].set_ylim(min(self.lr_history) * 0.9, max(self.lr_history) * 1.1)
# 更新验证指标图表
if self.val_metrics['precision']:
x_epochs = np.arange(len(self.val_metrics['precision']))
self.precision_line.set_data(x_epochs, self.val_metrics['precision'])
self.recall_line.set_data(x_epochs, self.val_metrics['recall'])
self.f1_line.set_data(x_epochs, self.val_metrics['f1'])
# 自动调整Y轴范围
all_metrics = self.val_metrics['precision'] + self.val_metrics['recall'] + self.val_metrics['f1']
if all_metrics:
min_metric = min(all_metrics) * 0.9
max_metric = max(all_metrics) * 1.1
self.axs[1, 0].set_ylim(min_metric, max_metric)
# 调整X轴范围
self.axs[0, 0].set_xlim(0, max(1, len(self.total_losses)))
self.axs[0, 1].set_xlim(0, max(1, len(self.lr_history)))
if self.val_metrics['precision']:
self.axs[1, 0].set_xlim(0, max(1, len(self.val_metrics['precision'])))
def on_epoch_end(self, epoch, optimizer=None):
# 添加空列表检查
total_min = min(self.total_losses) if self.total_losses else 0.0
total_max = max(self.total_losses) if self.total_losses else 0.0
total_avg = np.mean(self.total_losses) if self.total_losses else 0.0
bin_min = min(self.bin_losses) if self.bin_losses else 0.0
bin_avg = np.mean(self.bin_losses) if self.bin_losses else 0.0
thresh_min = min(self.thresh_losses) if self.thresh_losses else 0.0
thresh_avg = np.mean(self.thresh_losses) if self.thresh_losses else 0.0
db_min = min(self.db_losses) if self.db_losses else 0.0
db_avg = np.mean(self.db_losses) if self.db_losses else 0.0
# 生成详细的损失报告
report = (
f"\n{'=' * 70}\n"
f"EPOCH {epoch + 1} SUMMARY:\n"
f" - Total Loss: Min={total_min:.6f}, Max={total_max:.6f}, Avg={total_avg:.6f}\n"
f" - Binary Loss: Min={bin_min:.6f}, Avg={bin_avg:.6f}\n"
f" - Threshold Loss: Min={thresh_min:.6f}, Avg={thresh_avg:.6f}\n"
f" - DB Loss: Min={db_min:.6f}, Avg={db_avg:.6f}\n"
)
if self.val_metrics['precision']:
report += (
f" - Val Metrics: Precision={self.val_metrics['precision'][-1]:.4f}, "
f"Recall={self.val_metrics['recall'][-1]:.4f}, F1={self.val_metrics['f1'][-1]:.4f}\n"
)
if optimizer:
report += f" - Learning Rate: {optimizer.param_groups[0]['lr']:.6e}\n"
report += f"{'=' * 70}"
print(report)
# 保存CSV日志
with open(f'training_log_epoch_{epoch + 1}.csv', 'w') as f:
f.write("Timestamp,Total_Loss,Bin_Loss,Thresh_Loss,DB_Loss,Learning_Rate\n")
for i, t in enumerate(self.timestamps):
lr_val = self.lr_history[i] if i < len(self.lr_history) else 0
f.write(
f"{t:.2f},{self.total_losses[i]:.6f},{self.bin_losses[i]:.6f},{self.thresh_losses[i]:.6f},{self.db_losses[i]:.6f},{lr_val:.6e}\n")
# 重置记录(保留最后一个批次的值)
self.total_losses = [self.total_losses[-1]] if self.total_losses else []
self.bin_losses = [self.bin_losses[-1]] if self.bin_losses else []
self.thresh_losses = [self.thresh_losses[-1]] if self.thresh_losses else []
self.db_losses = [self.db_losses[-1]] if self.db_losses else []
self.timestamps = [self.timestamps[-1]] if self.timestamps else []
self.lr_history = [self.lr_history[-1]] if self.lr_history else []
# 更新图表
self.update_plots(0)
plt.draw()
plt.pause(0.1)
def on_train_end(self):
"""训练结束后生成图表并保存"""
plt.ioff() # 关闭交互模式
# 保存最终图表
plt.savefig('training_summary.png')
plt.close()
# 生成详细的训练报告图
self.generate_detailed_report()
def generate_detailed_report(self):
"""生成详细的训练报告图表"""
fig, axs = plt.subplots(3, 1, figsize=(12, 15))
# 损失图表
axs[0].plot(self.total_losses, label='Total Loss')
axs[0].plot(self.bin_losses, label='Binary Loss')
axs[0].plot(self.thresh_losses, label='Threshold Loss')
axs[0].plot(self.db_losses, label='DB Loss')
axs[0].set_title('Training Loss Components')
axs[0].set_xlabel('Batch')
axs[0].set_ylabel('Loss')
axs[0].legend()
axs[0].grid(True)
# 学习率图表
axs[1].plot(self.lr_history)
axs[1].set_title('Learning Rate Schedule')
axs[1].set_xlabel('Batch')
axs[1].set_ylabel('Learning Rate')
axs[1].grid(True)
# 验证指标图表
if self.val_metrics['precision']:
axs[2].plot(self.val_metrics['precision'], 'o-', label='Precision')
axs[2].plot(self.val_metrics['recall'], 'o-', label='Recall')
axs[2].plot(self.val_metrics['f1'], 'o-', label='F1 Score')
axs[2].set_title('Validation Metrics')
axs[2].set_xlabel('Epoch')
axs[2].set_ylabel('Score')
axs[2].legend()
axs[2].grid(True)
# 标记最佳F1分数
best_f1_idx = np.argmax(self.val_metrics['f1'])
best_f1 = self.val_metrics['f1'][best_f1_idx]
axs[2].plot(best_f1_idx, best_f1, 'ro', markersize=8)
axs[2].annotate(f'Best F1: {best_f1:.4f}',
xy=(best_f1_idx, best_f1),
xytext=(best_f1_idx + 0.5, best_f1 - 0.05),
arrowprops=dict(facecolor='black', shrink=0.05))
plt.tight_layout()
plt.savefig('training_detailed_report.png')
plt.close()
# 在类外部定义全局函数
# 在类外部定义全局函数
def suppress_water_meter_glare(img, **kwargs):
"""水表专用反光抑制(忽略额外参数)"""
lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
l, a, b = cv2.split(lab)
# 动态计算CLAHE参数
l_mean = np.mean(l)
clip_limit = 2.0 + (l_mean / 40) # 亮度越高,clipLimit越大
clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=(8, 8))
l_clahe = clahe.apply(l)
# 选择性增强暗部区域
_, mask = cv2.threshold(l, 100, 255, cv2.THRESH_BINARY_INV)
blended = cv2.addWeighted(l, 0.7, l_clahe, 0.3, 0)
l_final = np.where(mask > 0, blended, l)
lab = cv2.merge((l_final, a, b))
return cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
# ----------------------------
# 1. 数据集加载与预处理 (优化浮点坐标处理)
# ----------------------------
class WaterMeterDataset(Dataset):
"""水表数字区域检测数据集 - 优化浮点坐标处理"""
# ... (初始化代码保持不变) ...
def __init__(self, image_dir, label_dir, input_size=(640, 640), augment=True):
self.image_dir = image_dir
self.label_dir = label_dir
self.input_size = input_size
self.augment = augment
self.image_files = [f for f in os.listdir(image_dir) if os.path.isfile(os.path.join(image_dir, f))]
# 基础预处理流程
self.base_transform = Compose([
Resize(height=input_size[0], width=input_size[1]),
Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2()
])
# 简化但有效的数据增强
self.augmentation = Compose([
# 水表专用增强
OneOf([
# 模拟不同角度拍摄
Perspective(scale=(0.05, 0.1), p=0.3),
# 模拟水表玻璃反光
RandomGamma(gamma_limit=(80, 120), p=0.2),
# 模拟水表污渍
CoarseDropout(max_holes=5, max_height=20, max_width=20,
fill_value=0, p=0.2)
], p=0.8),
# 水表反光抑制
Lambda(name='glare_reduction', image=suppress_water_meter_glare),
Lambda(name='water_meter_aug', image=water_meter_specific_aug, p=0.7),
OneOf([
HorizontalFlip(p=0.3),
VerticalFlip(p=0.2),
Rotate(limit=15, p=0.5)
], p=0.7),
OneOf([
RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
CLAHE(clip_limit=2.0, p=0.3),
GaussNoise(std_range=(0.15, 0.4), # 优化后范围
mean_range=(0, 0),
per_channel=True,
p=0.3),
ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.5))
], p=0.7)
], p=0.8, keypoint_params=KeypointParams(format='xyas')) if augment else None
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
img_name = self.image_files[idx]
img_path = os.path.join(self.image_dir, img_name)
# 加载图像
image = cv2.imread(img_path)
if image is None:
print(f"错误: 无法读取图像 {img_path}")
return self[(idx + 1) % len(self)] # 跳过错误图像
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# 应用反光抑制
if np.random.rand() > 0.5:
lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
l, a, b = cv2.split(lab)
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
l = clahe.apply(l)
lab = cv2.merge([l, a, b])
image = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
# 解析标注
base_name = os.path.splitext(img_name)[0]
label_path = os.path.join(self.label_dir, base_name + '.json')
try:
with open(label_path) as f:
label_data = json.load(f)
polygons = []
orig_h, orig_w = image.shape[:2]
# 获取标注时的图像尺寸(如果存在)
json_h = label_data.get('imageHeight', orig_h)
json_w = label_data.get('imageWidth', orig_w)
# 计算缩放比例(处理不同尺寸的标注)
scale_x = orig_w / json_w
scale_y = orig_h / json_h
for shape in label_data['shapes']:
if shape['shape_type'] == 'polygon':
# 直接使用浮点坐标,避免整数转换
poly = np.array(shape['points'], dtype=np.float32)
# 应用缩放比例
poly[:, 0] = poly[:, 0] * scale_x
poly[:, 1] = poly[:, 1] * scale_y
# 裁剪到实际图像范围内
poly[:, 0] = np.clip(poly[:, 0], 0, orig_w - 1)
poly[:, 1] = np.clip(poly[:, 1], 0, orig_h - 1)
polygons.append(poly)
# 生成目标前验证标注有效性
if len(polygons) == 0:
print(f"警告: {img_name} 无有效标注,使用随机样本替代")
return self[np.random.randint(0, len(self))]
# === 调试可视化 ===
if idx < 5:
debug_img = image.copy()
for poly in polygons:
int_poly = poly.astype(np.int32).reshape(-1, 1, 2)
cv2.polylines(debug_img, [int_poly], True, (0, 255, 0), 3)
debug_info = f"Size: {orig_w}x{orig_h} | Polys: {len(polygons)}"
cv2.putText(debug_img, debug_info, (10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2)
debug_path = f"debug_{base_name}.jpg"
cv2.imwrite(debug_path, cv2.cvtColor(debug_img, cv2.COLOR_RGB2BGR))
print(f"保存调试图像: {debug_path}")
# 应用数据增强
keypoints = []
for poly in polygons:
for point in poly:
# 保留浮点精度
keypoints.append((point[0], point[1], 0, 0))
if self.augment and self.augmentation:
poly_lengths = [len(poly) for poly in polygons]
# 应用增强
augmented = self.augmentation(image=image, keypoints=keypoints)
image = augmented['image']
keypoints = augmented['keypoints']
# 正确重组多边形
polygons = []
start_idx = 0
for poly_len in poly_lengths:
end_idx = start_idx + poly_len
if end_idx <= len(keypoints):
poly_points = keypoints[start_idx:end_idx]
new_poly = np.array([[p[0], p[1]] for p in poly_points], dtype=np.float32)
polygons.append(new_poly)
start_idx = end_idx
# 对所有多边形进行边界裁剪
for poly in polygons:
poly[:, 0] = np.clip(poly[:, 0], 0, image.shape[1] - 1)
poly[:, 1] = np.clip(poly[:, 1], 0, image.shape[0] - 1)
except (FileNotFoundError, json.JSONDecodeError) as e:
print(f"警告: 无法加载标注文件 {label_path} - {str(e)}")
polygons = []
# 记录数据增强后的图像尺寸
aug_h, aug_w = image.shape[:2]
# 基础预处理(包含Resize)
processed = self.base_transform(image=image)
image_tensor = processed['image']
# 将多边形坐标缩放到input_size
scale_x = self.input_size[1] / aug_w
scale_y = self.input_size[0] / aug_h
scaled_polygons = []
for poly in polygons:
scaled_poly = poly.copy()
scaled_poly[:, 0] = scaled_poly[:, 0] * scale_x
scaled_poly[:, 1] = scaled_poly[:, 1] * scale_y
scaled_poly[:, 0] = np.clip(scaled_poly[:, 0], 0, self.input_size[1] - 1)
scaled_poly[:, 1] = np.clip(scaled_poly[:, 1], 0, self.input_size[0] - 1)
scaled_polygons.append(scaled_poly)
# 生成目标(使用input_size尺寸)
binary_target = self.generate_binary_target(scaled_polygons, (self.input_size[0], self.input_size[1]))
threshold_target = self.generate_threshold_target(scaled_polygons, (self.input_size[0], self.input_size[1]))
return image_tensor, binary_target, threshold_target
def generate_threshold_target(self, polygons, img_shape, ratio=0.4):
"""生成阈值目标图(优化浮点坐标处理)"""
# 定义输出尺寸(特征图尺寸)
output_size = (self.input_size[0] // 8, self.input_size[1] // 8)
# 创建全尺寸距离图
full_size_map = np.zeros(img_shape[:2], dtype=np.float32)
for poly in polygons:
if len(poly) < 3:
continue
# 确保坐标在图像范围内
poly[:, 0] = np.clip(poly[:, 0], 0, img_shape[1] - 1)
poly[:, 1] = np.clip(poly[:, 1], 0, img_shape[0] - 1)
# 计算最大距离(防止除零错误)
area = cv2.contourArea(poly)
perimeter = cv2.arcLength(poly, True)
if perimeter < 1e-3 or area < 10:
continue
max_dist = area * (1 - ratio ** 2) / max(perimeter, 1e-3)
# 创建浮点精度的多边形掩码
mask = np.zeros(img_shape[:2], dtype=np.uint8)
int_poly = poly.reshape((-1, 1, 2)).astype(np.int32)
cv2.fillPoly(mask, [int_poly], 255)
# 计算距离变换并更新全尺寸图
dist = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
normalized = np.clip(dist / max(max_dist, 1e-6), 0, 1)
full_size_map = np.maximum(full_size_map, normalized)
# 下采样到特征图尺寸
dist_map = cv2.resize(full_size_map, output_size, interpolation=cv2.INTER_LINEAR)
# 空目标检查
if np.max(dist_map) < 1e-6:
return torch.zeros((1, *output_size), dtype=torch.float32)
return torch.from_numpy(dist_map).unsqueeze(0).float()
def generate_binary_target(self, polygons, img_shape):
"""生成二值化目标图(优化浮点坐标处理)"""
# 直接在目标尺寸上创建
output_size = (self.input_size[0] // 8, self.input_size[1] // 8)
binary_map = np.zeros(output_size, dtype=np.float32)
# 计算缩放比例 (原始图像->特征图)
scale_x = output_size[1] / img_shape[1]
scale_y = output_size[0] / img_shape[0]
for poly in polygons:
if len(poly) > 2:
# 缩放多边形到特征图尺寸(保持浮点精度)
scaled_poly = poly.copy()
scaled_poly[:, 0] = scaled_poly[:, 0] * scale_x
scaled_poly[:, 1] = scaled_poly[:, 1] * scale_y
# 使用浮点坐标填充(更精确)
int_poly = scaled_poly.reshape((-1, 1, 2)).astype(np.float32)
# 创建临时画布进行填充
temp_canvas = np.zeros(output_size, dtype=np.uint8)
cv2.fillPoly(temp_canvas, [int_poly.astype(np.int32)], 1)
binary_map = np.maximum(binary_map, temp_canvas.astype(np.float32))
return torch.from_numpy(binary_map).unsqueeze(0).float()
# ----------------------------
# 2. DBNet模型定义 (增强版)
# ----------------------------
class DBNet(nn.Module):
"""基于ResNet18的DBNet文本检测模型"""
def __init__(self, pretrained=True):
super(DBNet, self).__init__()
base_model = resnet18(weights=ResNet18_Weights.DEFAULT)
# 提取中间特征层
self.conv1 = base_model.conv1
self.bn1 = base_model.bn1
self.relu = base_model.relu
self.maxpool = base_model.maxpool
self.layer1 = base_model.layer1
self.layer2 = base_model.layer2
self.layer3 = base_model.layer3
self.layer4 = base_model.layer4
# 特征融合层
self.fusion_conv = nn.Sequential(
nn.Conv2d(512, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True)
)
# 检测头
self.db_head = DBHead(64)
def forward(self, x):
# 骨干网络前向传播
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
# 特征融合
fused = self.fusion_conv(x)
# 检测头
binary_map, thresh_map = self.db_head(fused)
return binary_map, thresh_map
class DBHead(nn.Module):
"""DBNet检测头,包含注意力机制和残差连接"""
def __init__(self, in_channels):
super(DBHead, self).__init__()
# 修改DBHead的残差块
self.res_block = nn.Sequential(
nn.Conv2d(in_channels, in_channels, 3, padding=1),
nn.BatchNorm2d(in_channels),
nn.LeakyReLU(0.2, inplace=True), # 使用LeakyReLU防止梯度消失
nn.Conv2d(in_channels, in_channels, 3, padding=1),
nn.BatchNorm2d(in_channels)
)
# 添加空间注意力机制
self.spatial_attn = nn.Sequential(
nn.Conv2d(in_channels, 1, kernel_size=3, padding=1),
nn.Sigmoid()
)
# 通道注意力机制
self.attention = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, in_channels // 8, 1),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels // 8, in_channels, 1),
nn.Sigmoid()
)
# 二值化分支
self.binarize = nn.Sequential(
nn.Conv2d(in_channels, in_channels // 2, 3, padding=1),
nn.BatchNorm2d(in_channels // 2),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels // 2, in_channels // 2, 3, padding=1),
nn.BatchNorm2d(in_channels // 2),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(in_channels // 2, in_channels // 4, 4, stride=2, padding=1),
nn.BatchNorm2d(in_channels // 4),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(in_channels // 4, 1, 4, stride=2, padding=1),
nn.Sigmoid()
)
# 阈值分支
self.thresh = nn.Sequential(
nn.Conv2d(in_channels, in_channels // 2, 3, padding=1),
nn.BatchNorm2d(in_channels // 2),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels // 2, in_channels // 4, 3, padding=1),
nn.BatchNorm2d(in_channels // 4),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(in_channels // 4, in_channels // 8, 4, stride=2, padding=1),
nn.BatchNorm2d(in_channels // 8),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(in_channels // 8, 1, 4, stride=2, padding=1),
nn.Sigmoid()
)
def forward(self, x):
# 残差连接
residual = x
x = self.res_block(x) + residual
# 空间注意力
attn_map = self.spatial_attn(x)
x = x * attn_map
binary_map = self.binarize(x)
thresh_map = self.thresh(x)
return binary_map, thresh_map
# ----------------------------
# 3. 损失函数定义 (增强版)
# ----------------------------
class DBLoss(nn.Module):
"""重构后的 DBNet 损失函数,符合原始论文设计[1,2,4](@ref)"""
def __init__(self, alpha=1.0, beta=10.0, k=50, ohem_ratio=3.0):
super(DBLoss, self).__init__()
self.alpha = alpha # 概率图损失权重
self.beta = beta # 阈值图损失权重
self.k = k # 可微二值化参数[1](@ref)
self.ohem_ratio = ohem_ratio
def forward(self, preds, targets):
binary_pred, thresh_pred = preds
binary_target, thresh_target = targets
# 1. 概率图损失(二值图损失)使用带 OHEM 的 Dice Loss[2](@ref)
prob_loss = self.dice_loss_with_ohem(binary_pred, binary_target)
# 2. 阈值图损失使用 L1 Loss[4](@ref)
thresh_loss = F.l1_loss(thresh_pred, thresh_target, reduction='mean')
# 3. 可微二值化计算[1](@ref)
with torch.no_grad():
# 计算近似二值图 B = 1 / (1 + exp(-k(P - T)))
binary_map = torch.sigmoid(self.k * (binary_pred - thresh_pred))
# 4. 二值图损失使用 Dice Loss
bin_loss = self.dice_loss(binary_map, binary_target)
# 5. 组合损失:L = L_s + α × L_t + β × L_b
total_loss = prob_loss + self.alpha * thresh_loss + self.beta * bin_loss
return total_loss, prob_loss, thresh_loss, bin_loss
def dice_loss(self, pred, target):
"""标准 Dice Loss 实现"""
smooth = 1.0
intersection = (pred * target).sum()
union = pred.sum() + target.sum()
return 1 - (2. * intersection + smooth) / (union + smooth)
def dice_loss_with_ohem(self, pred, target):
"""带 OHEM 的 Dice Loss 实现[2](@ref)"""
# 计算每个像素的损失
loss_map = 1 - (2 * pred * target + 1) / (pred + target + 1)
# 应用 OHEM 采样
pos_mask = (target > 0.5).float()
neg_mask = 1 - pos_mask
# 计算正负样本数量
n_pos = pos_mask.sum().item()
n_neg = min(int(n_pos * self.ohem_ratio), neg_mask.sum().item())
if n_neg == 0:
return self.dice_loss(pred, target)
# 选择最难负样本
neg_loss = loss_map * neg_mask
neg_loss = neg_loss.view(-1)
topk_neg_loss, _ = torch.topk(neg_loss, n_neg)
# 组合正负样本损失
pos_loss = (loss_map * pos_mask).sum()
total_loss = (pos_loss + topk_neg_loss.sum()) / (n_pos + n_neg + 1e-6)
return total_loss
# ----------------------------
# 辅助函数 (保持不变)
# ----------------------------
def calculate_metrics(pred, target, threshold=0.5):
"""计算精确度、召回率和F1分数"""
pred_bin = (pred > threshold).float()
target_bin = (target > 0.5).float()
pred_flat = pred_bin.view(-1).cpu().numpy()
target_flat = target_bin.view(-1).cpu().numpy()
# 避免全零情况
if np.sum(target_flat) == 0:
return 0.0, 0.0, 0.0
precision = precision_score(target_flat, pred_flat, zero_division=0)
recall = recall_score(target_flat, pred_flat, zero_division=0)
f1 = f1_score(target_flat, pred_flat, zero_division=0)
return precision, recall, f1
# ... (保持不变) ...
def validate_model(model, dataloader, device):
"""验证模型性能"""
model.eval()
total_precision = 0.0
total_recall = 0.0
total_f1 = 0.0
num_batches = 0
with torch.no_grad():
for images, binary_targets, _ in dataloader:
images = images.to(device)
binary_targets = binary_targets.to(device)
binary_preds, _ = model(images)
precision, recall, f1 = calculate_metrics(binary_preds, binary_targets)
total_precision += precision
total_recall += recall
total_f1 += f1
num_batches += 1
avg_precision = total_precision / num_batches
avg_recall = total_recall / num_batches
avg_f1 = total_f1 / num_batches
return avg_precision, avg_recall, avg_f1
# 2. 动态损失权重校准 - 修改DBLoss类
class AdaptiveDBLoss(DBLoss):
def __init__(self, alpha=1.0, beta=5.0, gamma=2.0, adapt_step=100):
super().__init__(alpha, beta, gamma)
self.adapt_step = adapt_step
self.beta_history = []
def forward(self, preds, targets, step):
# 动态调整β系数
if step % self.adapt_step == 0 and len(self.beta_history) > 10:
db_median = np.median(self.beta_history[-10:])
self.beta = max(1.0, min(db_median * 0.8, 10.0))
total_loss, bin_loss, thresh_loss, db_loss = super().forward(preds, targets)
# 记录当前β值的表现
self.beta_history.append(db_loss.item())
return total_loss, bin_loss, thresh_loss, db_loss
# 3. 模型架构增强 - 替换原始DBHead
class EnhancedDBHead(DBHead):
def __init__(self, in_channels):
super().__init__(in_channels)
# 增加通道容量
self.res_block = nn.Sequential(
nn.Conv2d(in_channels, in_channels * 2, 3, padding=1),
nn.GroupNorm(8, in_channels * 2),
nn.GELU(),
nn.Conv2d(in_channels * 2, in_channels, 3, padding=1),
nn.GroupNorm(8, in_channels)
)
# 深度可分离卷积增强特征
self.depthwise = nn.Sequential(
nn.Conv2d(in_channels, in_channels, 3, padding=1, groups=in_channels),
nn.Conv2d(in_channels, in_channels * 4, 1),
nn.GELU()
)
# 自门控注意力机制
self.gate_attn = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, in_channels // 4, 1),
nn.GELU(),
nn.Conv2d(in_channels // 4, in_channels, 1),
nn.Sigmoid()
)
def forward(self, x):
residual = x
x = self.res_block(x) + residual
# 深度特征提取
depth_feat = self.depthwise(x)
# 门控特征融合
gate = self.gate_attn(depth_feat)
x = x * gate + depth_feat
# 原始输出
return super().forward(x)
# ----------------------------
# 4. 训练函数 (增强版,添加进度条)
# ----------------------------
def enhanced_train_model(model, train_loader, val_loader, criterion, optimizer, device,
epochs=200, checkpoint_path='dbnet_checkpoint.pth', lr_init=5e-5):
# 初始化
start_epoch = 0
best_loss = float('inf')
best_f1 = 0.0
logger = EnhancedTrainingLogger()
# 学习率调度器 (CosineAnnealingWarmRestarts)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)
# 混合精度训练
scaler = torch.cuda.amp.GradScaler()
# 检查点恢复机制
if os.path.exists(checkpoint_path):
print(f"发现检查点文件 {checkpoint_path}, 尝试恢复训练...")
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch'] + 1
best_loss = checkpoint['best_loss']
logger = checkpoint['logger']
print(f"成功恢复训练状态: 从第 {start_epoch} 轮开始, 最佳损失: {best_loss:.6f}")
if not logger.total_losses: # 检查日志是否为空
logger = EnhancedTrainingLogger() # 创建新的日志记录器
model.train()
optimizer.param_groups[0]['lr'] = lr_init
try:
for epoch in range(start_epoch, epochs):
epoch_total_loss = 0.0
epoch_bin_loss = 0.0
epoch_thresh_loss = 0.0
epoch_db_loss = 0.0
epoch_start = time.time()
# 使用tqdm添加进度条
pbar = tqdm(enumerate(train_loader), total=len(train_loader),
desc=f"Epoch {epoch + 1}/{epochs}", unit="batch")
for batch_idx, (images, binary_targets, thresh_targets) in pbar:
images = images.to(device)
binary_targets = binary_targets.to(device)
thresh_targets = thresh_targets.to(device)
# 混合精度训练
with torch.cuda.amp.autocast():
binary_preds, thresh_preds = model(images)
total_loss, bin_loss, thresh_loss, db_loss = criterion(
(binary_preds, thresh_preds),
(binary_targets, thresh_targets)
)
# 记录损失
epoch_total_loss += total_loss.item()
epoch_bin_loss += bin_loss.item()
epoch_thresh_loss += thresh_loss.item()
epoch_db_loss += db_loss.item()
# 记录日志
current_lr = optimizer.param_groups[0]['lr']
logger.on_batch_end(
batch_idx,
total_loss.item(),
bin_loss.item(),
thresh_loss.item(),
db_loss.item(),
current_lr
)
# 更新进度条描述
pbar.set_postfix({
'Loss': f"{total_loss.item():.4f}",
'Bin': f"{bin_loss.item():.4f}",
'Thresh': f"{thresh_loss.item():.4f}",
'DB': f"{db_loss.item():.4f}",
'LR': f"{current_lr:.2e}"
})
# 反向传播
optimizer.zero_grad()
scaler.scale(total_loss).backward()
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.step(optimizer)
scaler.update()
# 更新学习率
scheduler.step(epoch + batch_idx / len(train_loader))
# 每100个batch保存一次紧急检查点
if batch_idx % 100 == 0:
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'best_loss': best_loss,
'logger': logger,
'scheduler_state': scheduler.state_dict()
}
torch.save(checkpoint, checkpoint_path)
# 计算平均损失
num_batches = len(train_loader)
avg_total_loss = epoch_total_loss / num_batches
avg_bin_loss = epoch_bin_loss / num_batches
avg_thresh_loss = epoch_thresh_loss / num_batches
avg_db_loss = epoch_db_loss / num_batches
# 验证模型
precision, recall, f1 = validate_model(model, val_loader, device)
logger.val_metrics['precision'].append(precision)
logger.val_metrics['recall'].append(recall)
logger.val_metrics['f1'].append(f1)
epoch_time = time.time() - epoch_start
print(f"Epoch [{epoch + 1}/{epochs}] completed in {epoch_time:.2f}s")
print(
f" - Avg Loss: {avg_total_loss:.6f} (Bin:{avg_bin_loss:.6f}, Thresh:{avg_thresh_loss:.6f}, DB:{avg_db_loss:.6f})")
print(f" - Val Metrics: Precision={precision:.4f}, Recall={recall:.4f}, F1={f1:.4f}")
logger.on_epoch_end(epoch, optimizer)
# 保存最佳模型
if f1 > best_f1 or (f1 == best_f1 and avg_total_loss < best_loss):
best_f1 = f1
best_loss = avg_total_loss
torch.save({
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': avg_total_loss,
'f1': best_f1
}, 'dbnet_best.pth')
print(f"🔥 发现新的最佳模型! F1: {best_f1:.4f}, 损失: {best_loss:.6f}")
# 保存常规检查点
checkpoint = {
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'best_loss': best_loss,
'logger': logger,
'scheduler_state': scheduler.state_dict()
}
torch.save(checkpoint, checkpoint_path)
except KeyboardInterrupt:
print("\n训练被用户中断!")
except Exception as e:
print(f"\n❌ 训练中断! 原因: {str(e)}")
traceback.print_exc()
finally:
print("训练完成! 保存最终模型...")
torch.save(model.state_dict(), 'dbnet_final.pth')
logger.on_train_end()
return model
# ----------------------------
# 5. 推理与区域裁剪 (增强版)
# ----------------------------
def enhanced_detect_text_regions(image, model, device, threshold=0.3):
# 预处理
orig_h, orig_w = image.shape[:2]
input_img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
input_img = cv2.resize(input_img, (640, 640))
input_img = input_img.astype(np.float32) / 255.0
input_img = (input_img - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
input_tensor = torch.from_numpy(input_img).permute(2, 0, 1).unsqueeze(0).to(device)
input_tensor = input_tensor.to(torch.float32)
# 推理
with torch.no_grad():
binary_map, _ = model(input_tensor)
# 后处理
binary_map = binary_map.squeeze().cpu().numpy()
binary_output = (binary_map > threshold).astype(np.uint8) * 255
# 形态学操作增强
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
binary_output = cv2.morphologyEx(binary_output, cv2.MORPH_CLOSE, kernel)
# 查找轮廓
contours, _ = cv2.findContours(binary_output, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
return contours, orig_h, orig_w # 返回检测到的文本区域轮廓
# ... (保持不变) ...
def perspective_transform(image, contour):
"""对检测到的文本区域进行透视变换校正"""
# 多边形逼近轮廓
epsilon = 0.02 * cv2.arcLength(contour, True)
approx = cv2.approxPolyDP(contour, epsilon, True)
# 确保是四边形
if len(approx) != 4:
# 使用最小外接矩形
rect = cv2.minAreaRect(contour)
box = cv2.boxPoints(rect)
approx = np.int0(box)
# 获取四边形顶点并排序 (左上, 右上, 右下, 左下)
pts = approx.reshape(4, 2)
rect_pts = np.zeros((4, 2), dtype="float32")
# 计算顶点和
s = pts.sum(axis=1)
rect_pts[0] = pts[np.argmin(s)] # 左上
rect_pts[2] = pts[np.argmax(s)] # 右下
# 计算顶点差
diff = np.diff(pts, axis=1)
rect_pts[1] = pts[np.argmin(diff)] # 右上
rect_pts[3] = pts[np.argmax(diff)] # 左下
# 计算目标矩形尺寸
(tl, tr, br, bl) = rect_pts
widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2))
widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2))
maxWidth = max(int(widthA), int(widthB))
heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2))
heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2))
maxHeight = max(int(heightA), int(heightB))
# 目标点坐标
dst = np.array([
[0, 0],
[maxWidth - 1, 0],
[maxWidth - 1, maxHeight - 1],
[0, maxHeight - 1]], dtype="float32")
# 计算透视变换矩阵并应用
M = cv2.getPerspectiveTransform(rect_pts, dst)
warped = cv2.warpPerspective(image, M, (maxWidth, maxHeight))
return warped
def crop_text_regions(image, contours, orig_h, orig_w):
"""裁剪检测到的文本区域并进行透视校正"""
cropped_regions = []
# 计算缩放比例 (从640x640到原始尺寸)
scale_x = orig_w / 640.0
scale_y = orig_h / 640.0
for contour in contours:
# 过滤小区域
if cv2.contourArea(contour) < 100:
continue
# 缩放轮廓到原始图像尺寸
scaled_contour = contour.copy()
scaled_contour[:, :, 0] = scaled_contour[:, :, 0] * scale_x
scaled_contour[:, :, 1] = scaled_contour[:, :, 1] * scale_y
# 获取轮廓边界框
x, y, w, h = cv2.boundingRect(scaled_contour)
# 扩展边界框 (增加10%的边距)
margin_x = int(w * 0.1)
margin_y = int(h * 0.1)
x = max(0, x - margin_x)
y = max(0, y - margin_y)
w = min(orig_w - x, w + 2 * margin_x)
h = min(orig_h - y, h + 2 * margin_y)
# 裁剪区域
roi = image[y:y + h, x:x + w]
# 对裁剪区域进行透视校正
try:
# 调整轮廓坐标到ROI坐标系
roi_contour = scaled_contour.copy()
roi_contour[:, :, 0] -= x
roi_contour[:, :, 1] -= y
# 应用透视变换
warped_roi = perspective_transform(roi, roi_contour)
# 确保最小尺寸
if warped_roi.shape[0] > 10 and warped_roi.shape[1] > 10:
cropped_regions.append(warped_roi)
except Exception as e:
# 如果透视变换失败,使用原始ROI
print(f"透视变换失败: {str(e)},使用原始ROI")
cropped_regions.append(roi)
return cropped_regions
# ----------------------------
# 7. 模型加载与推理接口 (新增功能)
# ----------------------------
def load_trained_model(model_path, device='cuda'):
"""加载训练好的模型"""
model = DBNet(pretrained=False).to(device)
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
return model
# 5. 水表图像增强改进
def water_meter_specific_aug(image, **kwargs):
"""水表专用增强链"""
# 抑制高频反光
kernel_size = int(min(image.shape[:2]) * 0.01)
if kernel_size % 2 == 0:
kernel_size += 1
blurred = cv2.GaussianBlur(image, (kernel_size, kernel_size), 0)
# 自适应直方图均衡
lab = cv2.cvtColor(blurred, cv2.COLOR_RGB2LAB)
l, a, b = cv2.split(lab)
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
l_eq = clahe.apply(l)
# 色偏校正
a_balanced = cv2.normalize(a, None, 0, 255, cv2.NORM_MINMAX)
b_balanced = cv2.normalize(b, None, 0, 255, cv2.NORM_MINMAX)
return cv2.cvtColor(cv2.merge([l_eq, a_balanced, b_balanced]), cv2.COLOR_LAB2RGB)
def suppress_glare(image):
"""减少图像反光区域的影响[1](@ref)"""
lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
l, a, b = cv2.split(lab)
# 对亮度通道进行CLAHE均衡化
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
l_clahe = clahe.apply(l)
# 合并通道
lab_clahe = cv2.merge((l_clahe, a, b))
return cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2RGB)
def detect_and_crop(image_path, model, device='cuda', output_dir='cropped_regions'):
"""使用训练好的模型检测并裁剪水表数字区域"""
# 创建输出目录
os.makedirs(output_dir, exist_ok=True)
# 读取图像
image = cv2.imread(image_path)
if image is None:
print(f"错误: 无法读取图像 {image_path}")
return []
# 应用反光抑制
image = suppress_glare(image)
# 检测文本区域
contours, orig_h, orig_w = enhanced_detect_text_regions(image, model, device)
# 裁剪文本区域
cropped_regions = crop_text_regions(image, contours, orig_h, orig_w)
# 保存结果
base_name = os.path.splitext(os.path.basename(image_path))[0]
for i, region in enumerate(cropped_regions):
output_path = os.path.join(output_dir, f'{base_name}_region_{i}.jpg')
cv2.imwrite(output_path, region)
print(f"成功裁剪 {len(cropped_regions)} 个文本区域到 {output_dir}")
return cropped_regions
# ----------------------------
# 8. 主程序 (优化版)
# ----------------------------
if __name__ == "__main__":
# 优化参数
INPUT_SIZE = (512, 512) # 减小输入尺寸适配水表
# 配置参数
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DATA_DIR = 'images_train'
LABEL_DIR = 'labels_train'
VAL_DATA_DIR = 'images_val'
VAL_LABEL_DIR = 'labels_val'
BATCH_SIZE = 16
EPOCHS = 100
LR = 1e-4
CHECKPOINT_PATH = 'dbnet_checkpoint.pth'
TRAINED_MODEL_PATH = 'dbnet_best.pth'
# 模式选择: 'train' 或 'inference'
MODE = 'train'
if MODE == 'train':
# 1. 准备数据集
print("准备训练数据集...")
train_dataset = WaterMeterDataset(
image_dir=DATA_DIR,
label_dir=LABEL_DIR,
input_size=INPUT_SIZE,
augment=True
)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
print("准备验证数据集...")
val_dataset = WaterMeterDataset(
image_dir=VAL_DATA_DIR,
label_dir=VAL_LABEL_DIR,
input_size=INPUT_SIZE,
augment=False
)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
# 2. 初始化模型
print("初始化模型...")
model = DBNet(pretrained=True).to(DEVICE)
# 3. 损失函数和优化器
# 初始化时使用自适应损失
criterion = DBLoss(alpha=1.0, beta=8.0)
# 使用更先进的优化器
# 1. 强化学习率调度机制 - 更新优化器配置
optimizer = optim.AdamW( # 替换原始Adam
model.parameters(),
lr=3e-4, # 适当提升基础学习率
weight_decay=1e-4
)
# 4. 训练模型
print("开始训练...")
model = enhanced_train_model(
model,
train_loader,
val_loader,
criterion,
optimizer,
DEVICE,
epochs=EPOCHS,
checkpoint_path=CHECKPOINT_PATH,
lr_init=LR
)
print(f"✅ 训练完成! 最佳模型已保存到 {TRAINED_MODEL_PATH}")
elif MODE == 'inference':
# 加载训练好的模型
print(f"加载训练好的模型: {TRAINED_MODEL_PATH}")
model = load_trained_model(TRAINED_MODEL_PATH, DEVICE)
# 处理单个图像
test_image_path = 'test_images/test_1.jpg'
print(f"处理测试图像: {test_image_path}")
detect_and_crop(test_image_path, model, DEVICE)
# 处理整个目录
input_dir = 'test_images'
output_dir = 'cropped_results'
print(f"批量处理目录: {input_dir}")
for img_file in os.listdir(input_dir):
if img_file.lower().endswith(('.jpg', '.png', '.jpeg')):
img_path = os.path.join(input_dir, img_file)
print(f"处理图像: {img_file}")
detect_and_crop(img_path, model, DEVICE, output_dir)
最新发布