cls._defaults 和 self.__dict__

cls._defaultsself.__dict__

cls._defaultsself.__dict__ 在 Python 编程中都是用于存储对象状态的方式,但它们在使用场景和目的上有一些不同。让我们逐一分析它们:

  1. cls._defaults

cls._defaults 通常是类变量,用于存储默认属性值。这里的 cls 通常指的是类本身,而不是类的实例。_defaults 是一种约定俗成的命名方式,用于表示这些属性是默认值。这些默认值可以在创建类的实例时用作属性值的默认值。

例如,假设有一个类 Person,它有一些属性如 nameagecity。你可以为这些属性设置默认值,并在创建 Person 的实例时使用这些默认值。

class Person:
    _defaults = {
   
   'name': 'Unknown'
import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # 0=INFO, 1=WARNING, 2=ERROR, 3=FATAL os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' # 禁用oneDNN日志 import sys import glob import time import json import torch import pickle import shutil import argparse import datetime import torchvision import numpy as np from tqdm import tqdm from PIL import Image import torch.nn as nn from packaging import version from functools import partial import pytorch_lightning as pl from omegaconf import OmegaConf, DictConfig import torch.distributed as dist from typing import List, Dict, Any, Optional, Union, Tuple from ldm.util import instantiate_from_config from pytorch_lightning import seed_everything from pytorch_lightning.trainer import Trainer from torch.utils.data import DataLoader, Dataset from ldm.data.base import Txt2ImgIterableBaseDataset from pytorch_lightning.plugins import DDPPlugin from pytorch_lightning.utilities import rank_zero_info from pytorch_lightning.utilities.distributed import rank_zero_only from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor from torch.cuda.amp import autocast, GradScaler # 模型路径 current_dir = os.path.dirname(os.path.abspath(__file__)) for path in ["download", "download/CLIP", "download/k-diffusion", "download/stable_diffusion", "download/taming-transformers"]: sys.path.append(os.path.join(current_dir, path)) class ConfigManager: """配置管理类,统一处理配置加载解析""" def __init__(self, config_files: Union[str, List[str]], cli_args: Optional[List[str]] = None): # 将单个字符串路径转换为列表 if isinstance(config_files, str): config_files = [config_files] # 验证配置文件存在 self.configs = [] for cfg in config_files: if not os.path.exists(cfg): raise FileNotFoundError(f"配置文件不存在: {cfg}") self.configs.append(OmegaConf.load(cfg)) # 解析命令行参数 self.cli = OmegaConf.from_dotlist(cli_args) if cli_args else OmegaConf.create() # 合并所有配置 self.config = OmegaConf.merge(*self.configs, self.cli) def get_model_config(self) -> DictConfig: """获取模型配置""" if "model" not in self.config: raise KeyError("配置文件中缺少'model'部分") return self.config.model def get_data_config(self) -> DictConfig: """获取数据配置""" if "data" not in self.config: raise KeyError("配置文件中缺少'data'部分") return self.config.data def get_training_config(self) -> DictConfig: """获取训练配置,提供默认值""" training_config = self.config.get("training", OmegaConf.create()) # 设置默认值 defaults = { "max_epochs": 200, "gpus": torch.cuda.device_count(), "accumulate_grad_batches": 1, "learning_rate": 1e-4, "precision": 32 } for key, value in defaults.items(): if key not in training_config: training_config[key] = value return training_config def get_logging_config(self) -> DictConfig: """获取日志配置""" return self.config.get("logging", OmegaConf.create({"logdir": "logs"})) def get_callbacks_config(self) -> DictConfig: """获取回调函数配置""" return self.config.get("callbacks", OmegaConf.create()) def save_config(self, save_path: str) -> None: """保存配置到文件""" os.makedirs(os.path.dirname(save_path), exist_ok=True) OmegaConf.save(self.config, save_path) print(f"配置已保存到: {save_path}") class DataModuleFromConfig(pl.LightningDataModule): def __init__(self, batch_size, num_workers, train=None, validation=None, test=None): super().__init__() self.batch_size = batch_size self.num_workers = num_workers self.dataset_configs = dict() if train is not None: self.dataset_configs["train"] = train if validation is not None: self.dataset_configs["validation"] = validation if test is not None: self.dataset_configs["test"] = test def setup(self, stage=None): self.datasets = { k: instantiate_from_config(cfg) for k, cfg in self.dataset_configs.items() } def _get_dataloader(self, dataset_name, shuffle=False): dataset = self.datasets.get(dataset_name) if dataset is None: raise ValueError(f"数据集 {dataset_name} 未配置") return DataLoader( dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=shuffle, pin_memory=True ) def train_dataloader(self): return self._get_dataloader("train", shuffle=True) def val_dataloader(self): return self._get_dataloader("validation") def test_dataloader(self): return self._get_dataloader("test") def worker_init_fn(worker_id: int) -> None: """数据加载器工作进程初始化函数""" worker_info = torch.utils.data.get_worker_info() if worker_info is None: return dataset = worker_info.dataset worker_id = worker_info.id if isinstance(dataset, Txt2ImgIterableBaseDataset): # 对可迭代数据集进行分片 split_size = dataset.num_records // worker_info.num_workers dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size] # 设置随机种子 seed = torch.initial_seed() % 2**32 + worker_id np.random.seed(seed) torch.manual_seed(seed) class EnhancedImageLogger(Callback): """增强的图像日志记录器,支持多平台日志输出""" def __init__(self, batch_frequency: int, max_images: int, clamp: bool = True, rescale: bool = True, loggers: Optional[List] = None, log_first_step: bool = False, log_images_kwargs: Optional[Dict] = None): super().__init__() self.batch_frequency = max(1, batch_frequency) self.max_images = max_images self.clamp = clamp self.rescale = rescale self.loggers = loggers or [] self.log_first_step = log_first_step self.log_images_kwargs = log_images_kwargs or {} self.log_steps = [2 ** n for n in range(6, int(np.log2(self.batch_frequency)) + 1)] if self.batch_frequency > 1 else [] def check_frequency(self, step: int) -> bool: """检查是否达到记录频率""" if step == 0 and self.log_first_step: return True if step % self.batch_frequency == 0: return True if step in self.log_steps: if len(self.log_steps) > 0: self.log_steps.pop(0) return True return False def log_images(self, pl_module: pl.LightningModule, batch: Any, step: int, split: str = "train") -> None: """记录图像并发送到所有日志记录器""" if not self.check_frequency(step) or not hasattr(pl_module, "log_images"): return is_train = pl_module.training if is_train: pl_module.eval() # 切换到评估模式 with torch.no_grad(): try: images = pl_module.log_images(batch, split=split, **self.log_images_kwargs) except Exception as e: print(f"记录图像时出错: {e}") images = {} # 处理图像数据 for k in list(images.keys()): if not isinstance(images[k], torch.Tensor): continue N = min(images[k].shape[0], self.max_images) images[k] = images[k][:N] # 分布式环境下收集所有图像 if torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1: images[k] = torch.cat(all_gather(images[k])) images[k] = images[k].detach().cpu() if self.clamp: images[k] = torch.clamp(images[k], -1., 1.) if self.rescale: images[k] = (images[k] + 1.0) / 2.0 # 缩放到[0,1] # 发送到所有日志记录器 for logger in self.loggers: if hasattr(logger, 'log_images'): try: logger.log_images(images, step, split) except Exception as e: print(f"日志记录器 {type(logger).__name__} 记录图像失败: {e}") if is_train: pl_module.train() # 恢复训练模式 def on_train_batch_end(self, trainer: Trainer, pl_module: pl.LightningModule, outputs: Any, batch: Any, batch_idx: int) -> None: """训练批次结束时记录图像""" if trainer.global_step % trainer.log_every_n_steps == 0: self.log_images(pl_module, batch, pl_module.global_step, "train") def on_validation_batch_end(self, trainer: Trainer, pl_module: pl.LightningModule, outputs: Any, batch: Any, batch_idx: int) -> None: """验证批次结束时记录图像""" if batch_idx == 0: # 只记录第一个验证批次 self.log_images(pl_module, batch, pl_module.global_step, "val") class TensorBoardLogger: """TensorBoard日志记录器,完整实现PyTorch Lightning日志记录器接口""" def __init__(self, save_dir: str): from torch.utils.tensorboard import SummaryWriter os.makedirs(save_dir, exist_ok=True) self.save_dir = save_dir self.writer = SummaryWriter(save_dir) self._name = "TensorBoard" # 日志记录器名称 self._version = "1.0" # 版本信息 self._experiment = self.writer # 实验对象 print(f"TensorBoard日志保存在: {save_dir}") @property def name(self) -> str: return self._name @property def version(self) -> str: return self._version @property def experiment(self) -> Any: return self._experiment def log_hyperparams(self, params: Dict) -> None: """记录超参数到TensorBoard""" try: # 将嵌套字典展平 flat_params = {} for key, value in params.items(): if isinstance(value, dict): for sub_key, sub_value in value.items(): flat_params[f"{key}/{sub_key}"] = sub_value else: flat_params[key] = value # 记录超参数 self.writer.add_hparams( {k: v for k, v in flat_params.items() if isinstance(v, (int, float, str))}, {}, run_name="." ) print("已记录超参数到TensorBoard") except Exception as e: print(f"记录超参数失败: {e}") def log_graph(self, model: torch.nn.Module, input_array: Optional[torch.Tensor] = None) -> None: """记录模型计算图到TensorBoard""" try: # 扩散模型通常有复杂的前向传播,跳过图记录 print("跳过扩散模型的计算图记录") return except Exception as e: print(f"记录模型计算图失败: {e}") def log_metrics(self, metrics: Dict[str, float], step: int) -> None: """记录指标到TensorBoard""" for name, value in metrics.items(): try: self.writer.add_scalar(name, value, global_step=step) except Exception as e: print(f"添加标量失败: {name}, 错误: {e}") def log_images(self, images: Dict[str, torch.Tensor], step: int, split: str) -> None: """记录图像到TensorBoard""" for k, img in images.items(): if img.numel() == 0: continue try: grid = torchvision.utils.make_grid(img, nrow=min(8, img.shape[0])) self.writer.add_image(f"{split}/{k}", grid, global_step=step) except Exception as e: print(f"添加图像失败: {k}, 错误: {e}") def save(self) -> None: """保存日志(TensorBoard自动保存,这里无需额外操作)""" pass def finalize(self, status: str) -> None: """完成日志记录并关闭写入器""" self.close() def close(self) -> None: """关闭日志写入器""" if hasattr(self, 'writer') and self.writer is not None: self.writer.flush() self.writer.close() self.writer = None print(f"TensorBoard日志已关闭") class TQDMProgressBar(Callback): """使用tqdm显示训练进度,兼容不同版本的PyTorch Lightning""" def __init__(self): self.progress_bar = None self.epoch_bar = None def on_train_start(self, trainer: Trainer, pl_module: pl.LightningModule) -> None: """训练开始时初始化进度条""" # 兼容不同版本的步数估计 total_steps = self._get_total_steps(trainer) self.progress_bar = tqdm( total=total_steps, desc="Training Steps", position=0, leave=True, dynamic_ncols=True ) self.epoch_bar = tqdm( total=trainer.max_epochs, desc="Epochs", position=1, leave=True, dynamic_ncols=True ) def _get_total_steps(self, trainer: Trainer) -> int: """获取训练总步数,兼容不同版本的PyTorch Lightning""" # 尝试使用新版本属性 if hasattr(trainer, 'estimated_stepping_batches'): return trainer.estimated_stepping_batches # 尝试使用旧版本属性 if hasattr(trainer, 'estimated_steps'): return trainer.estimated_steps # 回退到手动计算 try: if hasattr(trainer, 'num_training_batches'): num_batches = trainer.num_training_batches else: num_batches = len(trainer.train_dataloader) if hasattr(trainer, 'accumulate_grad_batches'): accumulate = trainer.accumulate_grad_batches else: accumulate = 1 steps_per_epoch = num_batches // accumulate total_steps = trainer.max_epochs * steps_per_epoch print(f"回退计算训练总步数: {total_steps} = {trainer.max_epochs} epochs × {steps_per_epoch} steps/epoch") return total_steps except Exception as e: print(f"无法确定训练总步数: {e}, 使用默认值10000") return 10000 def on_train_batch_end(self, trainer: Trainer, pl_module: pl.LightningModule, outputs: Any, batch: Any, batch_idx: int) -> None: """每个训练批次结束时更新进度条""" if self.progress_bar: # 防止进度条超过总步数 if self.progress_bar.n < self.progress_bar.total: self.progress_bar.update(1) try: # 尝试从输出中获取损失 loss = outputs.get('loss') if loss is not None: if isinstance(loss, torch.Tensor): loss = loss.item() self.progress_bar.set_postfix({"loss": loss}) except Exception: pass def on_train_epoch_end(self, trainer: Trainer, pl_module: pl.LightningModule) -> None: """每个训练轮次结束时更新轮次进度条""" if self.epoch_bar: self.epoch_bar.update(1) self.epoch_bar.set_postfix({"epoch": trainer.current_epoch}) def on_train_end(self, trainer: Trainer, pl_module: pl.LightningModule) -> None: """训练结束时关闭进度条""" if self.progress_bar: self.progress_bar.close() if self.epoch_bar: self.epoch_bar.close() class PerformanceMonitor(Callback): """性能监控回调,记录内存使用训练速度""" def __init__(self): self.epoch_start_time = 0 self.batch_times = [] def on_train_epoch_start(self, trainer: Trainer, pl_module: pl.LightningModule) -> None: """每个训练轮次开始时记录时间重置内存统计""" self.epoch_start_time = time.time() self.batch_times = [] if torch.cuda.is_available(): torch.cuda.reset_peak_memory_stats() torch.cuda.synchronize() # 修改1:添加dataloader_idx参数 def on_train_batch_start(self, trainer: Trainer, pl_module: pl.LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: """每个训练批次开始时记录时间""" self.batch_start_time = time.time() # 修改2:添加dataloader_idx参数 def on_train_batch_end(self, trainer: Trainer, pl_module: pl.LightningModule, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: """每个训练批次结束时记录时间""" self.batch_times.append(time.time() - self.batch_start_time) def on_train_epoch_end(self, trainer: Trainer, pl_module: pl.LightningModule) -> None: """每个训练轮次结束时计算并记录性能指标""" epoch_time = time.time() - self.epoch_start_time if self.batch_times: avg_batch_time = sum(self.batch_times) / len(self.batch_times) batches_per_second = 1.0 / avg_batch_time else: avg_batch_time = 0 batches_per_second = 0 memory_info = "" if torch.cuda.is_available(): max_memory = torch.cuda.max_memory_allocated() / 2 ** 20 # MiB memory_info = f", 峰值显存: {max_memory:.2f} MiB" rank_zero_info( f"Epoch {trainer.current_epoch} | " f"耗时: {epoch_time:.2f}s | " f"Batch耗时: {avg_batch_time:.4f}s ({batches_per_second:.2f} batches/s)" f"{memory_info}" ) def get_world_size() -> int: """获取分布式训练中的总进程数""" if dist.is_initialized(): return dist.get_world_size() return 1 def all_gather(data: torch.Tensor) -> List[torch.Tensor]: """在分布式环境中收集所有进程的数据""" world_size = get_world_size() if world_size == 1: return [data] # 获取各进程的Tensor大小 local_size = torch.tensor([data.numel()], device=data.device) size_list = [torch.zeros_like(local_size) for _ in range(world_size)] dist.all_gather(size_list, local_size) size_list = [int(size.item()) for size in size_list] max_size = max(size_list) # 收集数据 tensor_list = [] for size in size_list: tensor_list.append(torch.empty((max_size,), dtype=data.dtype, device=data.device)) if local_size < max_size: padding = torch.zeros(max_size - local_size, dtype=data.dtype, device=data.device) data = torch.cat((data.view(-1), padding)) dist.all_gather(tensor_list, data.view(-1)) # 截断到实际大小 results = [] for tensor, size in zip(tensor_list, size_list): results.append(tensor[:size].reshape(data.shape)) return results def create_experiment_directories(logging_config: DictConfig, experiment_name: str) -> Tuple[str, str, str]: """创建实验目录结构""" now = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") logdir = os.path.join(logging_config.logdir, f"{experiment_name}_{now}") ckptdir = os.path.join(logdir, "checkpoints") cfgdir = os.path.join(logdir, "configs") os.makedirs(ckptdir, exist_ok=True) os.makedirs(cfgdir, exist_ok=True) print(f"实验目录: {logdir}") print(f"检查点目录: {ckptdir}") print(f"配置目录: {cfgdir}") return logdir, ckptdir, cfgdir def setup_callbacks(config_manager: ConfigManager, ckptdir: str, tb_logger: TensorBoardLogger) -> List[Callback]: """设置训练回调函数""" callbacks = [] # 模型检查点 checkpoint_callback = ModelCheckpoint( dirpath=ckptdir, filename='{epoch}-{step}-{val_loss:.2f}', monitor='val_loss', save_top_k=3, mode='min', save_last=True, save_on_train_epoch_end=True, # 确保在epoch结束时保存完整状态 save_weights_only=False, # 明确设置为False,保存完整检查点 every_n_train_steps=1000 # 每1000步保存一次 ) callbacks.append(checkpoint_callback) # 学习率监控 lr_monitor = LearningRateMonitor(logging_interval="step") callbacks.append(lr_monitor) # 图像日志记录 image_logger_cfg = config_manager.get_callbacks_config().get("image_logger", {}) image_logger = EnhancedImageLogger( batch_frequency=image_logger_cfg.get("batch_frequency", 500), max_images=image_logger_cfg.get("max_images", 4), loggers=[tb_logger] ) callbacks.append(image_logger) # 进度条 progress_bar = TQDMProgressBar() callbacks.append(progress_bar) # 性能监控 perf_monitor = PerformanceMonitor() callbacks.append(perf_monitor) return callbacks def preprocess_checkpoint(checkpoint_path: str, model: pl.LightningModule) -> Dict[str, Any]: """预处理检查点文件,确保包含所有必要的键,并添加缺失的训练状态""" print(f"预处理检查点文件: {checkpoint_path}") # 加载检查点 try: checkpoint = torch.load(checkpoint_path, map_location="cpu") except Exception as e: print(f"加载检查点失败: {e}") raise # 强制重置训练状态 checkpoint['epoch'] = 0 checkpoint['global_step'] = 0 checkpoint['lr_schedulers'] = [] checkpoint['optimizer_states'] = [] print("已重置训练状态: epoch=0, global_step=0") # 检查是否缺少关键训练状态 required_keys = ['optimizer_states', 'lr_schedulers', 'epoch', 'global_step'] missing_keys = [k for k in required_keys if k not in checkpoint] if missing_keys: print(f"警告: 检查点缺少训练状态字段 {missing_keys},将创建伪训练状态") # 创建伪训练状态 checkpoint.setdefault('optimizer_states', []) checkpoint.setdefault('lr_schedulers', []) checkpoint.setdefault('epoch', 0) checkpoint.setdefault('global_step', 0) # 检查是否缺少 position_ids state_dict = checkpoint.get("state_dict", {}) if "cond_stage_model.transformer.text_model.embeddings.position_ids" not in state_dict: print("警告: 检查点缺少 'cond_stage_model.transformer.text_model.embeddings.position_ids' 键") # 获取模型中的 position_ids 形状 if hasattr(model, "cond_stage_model") and hasattr(model.cond_stage_model, "transformer"): try: max_position_embeddings = model.cond_stage_model.transformer.text_model.config.max_position_embeddings position_ids = torch.arange(max_position_embeddings).expand((1, -1)) state_dict["cond_stage_model.transformer.text_model.embeddings.position_ids"] = position_ids print("已添加 position_ids 到检查点") except Exception as e: print(f"无法添加 position_ids: {e}") # 确保有 state_dict if "state_dict" not in checkpoint: checkpoint["state_dict"] = state_dict return checkpoint # 正确继承原始模型类 from ldm.models.diffusion.ddpm import LatentDiffusion class CustomLatentDiffusion(LatentDiffusion): """自定义 LatentDiffusion 类,处理检查点加载问题""" def on_load_checkpoint(self, checkpoint): """在加载检查点时自动处理缺失的键""" state_dict = checkpoint["state_dict"] # 检查是否缺少 position_ids if "cond_stage_model.transformer.text_model.embeddings.position_ids" not in state_dict: print("警告: 检查点缺少 'cond_stage_model.transformer.text_model.embeddings.position_ids' 键") # 获取模型中的 position_ids 形状 max_position_embeddings = self.cond_stage_model.transformer.text_model.config.max_position_embeddings position_ids = torch.arange(max_position_embeddings).expand((1, -1)) state_dict["cond_stage_model.transformer.text_model.embeddings.position_ids"] = position_ids print("已添加 position_ids 到 state_dict") # 使用非严格模式加载 self.load_state_dict(state_dict, strict=False) print("模型权重加载完成") def filter_kwargs(cls, kwargs, log_prefix=""): # 关键参数白名单 - 这些参数必须保留 ESSENTIAL_PARAMS = { 'unet_config', 'first_stage_config', 'cond_stage_config', 'scheduler_config', 'ckpt_path', 'linear_start', 'linear_end' } # 特殊处理:允许所有包含"config"的参数 filtered_kwargs = {} for k, v in kwargs.items(): if k in ESSENTIAL_PARAMS or 'config' in k: filtered_kwargs[k] = v else: print(f"{log_prefix}过滤参数: {k}") print(f"{log_prefix}保留参数: {list(filtered_kwargs.keys())}") return filtered_kwargs def check_checkpoint_content(checkpoint_path): """打印检查点包含的键,确认是否有训练状态""" checkpoint = torch.load(checkpoint_path, map_location="cpu") print("检查点包含的键:", list(checkpoint.keys())) if "state_dict" in checkpoint: print("模型权重存在") if "optimizer_states" in checkpoint: print("优化器状态存在") if "epoch" in checkpoint: print(f"保存的epoch: {checkpoint['epoch']}") if "global_step" in checkpoint: print(f"保存的global_step: {checkpoint['global_step']}") def main() -> None: """主函数,训练推理流程的入口点""" # 启用Tensor Core加速 torch.set_float32_matmul_precision('high') # 解析命令行参数 parser = argparse.ArgumentParser(description="扩散模型训练框架") parser.add_argument("--config", type=str, default="configs/train.yaml", help="配置文件路径") parser.add_argument("--name", type=str, default="experiment", help="实验名称") parser.add_argument("--resume", action="store_true", default=True, help="恢复训练") parser.add_argument("--debug", action="store_true", help="调试模式") parser.add_argument("--seed", type=int, default=42, help="随机种子") parser.add_argument("--scale_lr", action="store_true", help="根据GPU数量缩放学习率") parser.add_argument("--precision", type=str, default="32", choices=["16", "32", "bf16"], help="训练精度") args, unknown = parser.parse_known_args() # 设置随机种子 seed_everything(args.seed, workers=True) print(f"设置随机种子: {args.seed}") # 初始化配置管理器 try: config_manager = ConfigManager(args.config, unknown) config = config_manager.config except Exception as e: print(f"加载配置失败: {e}") sys.exit(1) # 创建日志目录 logging_config = config_manager.get_logging_config() logdir, ckptdir, cfgdir = create_experiment_directories(logging_config, args.name) # 保存配置 config_manager.save_config(os.path.join(cfgdir, "config.yaml")) # 配置日志记录器 tb_logger = TensorBoardLogger(os.path.join(logdir, "tensorboard")) # 配置回调函数 callbacks = setup_callbacks(config_manager, ckptdir, tb_logger) # 初始化数据模块 try: print("初始化数据模块...") data_config = config_manager.get_data_config() data_module = instantiate_from_config(data_config) data_module.setup() print("可用数据集:", list(data_module.datasets.keys())) except Exception as e: print(f"数据模块初始化失败: {str(e)}") return # 创建模型 try: model_config = config_manager.get_model_config() model_params = model_config.get("params", {}) # 创建模型实例 model = CustomLatentDiffusion(**model_config.get("params", {})) print("模型初始化成功") # 检查并转换预训练权重 ckpt_path = model_config.params.get("ckpt_path", "") if ckpt_path and os.path.exists(ckpt_path): print(f"加载预训练权重: {ckpt_path}") checkpoint = torch.load(ckpt_path, map_location="cpu") state_dict = checkpoint.get("state_dict", checkpoint) # 查找所有与conv_in.weight相关的键 conv_in_keys = [] for key in state_dict.keys(): if "conv_in.weight" in key and "first_stage_model" in key: conv_in_keys.append(key) # 转换找到的权重 for conv_in_key in conv_in_keys: if state_dict[conv_in_key].shape[1] == 3: # 原始是3通道 print(f"转换权重: {conv_in_key} 从3通道到1通道") # 取RGB三通道的平均值作为单通道权重 rgb_weights = state_dict[conv_in_key] ir_weights = rgb_weights.mean(dim=1, keepdim=True) state_dict[conv_in_key] = ir_weights print(f"转换前形状: {rgb_weights.shape}") print(f"转换后形状: {ir_weights.shape}") print(f"模型层形状: {model.first_stage_model.encoder.conv_in.weight.shape}") # 非严格模式加载(允许其他层不匹配) missing, unexpected = model.load_state_dict(state_dict, strict=False) print(f"权重加载完成: 缺失层 {len(missing)}, 不匹配层 {len(unexpected)}") if missing: print("缺失层:", missing) if unexpected: print("意外层:", unexpected) except Exception as e: print(f"模型初始化失败: {str(e)}") return print("VAE输入层形状:", model.first_stage_model.encoder.conv_in.weight.shape) # 权重转换 if ckpt_path and os.path.exists(ckpt_path): print(f"加载预训练权重: {ckpt_path}") checkpoint = torch.load(ckpt_path, map_location="cpu") state_dict = checkpoint.get("state_dict", checkpoint) # 增强:查找所有需要转换的层(包括可能的变体) conversion_keys = [] for key in state_dict.keys(): if "conv_in" in key or "conv_out" in key or "nin_shortcut" in key: if state_dict[key].ndim == 4 and state_dict[key].shape[1] == 3: conversion_keys.append(key) print(f"找到需要转换的层: {conversion_keys}") # 转换权重 for key in conversion_keys: print(f"转换权重: {key}") print(f"原始形状: {state_dict[key].shape}") # RGB权重 [out_c, in_c=3, kH, kW] rgb_weights = state_dict[key] # 转换为单通道权重 [out_c, 1, kH, kW] if rgb_weights.shape[1] == 3: ir_weights = rgb_weights.mean(dim=1, keepdim=True) state_dict[key] = ir_weights print(f"转换后形状: {state_dict[key].shape}") # 加载转换后的权重 try: # 使用非严格模式加载 missing, unexpected = model.load_state_dict(state_dict, strict=False) print(f"权重加载完成: 缺失层 {len(missing)}, 不匹配层 {len(unexpected)}") # 打印重要信息 if missing: print("缺失层:", missing[:5]) # 只显示前5个避免过多输出 if unexpected: print("意外层:", unexpected[:5]) # 特别检查conv_in层 if "first_stage_model.encoder.conv_in.weight" in missing: print("警告: conv_in.weight未加载,需要手动初始化") # 手动初始化单通道卷积层 with torch.no_grad(): model.first_stage_model.encoder.conv_in.weight.data.normal_(mean=0.0, std=0.02) print("已手动初始化conv_in.weight") except RuntimeError as e: print(f"加载权重时出错: {e}") print("尝试仅加载兼容的权重...") # 创建新的状态字典只包含兼容的键 model_state = model.state_dict() compatible_dict = {} for k, v in state_dict.items(): if k in model_state and v.shape == model_state[k].shape: compatible_dict[k] = v # 加载兼容的权重 model.load_state_dict(compatible_dict, strict=False) print(f"部分权重加载完成: {len(compatible_dict)}/{len(state_dict)}") # 配置学习率 training_config = config_manager.get_training_config() bs = data_config.params.batch_size base_lr = model_config.base_learning_rate ngpu = training_config.get("gpus", 1) accumulate_grad_batches = training_config.get("accumulate_grad_batches", 1) if args.scale_lr: model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr print(f"学习率缩放至: {model.learning_rate:.2e} = {accumulate_grad_batches} × {ngpu} × {bs} × {base_lr:.2e}") else: model.learning_rate = base_lr print(f"使用基础学习率: {model.learning_rate:.2e}") # 检查是否恢复训练 resume_from_checkpoint = None if args.resume: # 优先使用自动保存的last.ckpt last_ckpt = os.path.join(ckptdir, "last.ckpt") if os.path.exists(last_ckpt): print(f"恢复训练状态: {last_ckpt}") resume_from_checkpoint = last_ckpt else: # 回退到指定检查点 fallback_ckpt = os.path.join(current_dir, "checkpoints", "M3FD.ckpt") if os.path.exists(fallback_ckpt): print(f"警告: 使用仅含权重的检查点,训练状态将重置: {fallback_ckpt}") resume_from_checkpoint = fallback_ckpt else: print("未找到可用的检查点,从头开始训练") # 如果需要恢复训练,预处理检查点 if resume_from_checkpoint and os.path.exists(resume_from_checkpoint): try: # 预处理检查点 - 添加缺失的状态 checkpoint = preprocess_checkpoint(resume_from_checkpoint, model) # 创建新的完整检查点文件 fixed_ckpt_path = os.path.join(ckptdir, "fixed_checkpoint.ckpt") torch.save(checkpoint, fixed_ckpt_path) print(f"修复后的完整检查点已保存到: {fixed_ckpt_path}") # 使用修复后的检查点 resume_from_checkpoint = fixed_ckpt_path except Exception as e: print(f"预处理检查点失败: {e}") print("将尝试使用默认方式加载检查点") # 配置日志记录器 tb_logger = TensorBoardLogger(os.path.join(logdir, "tensorboard")) # 配置回调函数 callbacks = setup_callbacks(config_manager, ckptdir, tb_logger) # 检查是否有验证集 has_validation = hasattr(data_module, 'datasets') and 'validation' in data_module.datasets # 计算训练批次数 try: train_loader = data_module.train_dataloader() num_train_batches = len(train_loader) print(f"训练批次数: {num_train_batches}") except Exception as e: print(f"计算训练批次数失败: {e}") num_train_batches = 0 # 设置训练器参数(先设置基础参数) trainer_config = { "default_root_dir": logdir, "max_epochs": training_config.max_epochs, "gpus": ngpu, "distributed_backend": "ddp" if ngpu > 1 else None, "plugins": [DDPPlugin(find_unused_parameters=False)] if ngpu > 1 else None, "precision": 16, "accumulate_grad_batches": accumulate_grad_batches, "callbacks": callbacks, "logger": tb_logger, # 添加日志记录器 "resume_from_checkpoint": resume_from_checkpoint, "fast_dev_run": args.debug, "limit_val_batches": 0 if not has_validation else 1.0, "num_sanity_val_steps": 0, # 跳过初始验证加速恢复 "log_every_n_steps": 10 # 更频繁的日志记录 } # 动态调整验证配置 if has_validation: if num_train_batches < 50: # 小数据集:使用epoch验证 trainer_config["check_val_every_n_epoch"] = 1 # 确保移除步数验证参数 if "val_check_interval" in trainer_config: del trainer_config["val_check_interval"] else: # 大数据集:使用步数验证 val_check_interval = min(2000, num_train_batches) if num_train_batches < 100: val_check_interval = max(1, num_train_batches // 4) trainer_config["val_check_interval"] = val_check_interval # 创建训练器 try: print("最终训练器配置:") for k, v in trainer_config.items(): print(f" {k}: {v}") trainer = Trainer(**trainer_config) except Exception as e: print(f"创建训练器失败: {e}") tb_logger.close() sys.exit(1) # 执行训练 try: print("开始训练...") trainer.fit(model, data_module) print("训练完成!") except KeyboardInterrupt: print("训练被用户中断") if trainer.global_rank == 0 and trainer.model is not None: trainer.save_checkpoint(os.path.join(ckptdir, "interrupted.ckpt")) except Exception as e: print(f"训练出错: {e}") if trainer.global_rank == 0 and hasattr(trainer, 'model') and trainer.model is not None: trainer.save_checkpoint(os.path.join(ckptdir, "error.ckpt")) raise finally: # 关闭日志记录器 tb_logger.close() # 打印性能分析报告 if trainer.global_rank == 0 and hasattr(trainer, 'profiler'): print("训练摘要:") print(trainer.profiler.summary()) if __name__ == "__main__": main()运行报错:模型初始化失败: Error(s) in loading state_dict for CustomLatentDiffusion: size mismatch for first_stage_model.encoder.conv_in.weight: copying a param with shape torch.Size([128, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 1, 3, 3]).
最新发布
07-04
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

万物琴弦光锥之外

给个0.1,恭喜老板发财

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值