array_merge(): Argument #2 is not an array

本文详细解析了微擎中跳转函数二次包装后出现的array_merge()错误,并提供了有效的解决方案,通过强制类型转换避免了非数组参数引起的错误。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

背景:微擎跳转函数二次包装之后出现问题。

 public function success($msg, $url='',$query=array(),$type=1)
    {
        echo json_encode(array('msg'=>$msg,'url'=>webUrl($url,$query),'type'=>$type,'state'=>'success'));
        exit;
    }

注释:weburl函数代码

function webUrl($do = '', $query = array(), $full = true)
    {

....

}

开始:

看着是没问题,但是前段报错说这里的array时报错的。

这是c(控制器的代码)

$this->success('修改成功','',[], 2);

是吧,看着没问题,但是却说

weburl函数里面267,268报错。

array_merge(): Argument #2 is not an array。

然后我上网找,不知所以,但是我自己猜出来问题。

267 $query = array_merge(array('do' => 'web'), $query);
268 $query = array_merge(array('m' => MODEL_NAME), $query);

(老实说我那时并没有找到问题是什么,倒是)

然后去找了https://blog.youkuaiyun.com/wyodyia/article/details/5792864

然后在那个success函数将

public function success($msg, $url='',$query=array(),$type=1)
    {
        echo json_encode(array('msg'=>$msg,'url'=>webUrl($url,(array)$query),'type'=>$type,'state'=>'success'));
        exit;
    }

强制声明$query以后就能了,我之前看到过社么弱关系。之类。

import os os.environ[&#39;TF_CPP_MIN_LOG_LEVEL&#39;] = &#39;3&#39; # 0=INFO, 1=WARNING, 2=ERROR, 3=FATAL os.environ[&#39;TF_ENABLE_ONEDNN_OPTS&#39;] = &#39;0&#39; # 禁用oneDNN日志 import sys import glob import time import json import torch import pickle import shutil import argparse import datetime import torchvision import numpy as np from tqdm import tqdm from PIL import Image import torch.nn as nn from packaging import version from functools import partial import pytorch_lightning as pl from omegaconf import OmegaConf, DictConfig import torch.distributed as dist from typing import List, Dict, Any, Optional, Union, Tuple from ldm.util import instantiate_from_config from pytorch_lightning import seed_everything from pytorch_lightning.trainer import Trainer from torch.utils.data import DataLoader, Dataset from ldm.data.base import Txt2ImgIterableBaseDataset from pytorch_lightning.plugins import DDPPlugin from pytorch_lightning.utilities import rank_zero_info from pytorch_lightning.utilities.distributed import rank_zero_only from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor from torch.cuda.amp import autocast, GradScaler # 模型路径 current_dir = os.path.dirname(os.path.abspath(__file__)) for path in ["download", "download/CLIP", "download/k-diffusion", "download/stable_diffusion", "download/taming-transformers"]: sys.path.append(os.path.join(current_dir, path)) class ConfigManager: """配置管理类,统一处理配置加载和解析""" def __init__(self, config_files: Union[str, List[str]], cli_args: Optional[List[str]] = None): # 将单个字符串路径转换为列表 if isinstance(config_files, str): config_files = [config_files] # 验证配置文件存在 self.configs = [] for cfg in config_files: if not os.path.exists(cfg): raise FileNotFoundError(f"配置文件不存在: {cfg}") self.configs.append(OmegaConf.load(cfg)) # 解析命令行参数 self.cli = OmegaConf.from_dotlist(cli_args) if cli_args else OmegaConf.create() # 合并所有配置 self.config = OmegaConf.merge(*self.configs, self.cli) def get_model_config(self) -> DictConfig: """获取模型配置""" if "model" not in self.config: raise KeyError("配置文件中缺少&#39;model&#39;部分") return self.config.model def get_data_config(self) -> DictConfig: """获取数据配置""" if "data" not in self.config: raise KeyError("配置文件中缺少&#39;data&#39;部分") return self.config.data def get_training_config(self) -> DictConfig: """获取训练配置,提供默认值""" training_config = self.config.get("training", OmegaConf.create()) # 设置默认值 defaults = { "max_epochs": 200, "gpus": torch.cuda.device_count(), "accumulate_grad_batches": 1, "learning_rate": 1e-4, "precision": 32 } for key, value in defaults.items(): if key not in training_config: training_config[key] = value return training_config def get_logging_config(self) -> DictConfig: """获取日志配置""" return self.config.get("logging", OmegaConf.create({"logdir": "logs"})) def get_callbacks_config(self) -> DictConfig: """获取回调函数配置""" return self.config.get("callbacks", OmegaConf.create()) def save_config(self, save_path: str) -> None: """保存配置到文件""" os.makedirs(os.path.dirname(save_path), exist_ok=True) OmegaConf.save(self.config, save_path) print(f"配置已保存到: {save_path}") class DataModuleFromConfig(pl.LightningDataModule): def __init__(self, batch_size, num_workers, train=None, validation=None, test=None): super().__init__() self.batch_size = batch_size self.num_workers = num_workers self.dataset_configs = dict() if train is not None: self.dataset_configs["train"] = train if validation is not None: self.dataset_configs["validation"] = validation if test is not None: self.dataset_configs["test"] = test def setup(self, stage=None): self.datasets = { k: instantiate_from_config(cfg) for k, cfg in self.dataset_configs.items() } def _get_dataloader(self, dataset_name, shuffle=False): dataset = self.datasets.get(dataset_name) if dataset is None: raise ValueError(f"数据集 {dataset_name} 未配置") return DataLoader( dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=shuffle, pin_memory=True ) def train_dataloader(self): return self._get_dataloader("train", shuffle=True) def val_dataloader(self): return self._get_dataloader("validation") def test_dataloader(self): return self._get_dataloader("test") def worker_init_fn(worker_id: int) -> None: """数据加载器工作进程初始化函数""" worker_info = torch.utils.data.get_worker_info() if worker_info is None: return dataset = worker_info.dataset worker_id = worker_info.id if isinstance(dataset, Txt2ImgIterableBaseDataset): # 对可迭代数据集进行分片 split_size = dataset.num_records // worker_info.num_workers dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size] # 设置随机种子 seed = torch.initial_seed() % 2**32 + worker_id np.random.seed(seed) torch.manual_seed(seed) class EnhancedImageLogger(Callback): """增强的图像日志记录器,支持多平台日志输出""" def __init__(self, batch_frequency: int, max_images: int, clamp: bool = True, rescale: bool = True, loggers: Optional[List] = None, log_first_step: bool = False, log_images_kwargs: Optional[Dict] = None): super().__init__() self.batch_frequency = max(1, batch_frequency) self.max_images = max_images self.clamp = clamp self.rescale = rescale self.loggers = loggers or [] self.log_first_step = log_first_step self.log_images_kwargs = log_images_kwargs or {} self.log_steps = [2 ** n for n in range(6, int(np.log2(self.batch_frequency)) + 1)] if self.batch_frequency > 1 else [] def check_frequency(self, step: int) -> bool: """检查是否达到记录频率""" if step == 0 and self.log_first_step: return True if step % self.batch_frequency == 0: return True if step in self.log_steps: if len(self.log_steps) > 0: self.log_steps.pop(0) return True return False def log_images(self, pl_module: pl.LightningModule, batch: Any, step: int, split: str = "train") -> None: """记录图像并发送到所有日志记录器""" if not self.check_frequency(step) or not hasattr(pl_module, "log_images"): return is_train = pl_module.training if is_train: pl_module.eval() # 切换到评估模式 with torch.no_grad(): try: images = pl_module.log_images(batch, split=split, **self.log_images_kwargs) except Exception as e: print(f"记录图像时出错: {e}") images = {} # 处理图像数据 for k in list(images.keys()): if not isinstance(images[k], torch.Tensor): continue N = min(images[k].shape[0], self.max_images) images[k] = images[k][:N] # 分布式环境下收集所有图像 if torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1: images[k] = torch.cat(all_gather(images[k])) images[k] = images[k].detach().cpu() if self.clamp: images[k] = torch.clamp(images[k], -1., 1.) if self.rescale: images[k] = (images[k] + 1.0) / 2.0 # 缩放到[0,1] # 发送到所有日志记录器 for logger in self.loggers: if hasattr(logger, &#39;log_images&#39;): try: logger.log_images(images, step, split) except Exception as e: print(f"日志记录器 {type(logger).__name__} 记录图像失败: {e}") if is_train: pl_module.train() # 恢复训练模式 def on_train_batch_end(self, trainer: Trainer, pl_module: pl.LightningModule, outputs: Any, batch: Any, batch_idx: int) -> None: """训练批次结束时记录图像""" if trainer.global_step % trainer.log_every_n_steps == 0: self.log_images(pl_module, batch, pl_module.global_step, "train") def on_validation_batch_end(self, trainer: Trainer, pl_module: pl.LightningModule, outputs: Any, batch: Any, batch_idx: int) -> None: """验证批次结束时记录图像""" if batch_idx == 0: # 只记录第一个验证批次 self.log_images(pl_module, batch, pl_module.global_step, "val") class TensorBoardLogger: """TensorBoard日志记录器,完整实现PyTorch Lightning日志记录器接口""" def __init__(self, save_dir: str): from torch.utils.tensorboard import SummaryWriter os.makedirs(save_dir, exist_ok=True) self.save_dir = save_dir self.writer = SummaryWriter(save_dir) self._name = "TensorBoard" # 日志记录器名称 self._version = "1.0" # 版本信息 self._experiment = self.writer # 实验对象 print(f"TensorBoard日志保存在: {save_dir}") @property def name(self) -> str: return self._name @property def version(self) -> str: return self._version @property def experiment(self) -> Any: return self._experiment def log_hyperparams(self, params: Dict) -> None: """记录超参数到TensorBoard""" try: # 将嵌套字典展平 flat_params = {} for key, value in params.items(): if isinstance(value, dict): for sub_key, sub_value in value.items(): flat_params[f"{key}/{sub_key}"] = sub_value else: flat_params[key] = value # 记录超参数 self.writer.add_hparams( {k: v for k, v in flat_params.items() if isinstance(v, (int, float, str))}, {}, run_name="." ) print("已记录超参数到TensorBoard") except Exception as e: print(f"记录超参数失败: {e}") def log_graph(self, model: torch.nn.Module, input_array: Optional[torch.Tensor] = None) -> None: """记录模型计算图到TensorBoard""" try: # 扩散模型通常有复杂的前向传播,跳过图记录 print("跳过扩散模型的计算图记录") return except Exception as e: print(f"记录模型计算图失败: {e}") def log_metrics(self, metrics: Dict[str, float], step: int) -> None: """记录指标到TensorBoard""" for name, value in metrics.items(): try: self.writer.add_scalar(name, value, global_step=step) except Exception as e: print(f"添加标量失败: {name}, 错误: {e}") def log_images(self, images: Dict[str, torch.Tensor], step: int, split: str) -> None: """记录图像到TensorBoard""" for k, img in images.items(): if img.numel() == 0: continue try: grid = torchvision.utils.make_grid(img, nrow=min(8, img.shape[0])) self.writer.add_image(f"{split}/{k}", grid, global_step=step) except Exception as e: print(f"添加图像失败: {k}, 错误: {e}") def save(self) -> None: """保存日志(TensorBoard自动保存,这里无需额外操作)""" pass def finalize(self, status: str) -> None: """完成日志记录并关闭写入器""" self.close() def close(self) -> None: """关闭日志写入器""" if hasattr(self, &#39;writer&#39;) and self.writer is not None: self.writer.flush() self.writer.close() self.writer = None print(f"TensorBoard日志已关闭") class TQDMProgressBar(Callback): """使用tqdm显示训练进度,兼容不同版本的PyTorch Lightning""" def __init__(self): self.progress_bar = None self.epoch_bar = None def on_train_start(self, trainer: Trainer, pl_module: pl.LightningModule) -> None: """训练开始时初始化进度条""" # 兼容不同版本的步数估计 total_steps = self._get_total_steps(trainer) self.progress_bar = tqdm( total=total_steps, desc="Training Steps", position=0, leave=True, dynamic_ncols=True ) self.epoch_bar = tqdm( total=trainer.max_epochs, desc="Epochs", position=1, leave=True, dynamic_ncols=True ) def _get_total_steps(self, trainer: Trainer) -> int: """获取训练总步数,兼容不同版本的PyTorch Lightning""" # 尝试使用新版本属性 if hasattr(trainer, &#39;estimated_stepping_batches&#39;): return trainer.estimated_stepping_batches # 尝试使用旧版本属性 if hasattr(trainer, &#39;estimated_steps&#39;): return trainer.estimated_steps # 回退到手动计算 try: if hasattr(trainer, &#39;num_training_batches&#39;): num_batches = trainer.num_training_batches else: num_batches = len(trainer.train_dataloader) if hasattr(trainer, &#39;accumulate_grad_batches&#39;): accumulate = trainer.accumulate_grad_batches else: accumulate = 1 steps_per_epoch = num_batches // accumulate total_steps = trainer.max_epochs * steps_per_epoch print(f"回退计算训练总步数: {total_steps} = {trainer.max_epochs} epochs × {steps_per_epoch} steps/epoch") return total_steps except Exception as e: print(f"无法确定训练总步数: {e}, 使用默认值10000") return 10000 def on_train_batch_end(self, trainer: Trainer, pl_module: pl.LightningModule, outputs: Any, batch: Any, batch_idx: int) -> None: """每个训练批次结束时更新进度条""" if self.progress_bar: # 防止进度条超过总步数 if self.progress_bar.n < self.progress_bar.total: self.progress_bar.update(1) try: # 尝试从输出中获取损失 loss = outputs.get(&#39;loss&#39;) if loss is not None: if isinstance(loss, torch.Tensor): loss = loss.item() self.progress_bar.set_postfix({"loss": loss}) except Exception: pass def on_train_epoch_end(self, trainer: Trainer, pl_module: pl.LightningModule) -> None: """每个训练轮次结束时更新轮次进度条""" if self.epoch_bar: self.epoch_bar.update(1) self.epoch_bar.set_postfix({"epoch": trainer.current_epoch}) def on_train_end(self, trainer: Trainer, pl_module: pl.LightningModule) -> None: """训练结束时关闭进度条""" if self.progress_bar: self.progress_bar.close() if self.epoch_bar: self.epoch_bar.close() class PerformanceMonitor(Callback): """性能监控回调,记录内存使用和训练速度""" def __init__(self): self.epoch_start_time = 0 self.batch_times = [] def on_train_epoch_start(self, trainer: Trainer, pl_module: pl.LightningModule) -> None: """每个训练轮次开始时记录时间和重置内存统计""" self.epoch_start_time = time.time() self.batch_times = [] if torch.cuda.is_available(): torch.cuda.reset_peak_memory_stats() torch.cuda.synchronize() # 修改1:添加dataloader_idx参数 def on_train_batch_start(self, trainer: Trainer, pl_module: pl.LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: """每个训练批次开始时记录时间""" self.batch_start_time = time.time() # 修改2:添加dataloader_idx参数 def on_train_batch_end(self, trainer: Trainer, pl_module: pl.LightningModule, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: """每个训练批次结束时记录时间""" self.batch_times.append(time.time() - self.batch_start_time) def on_train_epoch_end(self, trainer: Trainer, pl_module: pl.LightningModule) -> None: """每个训练轮次结束时计算并记录性能指标""" epoch_time = time.time() - self.epoch_start_time if self.batch_times: avg_batch_time = sum(self.batch_times) / len(self.batch_times) batches_per_second = 1.0 / avg_batch_time else: avg_batch_time = 0 batches_per_second = 0 memory_info = "" if torch.cuda.is_available(): max_memory = torch.cuda.max_memory_allocated() / 2 ** 20 # MiB memory_info = f", 峰值显存: {max_memory:.2f} MiB" rank_zero_info( f"Epoch {trainer.current_epoch} | " f"耗时: {epoch_time:.2f}s | " f"Batch耗时: {avg_batch_time:.4f}s ({batches_per_second:.2f} batches/s)" f"{memory_info}" ) def get_world_size() -> int: """获取分布式训练中的总进程数""" if dist.is_initialized(): return dist.get_world_size() return 1 def all_gather(data: torch.Tensor) -> List[torch.Tensor]: """在分布式环境中收集所有进程的数据""" world_size = get_world_size() if world_size == 1: return [data] # 获取各进程的Tensor大小 local_size = torch.tensor([data.numel()], device=data.device) size_list = [torch.zeros_like(local_size) for _ in range(world_size)] dist.all_gather(size_list, local_size) size_list = [int(size.item()) for size in size_list] max_size = max(size_list) # 收集数据 tensor_list = [] for size in size_list: tensor_list.append(torch.empty((max_size,), dtype=data.dtype, device=data.device)) if local_size < max_size: padding = torch.zeros(max_size - local_size, dtype=data.dtype, device=data.device) data = torch.cat((data.view(-1), padding)) dist.all_gather(tensor_list, data.view(-1)) # 截断到实际大小 results = [] for tensor, size in zip(tensor_list, size_list): results.append(tensor[:size].reshape(data.shape)) return results def create_experiment_directories(logging_config: DictConfig, experiment_name: str) -> Tuple[str, str, str]: """创建实验目录结构""" now = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") logdir = os.path.join(logging_config.logdir, f"{experiment_name}_{now}") ckptdir = os.path.join(logdir, "checkpoints") cfgdir = os.path.join(logdir, "configs") os.makedirs(ckptdir, exist_ok=True) os.makedirs(cfgdir, exist_ok=True) print(f"实验目录: {logdir}") print(f"检查点目录: {ckptdir}") print(f"配置目录: {cfgdir}") return logdir, ckptdir, cfgdir def setup_callbacks(config_manager: ConfigManager, ckptdir: str, tb_logger: TensorBoardLogger) -> List[Callback]: """设置训练回调函数""" callbacks = [] # 模型检查点 checkpoint_callback = ModelCheckpoint( dirpath=ckptdir, filename=&#39;{epoch}-{step}-{val_loss:.2f}&#39;, monitor=&#39;val_loss&#39;, save_top_k=3, mode=&#39;min&#39;, save_last=True, save_on_train_epoch_end=True, # 确保在epoch结束时保存完整状态 save_weights_only=False, # 明确设置为False,保存完整检查点 every_n_train_steps=1000 # 每1000步保存一次 ) callbacks.append(checkpoint_callback) # 学习率监控 lr_monitor = LearningRateMonitor(logging_interval="step") callbacks.append(lr_monitor) # 图像日志记录 image_logger_cfg = config_manager.get_callbacks_config().get("image_logger", {}) image_logger = EnhancedImageLogger( batch_frequency=image_logger_cfg.get("batch_frequency", 500), max_images=image_logger_cfg.get("max_images", 4), loggers=[tb_logger] ) callbacks.append(image_logger) # 进度条 progress_bar = TQDMProgressBar() callbacks.append(progress_bar) # 性能监控 perf_monitor = PerformanceMonitor() callbacks.append(perf_monitor) return callbacks def preprocess_checkpoint(checkpoint_path: str, model: pl.LightningModule) -> Dict[str, Any]: """预处理检查点文件,确保包含所有必要的键,并添加缺失的训练状态""" print(f"预处理检查点文件: {checkpoint_path}") # 加载检查点 try: checkpoint = torch.load(checkpoint_path, map_location="cpu") except Exception as e: print(f"加载检查点失败: {e}") raise # 强制重置训练状态 checkpoint[&#39;epoch&#39;] = 0 checkpoint[&#39;global_step&#39;] = 0 checkpoint[&#39;lr_schedulers&#39;] = [] checkpoint[&#39;optimizer_states&#39;] = [] print("已重置训练状态: epoch=0, global_step=0") # 检查是否缺少关键训练状态 required_keys = [&#39;optimizer_states&#39;, &#39;lr_schedulers&#39;, &#39;epoch&#39;, &#39;global_step&#39;] missing_keys = [k for k in required_keys if k not in checkpoint] if missing_keys: print(f"警告: 检查点缺少训练状态字段 {missing_keys},将创建伪训练状态") # 创建伪训练状态 checkpoint.setdefault(&#39;optimizer_states&#39;, []) checkpoint.setdefault(&#39;lr_schedulers&#39;, []) checkpoint.setdefault(&#39;epoch&#39;, 0) checkpoint.setdefault(&#39;global_step&#39;, 0) # 检查是否缺少 position_ids state_dict = checkpoint.get("state_dict", {}) if "cond_stage_model.transformer.text_model.embeddings.position_ids" not in state_dict: print("警告: 检查点缺少 &#39;cond_stage_model.transformer.text_model.embeddings.position_ids&#39; 键") # 获取模型中的 position_ids 形状 if hasattr(model, "cond_stage_model") and hasattr(model.cond_stage_model, "transformer"): try: max_position_embeddings = model.cond_stage_model.transformer.text_model.config.max_position_embeddings position_ids = torch.arange(max_position_embeddings).expand((1, -1)) state_dict["cond_stage_model.transformer.text_model.embeddings.position_ids"] = position_ids print("已添加 position_ids 到检查点") except Exception as e: print(f"无法添加 position_ids: {e}") # 确保有 state_dict if "state_dict" not in checkpoint: checkpoint["state_dict"] = state_dict return checkpoint # 正确继承原始模型类 from ldm.models.diffusion.ddpm import LatentDiffusion class CustomLatentDiffusion(LatentDiffusion): """自定义 LatentDiffusion 类,处理检查点加载问题""" def on_load_checkpoint(self, checkpoint): """在加载检查点时自动处理缺失的键""" state_dict = checkpoint["state_dict"] # 检查是否缺少 position_ids if "cond_stage_model.transformer.text_model.embeddings.position_ids" not in state_dict: print("警告: 检查点缺少 &#39;cond_stage_model.transformer.text_model.embeddings.position_ids&#39; 键") # 获取模型中的 position_ids 形状 max_position_embeddings = self.cond_stage_model.transformer.text_model.config.max_position_embeddings position_ids = torch.arange(max_position_embeddings).expand((1, -1)) state_dict["cond_stage_model.transformer.text_model.embeddings.position_ids"] = position_ids print("已添加 position_ids 到 state_dict") # 使用非严格模式加载 self.load_state_dict(state_dict, strict=False) print("模型权重加载完成") def filter_kwargs(cls, kwargs, log_prefix=""): # 关键参数白名单 - 这些参数必须保留 ESSENTIAL_PARAMS = { &#39;unet_config&#39;, &#39;first_stage_config&#39;, &#39;cond_stage_config&#39;, &#39;scheduler_config&#39;, &#39;ckpt_path&#39;, &#39;linear_start&#39;, &#39;linear_end&#39; } # 特殊处理:允许所有包含"config"的参数 filtered_kwargs = {} for k, v in kwargs.items(): if k in ESSENTIAL_PARAMS or &#39;config&#39; in k: filtered_kwargs[k] = v else: print(f"{log_prefix}过滤参数: {k}") print(f"{log_prefix}保留参数: {list(filtered_kwargs.keys())}") return filtered_kwargs def check_checkpoint_content(checkpoint_path): """打印检查点包含的键,确认是否有训练状态""" checkpoint = torch.load(checkpoint_path, map_location="cpu") print("检查点包含的键:", list(checkpoint.keys())) if "state_dict" in checkpoint: print("模型权重存在") if "optimizer_states" in checkpoint: print("优化器状态存在") if "epoch" in checkpoint: print(f"保存的epoch: {checkpoint[&#39;epoch&#39;]}") if "global_step" in checkpoint: print(f"保存的global_step: {checkpoint[&#39;global_step&#39;]}") def main() -> None: """主函数,训练和推理流程的入口点""" # 启用Tensor Core加速 torch.set_float32_matmul_precision(&#39;high&#39;) # 解析命令行参数 parser = argparse.ArgumentParser(description="扩散模型训练框架") parser.add_argument("--config", type=str, default="configs/train.yaml", help="配置文件路径") parser.add_argument("--name", type=str, default="experiment", help="实验名称") parser.add_argument("--resume", action="store_true", default=True, help="恢复训练") parser.add_argument("--debug", action="store_true", help="调试模式") parser.add_argument("--seed", type=int, default=42, help="随机种子") parser.add_argument("--scale_lr", action="store_true", help="根据GPU数量缩放学习率") parser.add_argument("--precision", type=str, default="32", choices=["16", "32", "bf16"], help="训练精度") args, unknown = parser.parse_known_args() # 设置随机种子 seed_everything(args.seed, workers=True) print(f"设置随机种子: {args.seed}") # 初始化配置管理器 try: config_manager = ConfigManager(args.config, unknown) config = config_manager.config except Exception as e: print(f"加载配置失败: {e}") sys.exit(1) # 创建日志目录 logging_config = config_manager.get_logging_config() logdir, ckptdir, cfgdir = create_experiment_directories(logging_config, args.name) # 保存配置 config_manager.save_config(os.path.join(cfgdir, "config.yaml")) # 配置日志记录器 tb_logger = TensorBoardLogger(os.path.join(logdir, "tensorboard")) # 配置回调函数 callbacks = setup_callbacks(config_manager, ckptdir, tb_logger) # 初始化数据模块 try: print("初始化数据模块...") data_config = config_manager.get_data_config() data_module = instantiate_from_config(data_config) data_module.setup() print("可用数据集:", list(data_module.datasets.keys())) except Exception as e: print(f"数据模块初始化失败: {str(e)}") return # 创建模型 try: model_config = config_manager.get_model_config() model_params = model_config.get("params", {}) # 创建模型实例 model = CustomLatentDiffusion(**model_config.get("params", {})) print("模型初始化成功") # 检查并转换预训练权重 ckpt_path = model_config.params.get("ckpt_path", "") if ckpt_path and os.path.exists(ckpt_path): print(f"加载预训练权重: {ckpt_path}") checkpoint = torch.load(ckpt_path, map_location="cpu") state_dict = checkpoint.get("state_dict", checkpoint) # 查找所有与conv_in.weight相关的键 conv_in_keys = [] for key in state_dict.keys(): if "conv_in.weight" in key and "first_stage_model" in key: conv_in_keys.append(key) # 转换找到的权重 for conv_in_key in conv_in_keys: if state_dict[conv_in_key].shape[1] == 3: # 原始是3通道 print(f"转换权重: {conv_in_key} 从3通道到1通道") # 取RGB三通道的平均值作为单通道权重 rgb_weights = state_dict[conv_in_key] ir_weights = rgb_weights.mean(dim=1, keepdim=True) state_dict[conv_in_key] = ir_weights print(f"转换前形状: {rgb_weights.shape}") print(f"转换后形状: {ir_weights.shape}") print(f"模型层形状: {model.first_stage_model.encoder.conv_in.weight.shape}") # 非严格模式加载(允许其他层不匹配) missing, unexpected = model.load_state_dict(state_dict, strict=False) print(f"权重加载完成: 缺失层 {len(missing)}, 不匹配层 {len(unexpected)}") if missing: print("缺失层:", missing) if unexpected: print("意外层:", unexpected) except Exception as e: print(f"模型初始化失败: {str(e)}") return print("VAE输入层形状:", model.first_stage_model.encoder.conv_in.weight.shape) # 权重转换 if ckpt_path and os.path.exists(ckpt_path): print(f"加载预训练权重: {ckpt_path}") checkpoint = torch.load(ckpt_path, map_location="cpu") state_dict = checkpoint.get("state_dict", checkpoint) # 增强:查找所有需要转换的层(包括可能的变体) conversion_keys = [] for key in state_dict.keys(): if "conv_in" in key or "conv_out" in key or "nin_shortcut" in key: if state_dict[key].ndim == 4 and state_dict[key].shape[1] == 3: conversion_keys.append(key) print(f"找到需要转换的层: {conversion_keys}") # 转换权重 for key in conversion_keys: print(f"转换权重: {key}") print(f"原始形状: {state_dict[key].shape}") # RGB权重 [out_c, in_c=3, kH, kW] rgb_weights = state_dict[key] # 转换为单通道权重 [out_c, 1, kH, kW] if rgb_weights.shape[1] == 3: ir_weights = rgb_weights.mean(dim=1, keepdim=True) state_dict[key] = ir_weights print(f"转换后形状: {state_dict[key].shape}") # 加载转换后的权重 try: # 使用非严格模式加载 missing, unexpected = model.load_state_dict(state_dict, strict=False) print(f"权重加载完成: 缺失层 {len(missing)}, 不匹配层 {len(unexpected)}") # 打印重要信息 if missing: print("缺失层:", missing[:5]) # 只显示前5个避免过多输出 if unexpected: print("意外层:", unexpected[:5]) # 特别检查conv_in层 if "first_stage_model.encoder.conv_in.weight" in missing: print("警告: conv_in.weight未加载,需要手动初始化") # 手动初始化单通道卷积层 with torch.no_grad(): model.first_stage_model.encoder.conv_in.weight.data.normal_(mean=0.0, std=0.02) print("已手动初始化conv_in.weight") except RuntimeError as e: print(f"加载权重时出错: {e}") print("尝试仅加载兼容的权重...") # 创建新的状态字典只包含兼容的键 model_state = model.state_dict() compatible_dict = {} for k, v in state_dict.items(): if k in model_state and v.shape == model_state[k].shape: compatible_dict[k] = v # 加载兼容的权重 model.load_state_dict(compatible_dict, strict=False) print(f"部分权重加载完成: {len(compatible_dict)}/{len(state_dict)}") # 配置学习率 training_config = config_manager.get_training_config() bs = data_config.params.batch_size base_lr = model_config.base_learning_rate ngpu = training_config.get("gpus", 1) accumulate_grad_batches = training_config.get("accumulate_grad_batches", 1) if args.scale_lr: model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr print(f"学习率缩放至: {model.learning_rate:.2e} = {accumulate_grad_batches} × {ngpu} × {bs} × {base_lr:.2e}") else: model.learning_rate = base_lr print(f"使用基础学习率: {model.learning_rate:.2e}") # 检查是否恢复训练 resume_from_checkpoint = None if args.resume: # 优先使用自动保存的last.ckpt last_ckpt = os.path.join(ckptdir, "last.ckpt") if os.path.exists(last_ckpt): print(f"恢复训练状态: {last_ckpt}") resume_from_checkpoint = last_ckpt else: # 回退到指定检查点 fallback_ckpt = os.path.join(current_dir, "checkpoints", "M3FD.ckpt") if os.path.exists(fallback_ckpt): print(f"警告: 使用仅含权重的检查点,训练状态将重置: {fallback_ckpt}") resume_from_checkpoint = fallback_ckpt else: print("未找到可用的检查点,从头开始训练") # 如果需要恢复训练,预处理检查点 if resume_from_checkpoint and os.path.exists(resume_from_checkpoint): try: # 预处理检查点 - 添加缺失的状态 checkpoint = preprocess_checkpoint(resume_from_checkpoint, model) # 创建新的完整检查点文件 fixed_ckpt_path = os.path.join(ckptdir, "fixed_checkpoint.ckpt") torch.save(checkpoint, fixed_ckpt_path) print(f"修复后的完整检查点已保存到: {fixed_ckpt_path}") # 使用修复后的检查点 resume_from_checkpoint = fixed_ckpt_path except Exception as e: print(f"预处理检查点失败: {e}") print("将尝试使用默认方式加载检查点") # 配置日志记录器 tb_logger = TensorBoardLogger(os.path.join(logdir, "tensorboard")) # 配置回调函数 callbacks = setup_callbacks(config_manager, ckptdir, tb_logger) # 检查是否有验证集 has_validation = hasattr(data_module, &#39;datasets&#39;) and &#39;validation&#39; in data_module.datasets # 计算训练批次数 try: train_loader = data_module.train_dataloader() num_train_batches = len(train_loader) print(f"训练批次数: {num_train_batches}") except Exception as e: print(f"计算训练批次数失败: {e}") num_train_batches = 0 # 设置训练器参数(先设置基础参数) trainer_config = { "default_root_dir": logdir, "max_epochs": training_config.max_epochs, "gpus": ngpu, "distributed_backend": "ddp" if ngpu > 1 else None, "plugins": [DDPPlugin(find_unused_parameters=False)] if ngpu > 1 else None, "precision": 16, "accumulate_grad_batches": accumulate_grad_batches, "callbacks": callbacks, "logger": tb_logger, # 添加日志记录器 "resume_from_checkpoint": resume_from_checkpoint, "fast_dev_run": args.debug, "limit_val_batches": 0 if not has_validation else 1.0, "num_sanity_val_steps": 0, # 跳过初始验证加速恢复 "log_every_n_steps": 10 # 更频繁的日志记录 } # 动态调整验证配置 if has_validation: if num_train_batches < 50: # 小数据集:使用epoch验证 trainer_config["check_val_every_n_epoch"] = 1 # 确保移除步数验证参数 if "val_check_interval" in trainer_config: del trainer_config["val_check_interval"] else: # 大数据集:使用步数验证 val_check_interval = min(2000, num_train_batches) if num_train_batches < 100: val_check_interval = max(1, num_train_batches // 4) trainer_config["val_check_interval"] = val_check_interval # 创建训练器 try: print("最终训练器配置:") for k, v in trainer_config.items(): print(f" {k}: {v}") trainer = Trainer(**trainer_config) except Exception as e: print(f"创建训练器失败: {e}") tb_logger.close() sys.exit(1) # 执行训练 try: print("开始训练...") trainer.fit(model, data_module) print("训练完成!") except KeyboardInterrupt: print("训练被用户中断") if trainer.global_rank == 0 and trainer.model is not None: trainer.save_checkpoint(os.path.join(ckptdir, "interrupted.ckpt")) except Exception as e: print(f"训练出错: {e}") if trainer.global_rank == 0 and hasattr(trainer, &#39;model&#39;) and trainer.model is not None: trainer.save_checkpoint(os.path.join(ckptdir, "error.ckpt")) raise finally: # 关闭日志记录器 tb_logger.close() # 打印性能分析报告 if trainer.global_rank == 0 and hasattr(trainer, &#39;profiler&#39;): print("训练摘要:") print(trainer.profiler.summary()) if __name__ == "__main__": main()运行报错:模型初始化失败: Error(s) in loading state_dict for CustomLatentDiffusion: size mismatch for first_stage_model.encoder.conv_in.weight: copying a param with shape torch.Size([128, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 1, 3, 3]).
最新发布
07-04
import traceback import cv2 import json import os import sys import time import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from albumentations import ( Compose, Resize, Normalize, HorizontalFlip, VerticalFlip, Rotate, OneOf, RandomBrightnessContrast, GaussNoise, ElasticTransform, RandomGamma, HueSaturationValue, CoarseDropout, Perspective, KeypointParams, CLAHE, MotionBlur, ISONoise,Lambda ) from albumentations.pytorch import ToTensorV2 from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, StepLR, ReduceLROnPlateau from torch.utils.data import Dataset, DataLoader from torchvision.models import resnet18, ResNet18_Weights from sklearn.metrics import precision_score, recall_score, f1_score import matplotlib.pyplot as plt from tqdm import tqdm # 添加进度条库 class EnhancedTrainingLogger: """增强的训练日志记录器,跟踪多种损失指标并实时可视化""" def __init__(self): self.total_losses = [] self.bin_losses = [] self.thresh_losses = [] self.db_losses = [] self.timestamps = [] self.start_time = time.time() self.lr_history = [] self.val_metrics = {&#39;precision&#39;: [], &#39;recall&#39;: [], &#39;f1&#39;: []} # 实时可视化设置 plt.ion() # 开启交互模式 self.fig, self.axs = plt.subplots(2, 2, figsize=(15, 10)) self.fig.suptitle(&#39;Training Progress&#39;, fontsize=16) # 初始化图表 self.loss_line, = self.axs[0, 0].plot([], [], &#39;r-&#39;, label=&#39;Total Loss&#39;) self.bin_line, = self.axs[0, 0].plot([], [], &#39;g-&#39;, label=&#39;Binary Loss&#39;) self.thresh_line, = self.axs[0, 0].plot([], [], &#39;b-&#39;, label=&#39;Threshold Loss&#39;) self.db_line, = self.axs[0, 0].plot([], [], &#39;m-&#39;, label=&#39;DB Loss&#39;) self.axs[0, 0].set_title(&#39;Training Loss Components&#39;) self.axs[0, 0].set_xlabel(&#39;Batch&#39;) self.axs[0, 0].set_ylabel(&#39;Loss&#39;) self.axs[0, 0].legend() self.axs[0, 0].grid(True) self.lr_line, = self.axs[0, 1].plot([], [], &#39;c-&#39;) self.axs[0, 1].set_title(&#39;Learning Rate Schedule&#39;) self.axs[0, 1].set_xlabel(&#39;Batch&#39;) self.axs[0, 1].set_ylabel(&#39;Learning Rate&#39;) self.axs[0, 1].grid(True) self.precision_line, = self.axs[1, 0].plot([], [], &#39;r-&#39;, label=&#39;Precision&#39;) self.recall_line, = self.axs[1, 0].plot([], [], &#39;g-&#39;, label=&#39;Recall&#39;) self.f1_line, = self.axs[1, 0].plot([], [], &#39;b-&#39;, label=&#39;F1 Score&#39;) self.axs[1, 0].set_title(&#39;Validation Metrics&#39;) self.axs[1, 0].set_xlabel(&#39;Epoch&#39;) self.axs[1, 0].set_ylabel(&#39;Score&#39;) self.axs[1, 0].legend() self.axs[1, 0].grid(True) # 添加文本区域显示当前指标 self.metrics_text = self.axs[1, 1].text(0.5, 0.5, "", horizontalalignment=&#39;center&#39;, verticalalignment=&#39;center&#39;, transform=self.axs[1, 1].transAxes, fontsize=12) self.axs[1, 1].axis(&#39;off&#39;) # 关闭坐标轴 plt.tight_layout() plt.subplots_adjust(top=0.9) plt.draw() plt.pause(0.1) def on_batch_end(self, batch_idx, total_loss, bin_loss, thresh_loss, db_loss, lr=None): elapsed = time.time() - self.start_time self.total_losses.append(total_loss) self.bin_losses.append(bin_loss) self.thresh_losses.append(thresh_loss) self.db_losses.append(db_loss) self.timestamps.append(elapsed) if lr is not None: self.lr_history.append(lr) # 更新实时图表 self.update_plots(batch_idx) # 每10个batch打印详细日志 if batch_idx % 10 == 0: avg_total = np.mean(self.total_losses[-10:]) if len(self.total_losses) >= 10 else total_loss avg_bin = np.mean(self.bin_losses[-10:]) if len(self.bin_losses) >= 10 else bin_loss avg_thresh = np.mean(self.thresh_losses[-10:]) if len(self.thresh_losses) >= 10 else thresh_loss avg_db = np.mean(self.db_losses[-10:]) if len(self.db_losses) >= 10 else db_loss # 更新文本区域 metrics_text = ( f"Batch: {batch_idx}\n" f"Total Loss: {total_loss:.4f} (Avg10: {avg_total:.4f})\n" f"Binary Loss: {bin_loss:.4f} (Avg10: {avg_bin:.4f})\n" f"Threshold Loss: {thresh_loss:.4f} (Avg10: {avg_thresh:.4f})\n" f"DB Loss: {db_loss:.4f} (Avg10: {avg_db:.4f})\n" f"Learning Rate: {lr:.2e}\n" f"Time: {int(elapsed // 3600):02d}:{int((elapsed % 3600) // 60):02d}:{int(elapsed % 60):02d}" ) self.metrics_text.set_text(metrics_text) # 刷新图表 plt.draw() plt.pause(0.01) def update_plots(self, batch_idx): # 更新损失图表 x_data = np.arange(len(self.total_losses)) self.loss_line.set_data(x_data, self.total_losses) self.bin_line.set_data(x_data, self.bin_losses) self.thresh_line.set_data(x_data, self.thresh_losses) self.db_line.set_data(x_data, self.db_losses) # 自动调整Y轴范围 all_losses = self.total_losses + self.bin_losses + self.thresh_losses + self.db_losses if all_losses: min_loss = min(all_losses) * 0.9 max_loss = max(all_losses) * 1.1 self.axs[0, 0].set_ylim(min_loss, max_loss) # 更新学习率图表 if self.lr_history: self.lr_line.set_data(np.arange(len(self.lr_history)), self.lr_history) self.axs[0, 1].set_ylim(min(self.lr_history) * 0.9, max(self.lr_history) * 1.1) # 更新验证指标图表 if self.val_metrics[&#39;precision&#39;]: x_epochs = np.arange(len(self.val_metrics[&#39;precision&#39;])) self.precision_line.set_data(x_epochs, self.val_metrics[&#39;precision&#39;]) self.recall_line.set_data(x_epochs, self.val_metrics[&#39;recall&#39;]) self.f1_line.set_data(x_epochs, self.val_metrics[&#39;f1&#39;]) # 自动调整Y轴范围 all_metrics = self.val_metrics[&#39;precision&#39;] + self.val_metrics[&#39;recall&#39;] + self.val_metrics[&#39;f1&#39;] if all_metrics: min_metric = min(all_metrics) * 0.9 max_metric = max(all_metrics) * 1.1 self.axs[1, 0].set_ylim(min_metric, max_metric) # 调整X轴范围 self.axs[0, 0].set_xlim(0, max(1, len(self.total_losses))) self.axs[0, 1].set_xlim(0, max(1, len(self.lr_history))) if self.val_metrics[&#39;precision&#39;]: self.axs[1, 0].set_xlim(0, max(1, len(self.val_metrics[&#39;precision&#39;]))) def on_epoch_end(self, epoch, optimizer=None): # 添加空列表检查 total_min = min(self.total_losses) if self.total_losses else 0.0 total_max = max(self.total_losses) if self.total_losses else 0.0 total_avg = np.mean(self.total_losses) if self.total_losses else 0.0 bin_min = min(self.bin_losses) if self.bin_losses else 0.0 bin_avg = np.mean(self.bin_losses) if self.bin_losses else 0.0 thresh_min = min(self.thresh_losses) if self.thresh_losses else 0.0 thresh_avg = np.mean(self.thresh_losses) if self.thresh_losses else 0.0 db_min = min(self.db_losses) if self.db_losses else 0.0 db_avg = np.mean(self.db_losses) if self.db_losses else 0.0 # 生成详细的损失报告 report = ( f"\n{&#39;=&#39; * 70}\n" f"EPOCH {epoch + 1} SUMMARY:\n" f" - Total Loss: Min={total_min:.6f}, Max={total_max:.6f}, Avg={total_avg:.6f}\n" f" - Binary Loss: Min={bin_min:.6f}, Avg={bin_avg:.6f}\n" f" - Threshold Loss: Min={thresh_min:.6f}, Avg={thresh_avg:.6f}\n" f" - DB Loss: Min={db_min:.6f}, Avg={db_avg:.6f}\n" ) if self.val_metrics[&#39;precision&#39;]: report += ( f" - Val Metrics: Precision={self.val_metrics[&#39;precision&#39;][-1]:.4f}, " f"Recall={self.val_metrics[&#39;recall&#39;][-1]:.4f}, F1={self.val_metrics[&#39;f1&#39;][-1]:.4f}\n" ) if optimizer: report += f" - Learning Rate: {optimizer.param_groups[0][&#39;lr&#39;]:.6e}\n" report += f"{&#39;=&#39; * 70}" print(report) # 保存CSV日志 with open(f&#39;training_log_epoch_{epoch + 1}.csv&#39;, &#39;w&#39;) as f: f.write("Timestamp,Total_Loss,Bin_Loss,Thresh_Loss,DB_Loss,Learning_Rate\n") for i, t in enumerate(self.timestamps): lr_val = self.lr_history[i] if i < len(self.lr_history) else 0 f.write( f"{t:.2f},{self.total_losses[i]:.6f},{self.bin_losses[i]:.6f},{self.thresh_losses[i]:.6f},{self.db_losses[i]:.6f},{lr_val:.6e}\n") # 重置记录(保留最后一个批次的值) self.total_losses = [self.total_losses[-1]] if self.total_losses else [] self.bin_losses = [self.bin_losses[-1]] if self.bin_losses else [] self.thresh_losses = [self.thresh_losses[-1]] if self.thresh_losses else [] self.db_losses = [self.db_losses[-1]] if self.db_losses else [] self.timestamps = [self.timestamps[-1]] if self.timestamps else [] self.lr_history = [self.lr_history[-1]] if self.lr_history else [] # 更新图表 self.update_plots(0) plt.draw() plt.pause(0.1) def on_train_end(self): """训练结束后生成图表并保存""" plt.ioff() # 关闭交互模式 # 保存最终图表 plt.savefig(&#39;training_summary.png&#39;) plt.close() # 生成详细的训练报告图 self.generate_detailed_report() def generate_detailed_report(self): """生成详细的训练报告图表""" fig, axs = plt.subplots(3, 1, figsize=(12, 15)) # 损失图表 axs[0].plot(self.total_losses, label=&#39;Total Loss&#39;) axs[0].plot(self.bin_losses, label=&#39;Binary Loss&#39;) axs[0].plot(self.thresh_losses, label=&#39;Threshold Loss&#39;) axs[0].plot(self.db_losses, label=&#39;DB Loss&#39;) axs[0].set_title(&#39;Training Loss Components&#39;) axs[0].set_xlabel(&#39;Batch&#39;) axs[0].set_ylabel(&#39;Loss&#39;) axs[0].legend() axs[0].grid(True) # 学习率图表 axs[1].plot(self.lr_history) axs[1].set_title(&#39;Learning Rate Schedule&#39;) axs[1].set_xlabel(&#39;Batch&#39;) axs[1].set_ylabel(&#39;Learning Rate&#39;) axs[1].grid(True) # 验证指标图表 if self.val_metrics[&#39;precision&#39;]: axs[2].plot(self.val_metrics[&#39;precision&#39;], &#39;o-&#39;, label=&#39;Precision&#39;) axs[2].plot(self.val_metrics[&#39;recall&#39;], &#39;o-&#39;, label=&#39;Recall&#39;) axs[2].plot(self.val_metrics[&#39;f1&#39;], &#39;o-&#39;, label=&#39;F1 Score&#39;) axs[2].set_title(&#39;Validation Metrics&#39;) axs[2].set_xlabel(&#39;Epoch&#39;) axs[2].set_ylabel(&#39;Score&#39;) axs[2].legend() axs[2].grid(True) # 标记最佳F1分数 best_f1_idx = np.argmax(self.val_metrics[&#39;f1&#39;]) best_f1 = self.val_metrics[&#39;f1&#39;][best_f1_idx] axs[2].plot(best_f1_idx, best_f1, &#39;ro&#39;, markersize=8) axs[2].annotate(f&#39;Best F1: {best_f1:.4f}&#39;, xy=(best_f1_idx, best_f1), xytext=(best_f1_idx + 0.5, best_f1 - 0.05), arrowprops=dict(facecolor=&#39;black&#39;, shrink=0.05)) plt.tight_layout() plt.savefig(&#39;training_detailed_report.png&#39;) plt.close() # 在类外部定义全局函数 # 在类外部定义全局函数 def suppress_water_meter_glare(img, **kwargs): """水表专用反光抑制(忽略额外参数)""" lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB) l, a, b = cv2.split(lab) # 动态计算CLAHE参数 l_mean = np.mean(l) clip_limit = 2.0 + (l_mean / 40) # 亮度越高,clipLimit越大 clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=(8, 8)) l_clahe = clahe.apply(l) # 选择性增强暗部区域 _, mask = cv2.threshold(l, 100, 255, cv2.THRESH_BINARY_INV) blended = cv2.addWeighted(l, 0.7, l_clahe, 0.3, 0) l_final = np.where(mask > 0, blended, l) lab = cv2.merge((l_final, a, b)) return cv2.cvtColor(lab, cv2.COLOR_LAB2RGB) # ---------------------------- # 1. 数据集加载与预处理 (优化浮点坐标处理) # ---------------------------- class WaterMeterDataset(Dataset): """水表数字区域检测数据集 - 优化浮点坐标处理""" # ... (初始化代码保持不变) ... def __init__(self, image_dir, label_dir, input_size=(640, 640), augment=True): self.image_dir = image_dir self.label_dir = label_dir self.input_size = input_size self.augment = augment self.image_files = [f for f in os.listdir(image_dir) if os.path.isfile(os.path.join(image_dir, f))] # 基础预处理流程 self.base_transform = Compose([ Resize(height=input_size[0], width=input_size[1]), Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ToTensorV2() ]) # 简化但有效的数据增强 self.augmentation = Compose([ # 水表专用增强 OneOf([ # 模拟不同角度拍摄 Perspective(scale=(0.05, 0.1), p=0.3), # 模拟水表玻璃反光 RandomGamma(gamma_limit=(80, 120), p=0.2), # 模拟水表污渍 CoarseDropout(max_holes=5, max_height=20, max_width=20, fill_value=0, p=0.2) ], p=0.8), # 水表反光抑制 Lambda(name=&#39;glare_reduction&#39;, image=suppress_water_meter_glare), Lambda(name=&#39;water_meter_aug&#39;, image=water_meter_specific_aug, p=0.7), OneOf([ HorizontalFlip(p=0.3), VerticalFlip(p=0.2), Rotate(limit=15, p=0.5) ], p=0.7), OneOf([ RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5), CLAHE(clip_limit=2.0, p=0.3), GaussNoise(std_range=(0.15, 0.4), # 优化后范围 mean_range=(0, 0), per_channel=True, p=0.3), ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.5)) ], p=0.7) ], p=0.8, keypoint_params=KeypointParams(format=&#39;xyas&#39;)) if augment else None def __len__(self): return len(self.image_files) def __getitem__(self, idx): img_name = self.image_files[idx] img_path = os.path.join(self.image_dir, img_name) # 加载图像 image = cv2.imread(img_path) if image is None: print(f"错误: 无法读取图像 {img_path}") return self[(idx + 1) % len(self)] # 跳过错误图像 image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # 应用反光抑制 if np.random.rand() > 0.5: lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB) l, a, b = cv2.split(lab) clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8)) l = clahe.apply(l) lab = cv2.merge([l, a, b]) image = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB) # 解析标注 base_name = os.path.splitext(img_name)[0] label_path = os.path.join(self.label_dir, base_name + &#39;.json&#39;) try: with open(label_path) as f: label_data = json.load(f) polygons = [] orig_h, orig_w = image.shape[:2] # 获取标注时的图像尺寸(如果存在) json_h = label_data.get(&#39;imageHeight&#39;, orig_h) json_w = label_data.get(&#39;imageWidth&#39;, orig_w) # 计算缩放比例(处理不同尺寸的标注) scale_x = orig_w / json_w scale_y = orig_h / json_h for shape in label_data[&#39;shapes&#39;]: if shape[&#39;shape_type&#39;] == &#39;polygon&#39;: # 直接使用浮点坐标,避免整数转换 poly = np.array(shape[&#39;points&#39;], dtype=np.float32) # 应用缩放比例 poly[:, 0] = poly[:, 0] * scale_x poly[:, 1] = poly[:, 1] * scale_y # 裁剪到实际图像范围内 poly[:, 0] = np.clip(poly[:, 0], 0, orig_w - 1) poly[:, 1] = np.clip(poly[:, 1], 0, orig_h - 1) polygons.append(poly) # 生成目标前验证标注有效性 if len(polygons) == 0: print(f"警告: {img_name} 无有效标注,使用随机样本替代") return self[np.random.randint(0, len(self))] # === 调试可视化 === if idx < 5: debug_img = image.copy() for poly in polygons: int_poly = poly.astype(np.int32).reshape(-1, 1, 2) cv2.polylines(debug_img, [int_poly], True, (0, 255, 0), 3) debug_info = f"Size: {orig_w}x{orig_h} | Polys: {len(polygons)}" cv2.putText(debug_img, debug_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2) debug_path = f"debug_{base_name}.jpg" cv2.imwrite(debug_path, cv2.cvtColor(debug_img, cv2.COLOR_RGB2BGR)) print(f"保存调试图像: {debug_path}") # 应用数据增强 keypoints = [] for poly in polygons: for point in poly: # 保留浮点精度 keypoints.append((point[0], point[1], 0, 0)) if self.augment and self.augmentation: poly_lengths = [len(poly) for poly in polygons] # 应用增强 augmented = self.augmentation(image=image, keypoints=keypoints) image = augmented[&#39;image&#39;] keypoints = augmented[&#39;keypoints&#39;] # 正确重组多边形 polygons = [] start_idx = 0 for poly_len in poly_lengths: end_idx = start_idx + poly_len if end_idx <= len(keypoints): poly_points = keypoints[start_idx:end_idx] new_poly = np.array([[p[0], p[1]] for p in poly_points], dtype=np.float32) polygons.append(new_poly) start_idx = end_idx # 对所有多边形进行边界裁剪 for poly in polygons: poly[:, 0] = np.clip(poly[:, 0], 0, image.shape[1] - 1) poly[:, 1] = np.clip(poly[:, 1], 0, image.shape[0] - 1) except (FileNotFoundError, json.JSONDecodeError) as e: print(f"警告: 无法加载标注文件 {label_path} - {str(e)}") polygons = [] # 记录数据增强后的图像尺寸 aug_h, aug_w = image.shape[:2] # 基础预处理(包含Resize) processed = self.base_transform(image=image) image_tensor = processed[&#39;image&#39;] # 将多边形坐标缩放到input_size scale_x = self.input_size[1] / aug_w scale_y = self.input_size[0] / aug_h scaled_polygons = [] for poly in polygons: scaled_poly = poly.copy() scaled_poly[:, 0] = scaled_poly[:, 0] * scale_x scaled_poly[:, 1] = scaled_poly[:, 1] * scale_y scaled_poly[:, 0] = np.clip(scaled_poly[:, 0], 0, self.input_size[1] - 1) scaled_poly[:, 1] = np.clip(scaled_poly[:, 1], 0, self.input_size[0] - 1) scaled_polygons.append(scaled_poly) # 生成目标(使用input_size尺寸) binary_target = self.generate_binary_target(scaled_polygons, (self.input_size[0], self.input_size[1])) threshold_target = self.generate_threshold_target(scaled_polygons, (self.input_size[0], self.input_size[1])) return image_tensor, binary_target, threshold_target def generate_threshold_target(self, polygons, img_shape, ratio=0.4): """生成阈值目标图(优化浮点坐标处理)""" # 定义输出尺寸(特征图尺寸) output_size = (self.input_size[0] // 8, self.input_size[1] // 8) # 创建全尺寸距离图 full_size_map = np.zeros(img_shape[:2], dtype=np.float32) for poly in polygons: if len(poly) < 3: continue # 确保坐标在图像范围内 poly[:, 0] = np.clip(poly[:, 0], 0, img_shape[1] - 1) poly[:, 1] = np.clip(poly[:, 1], 0, img_shape[0] - 1) # 计算最大距离(防止除零错误) area = cv2.contourArea(poly) perimeter = cv2.arcLength(poly, True) if perimeter < 1e-3 or area < 10: continue max_dist = area * (1 - ratio ** 2) / max(perimeter, 1e-3) # 创建浮点精度的多边形掩码 mask = np.zeros(img_shape[:2], dtype=np.uint8) int_poly = poly.reshape((-1, 1, 2)).astype(np.int32) cv2.fillPoly(mask, [int_poly], 255) # 计算距离变换并更新全尺寸图 dist = cv2.distanceTransform(mask, cv2.DIST_L2, 3) normalized = np.clip(dist / max(max_dist, 1e-6), 0, 1) full_size_map = np.maximum(full_size_map, normalized) # 下采样到特征图尺寸 dist_map = cv2.resize(full_size_map, output_size, interpolation=cv2.INTER_LINEAR) # 空目标检查 if np.max(dist_map) < 1e-6: return torch.zeros((1, *output_size), dtype=torch.float32) return torch.from_numpy(dist_map).unsqueeze(0).float() def generate_binary_target(self, polygons, img_shape): """生成二值化目标图(优化浮点坐标处理)""" # 直接在目标尺寸上创建 output_size = (self.input_size[0] // 8, self.input_size[1] // 8) binary_map = np.zeros(output_size, dtype=np.float32) # 计算缩放比例 (原始图像->特征图) scale_x = output_size[1] / img_shape[1] scale_y = output_size[0] / img_shape[0] for poly in polygons: if len(poly) > 2: # 缩放多边形到特征图尺寸(保持浮点精度) scaled_poly = poly.copy() scaled_poly[:, 0] = scaled_poly[:, 0] * scale_x scaled_poly[:, 1] = scaled_poly[:, 1] * scale_y # 使用浮点坐标填充(更精确) int_poly = scaled_poly.reshape((-1, 1, 2)).astype(np.float32) # 创建临时画布进行填充 temp_canvas = np.zeros(output_size, dtype=np.uint8) cv2.fillPoly(temp_canvas, [int_poly.astype(np.int32)], 1) binary_map = np.maximum(binary_map, temp_canvas.astype(np.float32)) return torch.from_numpy(binary_map).unsqueeze(0).float() # ---------------------------- # 2. DBNet模型定义 (增强版) # ---------------------------- class DBNet(nn.Module): """基于ResNet18的DBNet文本检测模型""" def __init__(self, pretrained=True): super(DBNet, self).__init__() base_model = resnet18(weights=ResNet18_Weights.DEFAULT) # 提取中间特征层 self.conv1 = base_model.conv1 self.bn1 = base_model.bn1 self.relu = base_model.relu self.maxpool = base_model.maxpool self.layer1 = base_model.layer1 self.layer2 = base_model.layer2 self.layer3 = base_model.layer3 self.layer4 = base_model.layer4 # 特征融合层 self.fusion_conv = nn.Sequential( nn.Conv2d(512, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.Conv2d(128, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True) ) # 检测头 self.db_head = DBHead(64) def forward(self, x): # 骨干网络前向传播 x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) # 特征融合 fused = self.fusion_conv(x) # 检测头 binary_map, thresh_map = self.db_head(fused) return binary_map, thresh_map class DBHead(nn.Module): """DBNet检测头,包含注意力机制和残差连接""" def __init__(self, in_channels): super(DBHead, self).__init__() # 修改DBHead的残差块 self.res_block = nn.Sequential( nn.Conv2d(in_channels, in_channels, 3, padding=1), nn.BatchNorm2d(in_channels), nn.LeakyReLU(0.2, inplace=True), # 使用LeakyReLU防止梯度消失 nn.Conv2d(in_channels, in_channels, 3, padding=1), nn.BatchNorm2d(in_channels) ) # 添加空间注意力机制 self.spatial_attn = nn.Sequential( nn.Conv2d(in_channels, 1, kernel_size=3, padding=1), nn.Sigmoid() ) # 通道注意力机制 self.attention = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(in_channels, in_channels // 8, 1), nn.ReLU(inplace=True), nn.Conv2d(in_channels // 8, in_channels, 1), nn.Sigmoid() ) # 二值化分支 self.binarize = nn.Sequential( nn.Conv2d(in_channels, in_channels // 2, 3, padding=1), nn.BatchNorm2d(in_channels // 2), nn.ReLU(inplace=True), nn.Conv2d(in_channels // 2, in_channels // 2, 3, padding=1), nn.BatchNorm2d(in_channels // 2), nn.ReLU(inplace=True), nn.ConvTranspose2d(in_channels // 2, in_channels // 4, 4, stride=2, padding=1), nn.BatchNorm2d(in_channels // 4), nn.ReLU(inplace=True), nn.ConvTranspose2d(in_channels // 4, 1, 4, stride=2, padding=1), nn.Sigmoid() ) # 阈值分支 self.thresh = nn.Sequential( nn.Conv2d(in_channels, in_channels // 2, 3, padding=1), nn.BatchNorm2d(in_channels // 2), nn.ReLU(inplace=True), nn.Conv2d(in_channels // 2, in_channels // 4, 3, padding=1), nn.BatchNorm2d(in_channels // 4), nn.ReLU(inplace=True), nn.ConvTranspose2d(in_channels // 4, in_channels // 8, 4, stride=2, padding=1), nn.BatchNorm2d(in_channels // 8), nn.ReLU(inplace=True), nn.ConvTranspose2d(in_channels // 8, 1, 4, stride=2, padding=1), nn.Sigmoid() ) def forward(self, x): # 残差连接 residual = x x = self.res_block(x) + residual # 空间注意力 attn_map = self.spatial_attn(x) x = x * attn_map binary_map = self.binarize(x) thresh_map = self.thresh(x) return binary_map, thresh_map # ---------------------------- # 3. 损失函数定义 (增强版) # ---------------------------- class DBLoss(nn.Module): """重构后的 DBNet 损失函数,符合原始论文设计[1,2,4](@ref)""" def __init__(self, alpha=1.0, beta=10.0, k=50, ohem_ratio=3.0): super(DBLoss, self).__init__() self.alpha = alpha # 概率图损失权重 self.beta = beta # 阈值图损失权重 self.k = k # 可微二值化参数[1](@ref) self.ohem_ratio = ohem_ratio def forward(self, preds, targets): binary_pred, thresh_pred = preds binary_target, thresh_target = targets # 1. 概率图损失(二值图损失)使用带 OHEM 的 Dice Loss[2](@ref) prob_loss = self.dice_loss_with_ohem(binary_pred, binary_target) # 2. 阈值图损失使用 L1 Loss[4](@ref) thresh_loss = F.l1_loss(thresh_pred, thresh_target, reduction=&#39;mean&#39;) # 3. 可微二值化计算[1](@ref) with torch.no_grad(): # 计算近似二值图 B = 1 / (1 + exp(-k(P - T))) binary_map = torch.sigmoid(self.k * (binary_pred - thresh_pred)) # 4. 二值图损失使用 Dice Loss bin_loss = self.dice_loss(binary_map, binary_target) # 5. 组合损失:L = L_s + α × L_t + β × L_b total_loss = prob_loss + self.alpha * thresh_loss + self.beta * bin_loss return total_loss, prob_loss, thresh_loss, bin_loss def dice_loss(self, pred, target): """标准 Dice Loss 实现""" smooth = 1.0 intersection = (pred * target).sum() union = pred.sum() + target.sum() return 1 - (2. * intersection + smooth) / (union + smooth) def dice_loss_with_ohem(self, pred, target): """带 OHEM 的 Dice Loss 实现[2](@ref)""" # 计算每个像素的损失 loss_map = 1 - (2 * pred * target + 1) / (pred + target + 1) # 应用 OHEM 采样 pos_mask = (target > 0.5).float() neg_mask = 1 - pos_mask # 计算正负样本数量 n_pos = pos_mask.sum().item() n_neg = min(int(n_pos * self.ohem_ratio), neg_mask.sum().item()) if n_neg == 0: return self.dice_loss(pred, target) # 选择最难负样本 neg_loss = loss_map * neg_mask neg_loss = neg_loss.view(-1) topk_neg_loss, _ = torch.topk(neg_loss, n_neg) # 组合正负样本损失 pos_loss = (loss_map * pos_mask).sum() total_loss = (pos_loss + topk_neg_loss.sum()) / (n_pos + n_neg + 1e-6) return total_loss # ---------------------------- # 辅助函数 (保持不变) # ---------------------------- def calculate_metrics(pred, target, threshold=0.5): """计算精确度、召回率和F1分数""" pred_bin = (pred > threshold).float() target_bin = (target > 0.5).float() pred_flat = pred_bin.view(-1).cpu().numpy() target_flat = target_bin.view(-1).cpu().numpy() # 避免全零情况 if np.sum(target_flat) == 0: return 0.0, 0.0, 0.0 precision = precision_score(target_flat, pred_flat, zero_division=0) recall = recall_score(target_flat, pred_flat, zero_division=0) f1 = f1_score(target_flat, pred_flat, zero_division=0) return precision, recall, f1 # ... (保持不变) ... def validate_model(model, dataloader, device): """验证模型性能""" model.eval() total_precision = 0.0 total_recall = 0.0 total_f1 = 0.0 num_batches = 0 with torch.no_grad(): for images, binary_targets, _ in dataloader: images = images.to(device) binary_targets = binary_targets.to(device) binary_preds, _ = model(images) precision, recall, f1 = calculate_metrics(binary_preds, binary_targets) total_precision += precision total_recall += recall total_f1 += f1 num_batches += 1 avg_precision = total_precision / num_batches avg_recall = total_recall / num_batches avg_f1 = total_f1 / num_batches return avg_precision, avg_recall, avg_f1 # 2. 动态损失权重校准 - 修改DBLoss类 class AdaptiveDBLoss(DBLoss): def __init__(self, alpha=1.0, beta=5.0, gamma=2.0, adapt_step=100): super().__init__(alpha, beta, gamma) self.adapt_step = adapt_step self.beta_history = [] def forward(self, preds, targets, step): # 动态调整β系数 if step % self.adapt_step == 0 and len(self.beta_history) > 10: db_median = np.median(self.beta_history[-10:]) self.beta = max(1.0, min(db_median * 0.8, 10.0)) total_loss, bin_loss, thresh_loss, db_loss = super().forward(preds, targets) # 记录当前β值的表现 self.beta_history.append(db_loss.item()) return total_loss, bin_loss, thresh_loss, db_loss # 3. 模型架构增强 - 替换原始DBHead class EnhancedDBHead(DBHead): def __init__(self, in_channels): super().__init__(in_channels) # 增加通道容量 self.res_block = nn.Sequential( nn.Conv2d(in_channels, in_channels * 2, 3, padding=1), nn.GroupNorm(8, in_channels * 2), nn.GELU(), nn.Conv2d(in_channels * 2, in_channels, 3, padding=1), nn.GroupNorm(8, in_channels) ) # 深度可分离卷积增强特征 self.depthwise = nn.Sequential( nn.Conv2d(in_channels, in_channels, 3, padding=1, groups=in_channels), nn.Conv2d(in_channels, in_channels * 4, 1), nn.GELU() ) # 自门控注意力机制 self.gate_attn = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(in_channels, in_channels // 4, 1), nn.GELU(), nn.Conv2d(in_channels // 4, in_channels, 1), nn.Sigmoid() ) def forward(self, x): residual = x x = self.res_block(x) + residual # 深度特征提取 depth_feat = self.depthwise(x) # 门控特征融合 gate = self.gate_attn(depth_feat) x = x * gate + depth_feat # 原始输出 return super().forward(x) # ---------------------------- # 4. 训练函数 (增强版,添加进度条) # ---------------------------- def enhanced_train_model(model, train_loader, val_loader, criterion, optimizer, device, epochs=200, checkpoint_path=&#39;dbnet_checkpoint.pth&#39;, lr_init=5e-5): # 初始化 start_epoch = 0 best_loss = float(&#39;inf&#39;) best_f1 = 0.0 logger = EnhancedTrainingLogger() # 学习率调度器 (CosineAnnealingWarmRestarts) scheduler = ReduceLROnPlateau(optimizer, mode=&#39;min&#39;, factor=0.5, patience=3, verbose=True) # 混合精度训练 scaler = torch.cuda.amp.GradScaler() # 检查点恢复机制 if os.path.exists(checkpoint_path): print(f"发现检查点文件 {checkpoint_path}, 尝试恢复训练...") checkpoint = torch.load(checkpoint_path) model.load_state_dict(checkpoint[&#39;model_state_dict&#39;]) optimizer.load_state_dict(checkpoint[&#39;optimizer_state_dict&#39;]) start_epoch = checkpoint[&#39;epoch&#39;] + 1 best_loss = checkpoint[&#39;best_loss&#39;] logger = checkpoint[&#39;logger&#39;] print(f"成功恢复训练状态: 从第 {start_epoch} 轮开始, 最佳损失: {best_loss:.6f}") if not logger.total_losses: # 检查日志是否为空 logger = EnhancedTrainingLogger() # 创建新的日志记录器 model.train() optimizer.param_groups[0][&#39;lr&#39;] = lr_init try: for epoch in range(start_epoch, epochs): epoch_total_loss = 0.0 epoch_bin_loss = 0.0 epoch_thresh_loss = 0.0 epoch_db_loss = 0.0 epoch_start = time.time() # 使用tqdm添加进度条 pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch + 1}/{epochs}", unit="batch") for batch_idx, (images, binary_targets, thresh_targets) in pbar: images = images.to(device) binary_targets = binary_targets.to(device) thresh_targets = thresh_targets.to(device) # 混合精度训练 with torch.cuda.amp.autocast(): binary_preds, thresh_preds = model(images) total_loss, bin_loss, thresh_loss, db_loss = criterion( (binary_preds, thresh_preds), (binary_targets, thresh_targets) ) # 记录损失 epoch_total_loss += total_loss.item() epoch_bin_loss += bin_loss.item() epoch_thresh_loss += thresh_loss.item() epoch_db_loss += db_loss.item() # 记录日志 current_lr = optimizer.param_groups[0][&#39;lr&#39;] logger.on_batch_end( batch_idx, total_loss.item(), bin_loss.item(), thresh_loss.item(), db_loss.item(), current_lr ) # 更新进度条描述 pbar.set_postfix({ &#39;Loss&#39;: f"{total_loss.item():.4f}", &#39;Bin&#39;: f"{bin_loss.item():.4f}", &#39;Thresh&#39;: f"{thresh_loss.item():.4f}", &#39;DB&#39;: f"{db_loss.item():.4f}", &#39;LR&#39;: f"{current_lr:.2e}" }) # 反向传播 optimizer.zero_grad() scaler.scale(total_loss).backward() # 梯度裁剪 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) scaler.step(optimizer) scaler.update() # 更新学习率 scheduler.step(epoch + batch_idx / len(train_loader)) # 每100个batch保存一次紧急检查点 if batch_idx % 100 == 0: checkpoint = { &#39;epoch&#39;: epoch, &#39;model_state_dict&#39;: model.state_dict(), &#39;optimizer_state_dict&#39;: optimizer.state_dict(), &#39;best_loss&#39;: best_loss, &#39;logger&#39;: logger, &#39;scheduler_state&#39;: scheduler.state_dict() } torch.save(checkpoint, checkpoint_path) # 计算平均损失 num_batches = len(train_loader) avg_total_loss = epoch_total_loss / num_batches avg_bin_loss = epoch_bin_loss / num_batches avg_thresh_loss = epoch_thresh_loss / num_batches avg_db_loss = epoch_db_loss / num_batches # 验证模型 precision, recall, f1 = validate_model(model, val_loader, device) logger.val_metrics[&#39;precision&#39;].append(precision) logger.val_metrics[&#39;recall&#39;].append(recall) logger.val_metrics[&#39;f1&#39;].append(f1) epoch_time = time.time() - epoch_start print(f"Epoch [{epoch + 1}/{epochs}] completed in {epoch_time:.2f}s") print( f" - Avg Loss: {avg_total_loss:.6f} (Bin:{avg_bin_loss:.6f}, Thresh:{avg_thresh_loss:.6f}, DB:{avg_db_loss:.6f})") print(f" - Val Metrics: Precision={precision:.4f}, Recall={recall:.4f}, F1={f1:.4f}") logger.on_epoch_end(epoch, optimizer) # 保存最佳模型 if f1 > best_f1 or (f1 == best_f1 and avg_total_loss < best_loss): best_f1 = f1 best_loss = avg_total_loss torch.save({ &#39;epoch&#39;: epoch + 1, &#39;model_state_dict&#39;: model.state_dict(), &#39;optimizer_state_dict&#39;: optimizer.state_dict(), &#39;loss&#39;: avg_total_loss, &#39;f1&#39;: best_f1 }, &#39;dbnet_best.pth&#39;) print(f"🔥 发现新的最佳模型! F1: {best_f1:.4f}, 损失: {best_loss:.6f}") # 保存常规检查点 checkpoint = { &#39;epoch&#39;: epoch + 1, &#39;model_state_dict&#39;: model.state_dict(), &#39;optimizer_state_dict&#39;: optimizer.state_dict(), &#39;best_loss&#39;: best_loss, &#39;logger&#39;: logger, &#39;scheduler_state&#39;: scheduler.state_dict() } torch.save(checkpoint, checkpoint_path) except KeyboardInterrupt: print("\n训练被用户中断!") except Exception as e: print(f"\n❌ 训练中断! 原因: {str(e)}") traceback.print_exc() finally: print("训练完成! 保存最终模型...") torch.save(model.state_dict(), &#39;dbnet_final.pth&#39;) logger.on_train_end() return model # ---------------------------- # 5. 推理与区域裁剪 (增强版) # ---------------------------- def enhanced_detect_text_regions(image, model, device, threshold=0.3): # 预处理 orig_h, orig_w = image.shape[:2] input_img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) input_img = cv2.resize(input_img, (640, 640)) input_img = input_img.astype(np.float32) / 255.0 input_img = (input_img - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225] input_tensor = torch.from_numpy(input_img).permute(2, 0, 1).unsqueeze(0).to(device) input_tensor = input_tensor.to(torch.float32) # 推理 with torch.no_grad(): binary_map, _ = model(input_tensor) # 后处理 binary_map = binary_map.squeeze().cpu().numpy() binary_output = (binary_map > threshold).astype(np.uint8) * 255 # 形态学操作增强 kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) binary_output = cv2.morphologyEx(binary_output, cv2.MORPH_CLOSE, kernel) # 查找轮廓 contours, _ = cv2.findContours(binary_output, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) return contours, orig_h, orig_w # 返回检测到的文本区域轮廓 # ... (保持不变) ... def perspective_transform(image, contour): """对检测到的文本区域进行透视变换校正""" # 多边形逼近轮廓 epsilon = 0.02 * cv2.arcLength(contour, True) approx = cv2.approxPolyDP(contour, epsilon, True) # 确保是四边形 if len(approx) != 4: # 使用最小外接矩形 rect = cv2.minAreaRect(contour) box = cv2.boxPoints(rect) approx = np.int0(box) # 获取四边形顶点并排序 (左上, 右上, 右下, 左下) pts = approx.reshape(4, 2) rect_pts = np.zeros((4, 2), dtype="float32") # 计算顶点和 s = pts.sum(axis=1) rect_pts[0] = pts[np.argmin(s)] # 左上 rect_pts[2] = pts[np.argmax(s)] # 右下 # 计算顶点差 diff = np.diff(pts, axis=1) rect_pts[1] = pts[np.argmin(diff)] # 右上 rect_pts[3] = pts[np.argmax(diff)] # 左下 # 计算目标矩形尺寸 (tl, tr, br, bl) = rect_pts widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2)) widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2)) maxWidth = max(int(widthA), int(widthB)) heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2)) heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2)) maxHeight = max(int(heightA), int(heightB)) # 目标点坐标 dst = np.array([ [0, 0], [maxWidth - 1, 0], [maxWidth - 1, maxHeight - 1], [0, maxHeight - 1]], dtype="float32") # 计算透视变换矩阵并应用 M = cv2.getPerspectiveTransform(rect_pts, dst) warped = cv2.warpPerspective(image, M, (maxWidth, maxHeight)) return warped def crop_text_regions(image, contours, orig_h, orig_w): """裁剪检测到的文本区域并进行透视校正""" cropped_regions = [] # 计算缩放比例 (从640x640到原始尺寸) scale_x = orig_w / 640.0 scale_y = orig_h / 640.0 for contour in contours: # 过滤小区域 if cv2.contourArea(contour) < 100: continue # 缩放轮廓到原始图像尺寸 scaled_contour = contour.copy() scaled_contour[:, :, 0] = scaled_contour[:, :, 0] * scale_x scaled_contour[:, :, 1] = scaled_contour[:, :, 1] * scale_y # 获取轮廓边界框 x, y, w, h = cv2.boundingRect(scaled_contour) # 扩展边界框 (增加10%的边距) margin_x = int(w * 0.1) margin_y = int(h * 0.1) x = max(0, x - margin_x) y = max(0, y - margin_y) w = min(orig_w - x, w + 2 * margin_x) h = min(orig_h - y, h + 2 * margin_y) # 裁剪区域 roi = image[y:y + h, x:x + w] # 对裁剪区域进行透视校正 try: # 调整轮廓坐标到ROI坐标系 roi_contour = scaled_contour.copy() roi_contour[:, :, 0] -= x roi_contour[:, :, 1] -= y # 应用透视变换 warped_roi = perspective_transform(roi, roi_contour) # 确保最小尺寸 if warped_roi.shape[0] > 10 and warped_roi.shape[1] > 10: cropped_regions.append(warped_roi) except Exception as e: # 如果透视变换失败,使用原始ROI print(f"透视变换失败: {str(e)},使用原始ROI") cropped_regions.append(roi) return cropped_regions # ---------------------------- # 7. 模型加载与推理接口 (新增功能) # ---------------------------- def load_trained_model(model_path, device=&#39;cuda&#39;): """加载训练好的模型""" model = DBNet(pretrained=False).to(device) checkpoint = torch.load(model_path, map_location=device) model.load_state_dict(checkpoint[&#39;model_state_dict&#39;]) model.eval() return model # 5. 水表图像增强改进 def water_meter_specific_aug(image, **kwargs): """水表专用增强链""" # 抑制高频反光 kernel_size = int(min(image.shape[:2]) * 0.01) if kernel_size % 2 == 0: kernel_size += 1 blurred = cv2.GaussianBlur(image, (kernel_size, kernel_size), 0) # 自适应直方图均衡 lab = cv2.cvtColor(blurred, cv2.COLOR_RGB2LAB) l, a, b = cv2.split(lab) clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8)) l_eq = clahe.apply(l) # 色偏校正 a_balanced = cv2.normalize(a, None, 0, 255, cv2.NORM_MINMAX) b_balanced = cv2.normalize(b, None, 0, 255, cv2.NORM_MINMAX) return cv2.cvtColor(cv2.merge([l_eq, a_balanced, b_balanced]), cv2.COLOR_LAB2RGB) def suppress_glare(image): """减少图像反光区域的影响[1](@ref)""" lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB) l, a, b = cv2.split(lab) # 对亮度通道进行CLAHE均衡化 clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8)) l_clahe = clahe.apply(l) # 合并通道 lab_clahe = cv2.merge((l_clahe, a, b)) return cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2RGB) def detect_and_crop(image_path, model, device=&#39;cuda&#39;, output_dir=&#39;cropped_regions&#39;): """使用训练好的模型检测并裁剪水表数字区域""" # 创建输出目录 os.makedirs(output_dir, exist_ok=True) # 读取图像 image = cv2.imread(image_path) if image is None: print(f"错误: 无法读取图像 {image_path}") return [] # 应用反光抑制 image = suppress_glare(image) # 检测文本区域 contours, orig_h, orig_w = enhanced_detect_text_regions(image, model, device) # 裁剪文本区域 cropped_regions = crop_text_regions(image, contours, orig_h, orig_w) # 保存结果 base_name = os.path.splitext(os.path.basename(image_path))[0] for i, region in enumerate(cropped_regions): output_path = os.path.join(output_dir, f&#39;{base_name}_region_{i}.jpg&#39;) cv2.imwrite(output_path, region) print(f"成功裁剪 {len(cropped_regions)} 个文本区域到 {output_dir}") return cropped_regions # ---------------------------- # 8. 主程序 (优化版) # ---------------------------- if __name__ == "__main__": # 优化参数 INPUT_SIZE = (512, 512) # 减小输入尺寸适配水表 # 配置参数 DEVICE = &#39;cuda&#39; if torch.cuda.is_available() else &#39;cpu&#39; DATA_DIR = &#39;images_train&#39; LABEL_DIR = &#39;labels_train&#39; VAL_DATA_DIR = &#39;images_val&#39; VAL_LABEL_DIR = &#39;labels_val&#39; BATCH_SIZE = 16 EPOCHS = 100 LR = 1e-4 CHECKPOINT_PATH = &#39;dbnet_checkpoint.pth&#39; TRAINED_MODEL_PATH = &#39;dbnet_best.pth&#39; # 模式选择: &#39;train&#39; 或 &#39;inference&#39; MODE = &#39;train&#39; if MODE == &#39;train&#39;: # 1. 准备数据集 print("准备训练数据集...") train_dataset = WaterMeterDataset( image_dir=DATA_DIR, label_dir=LABEL_DIR, input_size=INPUT_SIZE, augment=True ) train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True) print("准备验证数据集...") val_dataset = WaterMeterDataset( image_dir=VAL_DATA_DIR, label_dir=VAL_LABEL_DIR, input_size=INPUT_SIZE, augment=False ) val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2) # 2. 初始化模型 print("初始化模型...") model = DBNet(pretrained=True).to(DEVICE) # 3. 损失函数和优化器 # 初始化时使用自适应损失 criterion = DBLoss(alpha=1.0, beta=8.0) # 使用更先进的优化器 # 1. 强化学习率调度机制 - 更新优化器配置 optimizer = optim.AdamW( # 替换原始Adam model.parameters(), lr=3e-4, # 适当提升基础学习率 weight_decay=1e-4 ) # 4. 训练模型 print("开始训练...") model = enhanced_train_model( model, train_loader, val_loader, criterion, optimizer, DEVICE, epochs=EPOCHS, checkpoint_path=CHECKPOINT_PATH, lr_init=LR ) print(f"✅ 训练完成! 最佳模型已保存到 {TRAINED_MODEL_PATH}") elif MODE == &#39;inference&#39;: # 加载训练好的模型 print(f"加载训练好的模型: {TRAINED_MODEL_PATH}") model = load_trained_model(TRAINED_MODEL_PATH, DEVICE) # 处理单个图像 test_image_path = &#39;test_images/test_1.jpg&#39; print(f"处理测试图像: {test_image_path}") detect_and_crop(test_image_path, model, DEVICE) # 处理整个目录 input_dir = &#39;test_images&#39; output_dir = &#39;cropped_results&#39; print(f"批量处理目录: {input_dir}") for img_file in os.listdir(input_dir): if img_file.lower().endswith((&#39;.jpg&#39;, &#39;.png&#39;, &#39;.jpeg&#39;)): img_path = os.path.join(input_dir, img_file) print(f"处理图像: {img_file}") detect_and_crop(img_path, model, DEVICE, output_dir)
06-07
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值