import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # 0=INFO, 1=WARNING, 2=ERROR, 3=FATAL
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' # 禁用oneDNN日志
import sys
import glob
import time
import json
import torch
import pickle
import shutil
import argparse
import datetime
import torchvision
import numpy as np
from tqdm import tqdm
from PIL import Image
import torch.nn as nn
from packaging import version
from functools import partial
import pytorch_lightning as pl
from omegaconf import OmegaConf, DictConfig
import torch.distributed as dist
from typing import List, Dict, Any, Optional, Union, Tuple
from ldm.util import instantiate_from_config
from pytorch_lightning import seed_everything
from pytorch_lightning.trainer import Trainer
from torch.utils.data import DataLoader, Dataset
from ldm.data.base import Txt2ImgIterableBaseDataset
from pytorch_lightning.plugins import DDPPlugin
from pytorch_lightning.utilities import rank_zero_info
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
from torch.cuda.amp import autocast, GradScaler
# 模型路径
current_dir = os.path.dirname(os.path.abspath(__file__))
for path in ["download", "download/CLIP", "download/k-diffusion",
"download/stable_diffusion", "download/taming-transformers"]:
sys.path.append(os.path.join(current_dir, path))
class ConfigManager:
"""配置管理类,统一处理配置加载和解析"""
def __init__(self, config_files: Union[str, List[str]], cli_args: Optional[List[str]] = None):
# 将单个字符串路径转换为列表
if isinstance(config_files, str):
config_files = [config_files]
# 验证配置文件存在
self.configs = []
for cfg in config_files:
if not os.path.exists(cfg):
raise FileNotFoundError(f"配置文件不存在: {cfg}")
self.configs.append(OmegaConf.load(cfg))
# 解析命令行参数
self.cli = OmegaConf.from_dotlist(cli_args) if cli_args else OmegaConf.create()
# 合并所有配置
self.config = OmegaConf.merge(*self.configs, self.cli)
def get_model_config(self) -> DictConfig:
"""获取模型配置"""
if "model" not in self.config:
raise KeyError("配置文件中缺少'model'部分")
return self.config.model
def get_data_config(self) -> DictConfig:
"""获取数据配置"""
if "data" not in self.config:
raise KeyError("配置文件中缺少'data'部分")
return self.config.data
def get_training_config(self) -> DictConfig:
"""获取训练配置,提供默认值"""
training_config = self.config.get("training", OmegaConf.create())
# 设置默认值
defaults = {
"max_epochs": 200,
"gpus": torch.cuda.device_count(),
"accumulate_grad_batches": 1,
"learning_rate": 1e-4,
"precision": 32
}
for key, value in defaults.items():
if key not in training_config:
training_config[key] = value
return training_config
def get_logging_config(self) -> DictConfig:
"""获取日志配置"""
return self.config.get("logging", OmegaConf.create({"logdir": "logs"}))
def get_callbacks_config(self) -> DictConfig:
"""获取回调函数配置"""
return self.config.get("callbacks", OmegaConf.create())
def save_config(self, save_path: str) -> None:
"""保存配置到文件"""
os.makedirs(os.path.dirname(save_path), exist_ok=True)
OmegaConf.save(self.config, save_path)
print(f"配置已保存到: {save_path}")
class DataModuleFromConfig(pl.LightningDataModule):
def __init__(self, batch_size, num_workers, train=None, validation=None, test=None):
super().__init__()
self.batch_size = batch_size
self.num_workers = num_workers
self.dataset_configs = dict()
if train is not None:
self.dataset_configs["train"] = train
if validation is not None:
self.dataset_configs["validation"] = validation
if test is not None:
self.dataset_configs["test"] = test
def setup(self, stage=None):
self.datasets = {
k: instantiate_from_config(cfg)
for k, cfg in self.dataset_configs.items()
}
def _get_dataloader(self, dataset_name, shuffle=False):
dataset = self.datasets.get(dataset_name)
if dataset is None:
raise ValueError(f"数据集 {dataset_name} 未配置")
return DataLoader(
dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=shuffle,
pin_memory=True
)
def train_dataloader(self):
return self._get_dataloader("train", shuffle=True)
def val_dataloader(self):
return self._get_dataloader("validation")
def test_dataloader(self):
return self._get_dataloader("test")
def worker_init_fn(worker_id: int) -> None:
"""数据加载器工作进程初始化函数"""
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
return
dataset = worker_info.dataset
worker_id = worker_info.id
if isinstance(dataset, Txt2ImgIterableBaseDataset):
# 对可迭代数据集进行分片
split_size = dataset.num_records // worker_info.num_workers
dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size]
# 设置随机种子
seed = torch.initial_seed() % 2**32 + worker_id
np.random.seed(seed)
torch.manual_seed(seed)
class EnhancedImageLogger(Callback):
"""增强的图像日志记录器,支持多平台日志输出"""
def __init__(self, batch_frequency: int, max_images: int, clamp: bool = True, rescale: bool = True,
loggers: Optional[List] = None, log_first_step: bool = False,
log_images_kwargs: Optional[Dict] = None):
super().__init__()
self.batch_frequency = max(1, batch_frequency)
self.max_images = max_images
self.clamp = clamp
self.rescale = rescale
self.loggers = loggers or []
self.log_first_step = log_first_step
self.log_images_kwargs = log_images_kwargs or {}
self.log_steps = [2 ** n for n in range(6, int(np.log2(self.batch_frequency)) + 1)] if self.batch_frequency > 1 else []
def check_frequency(self, step: int) -> bool:
"""检查是否达到记录频率"""
if step == 0 and self.log_first_step:
return True
if step % self.batch_frequency == 0:
return True
if step in self.log_steps:
if len(self.log_steps) > 0:
self.log_steps.pop(0)
return True
return False
def log_images(self, pl_module: pl.LightningModule, batch: Any, step: int, split: str = "train") -> None:
"""记录图像并发送到所有日志记录器"""
if not self.check_frequency(step) or not hasattr(pl_module, "log_images"):
return
is_train = pl_module.training
if is_train:
pl_module.eval() # 切换到评估模式
with torch.no_grad():
try:
images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
except Exception as e:
print(f"记录图像时出错: {e}")
images = {}
# 处理图像数据
for k in list(images.keys()):
if not isinstance(images[k], torch.Tensor):
continue
N = min(images[k].shape[0], self.max_images)
images[k] = images[k][:N]
# 分布式环境下收集所有图像
if torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1:
images[k] = torch.cat(all_gather(images[k]))
images[k] = images[k].detach().cpu()
if self.clamp:
images[k] = torch.clamp(images[k], -1., 1.)
if self.rescale:
images[k] = (images[k] + 1.0) / 2.0 # 缩放到[0,1]
# 发送到所有日志记录器
for logger in self.loggers:
if hasattr(logger, 'log_images'):
try:
logger.log_images(images, step, split)
except Exception as e:
print(f"日志记录器 {type(logger).__name__} 记录图像失败: {e}")
if is_train:
pl_module.train() # 恢复训练模式
def on_train_batch_end(self, trainer: Trainer, pl_module: pl.LightningModule,
outputs: Any, batch: Any, batch_idx: int) -> None:
"""训练批次结束时记录图像"""
if trainer.global_step % trainer.log_every_n_steps == 0:
self.log_images(pl_module, batch, pl_module.global_step, "train")
def on_validation_batch_end(self, trainer: Trainer, pl_module: pl.LightningModule,
outputs: Any, batch: Any, batch_idx: int) -> None:
"""验证批次结束时记录图像"""
if batch_idx == 0: # 只记录第一个验证批次
self.log_images(pl_module, batch, pl_module.global_step, "val")
class TensorBoardLogger:
"""TensorBoard日志记录器,完整实现PyTorch Lightning日志记录器接口"""
def __init__(self, save_dir: str):
from torch.utils.tensorboard import SummaryWriter
os.makedirs(save_dir, exist_ok=True)
self.save_dir = save_dir
self.writer = SummaryWriter(save_dir)
self._name = "TensorBoard" # 日志记录器名称
self._version = "1.0" # 版本信息
self._experiment = self.writer # 实验对象
print(f"TensorBoard日志保存在: {save_dir}")
@property
def name(self) -> str:
return self._name
@property
def version(self) -> str:
return self._version
@property
def experiment(self) -> Any:
return self._experiment
def log_hyperparams(self, params: Dict) -> None:
"""记录超参数到TensorBoard"""
try:
# 将嵌套字典展平
flat_params = {}
for key, value in params.items():
if isinstance(value, dict):
for sub_key, sub_value in value.items():
flat_params[f"{key}/{sub_key}"] = sub_value
else:
flat_params[key] = value
# 记录超参数
self.writer.add_hparams(
{k: v for k, v in flat_params.items() if isinstance(v, (int, float, str))},
{},
run_name="."
)
print("已记录超参数到TensorBoard")
except Exception as e:
print(f"记录超参数失败: {e}")
def log_graph(self, model: torch.nn.Module, input_array: Optional[torch.Tensor] = None) -> None:
"""记录模型计算图到TensorBoard"""
try:
# 扩散模型通常有复杂的前向传播,跳过图记录
print("跳过扩散模型的计算图记录")
return
except Exception as e:
print(f"记录模型计算图失败: {e}")
def log_metrics(self, metrics: Dict[str, float], step: int) -> None:
"""记录指标到TensorBoard"""
for name, value in metrics.items():
try:
self.writer.add_scalar(name, value, global_step=step)
except Exception as e:
print(f"添加标量失败: {name}, 错误: {e}")
def log_images(self, images: Dict[str, torch.Tensor], step: int, split: str) -> None:
"""记录图像到TensorBoard"""
for k, img in images.items():
if img.numel() == 0:
continue
try:
grid = torchvision.utils.make_grid(img, nrow=min(8, img.shape[0]))
self.writer.add_image(f"{split}/{k}", grid, global_step=step)
except Exception as e:
print(f"添加图像失败: {k}, 错误: {e}")
def save(self) -> None:
"""保存日志(TensorBoard自动保存,这里无需额外操作)"""
pass
def finalize(self, status: str) -> None:
"""完成日志记录并关闭写入器"""
self.close()
def close(self) -> None:
"""关闭日志写入器"""
if hasattr(self, 'writer') and self.writer is not None:
self.writer.flush()
self.writer.close()
self.writer = None
print(f"TensorBoard日志已关闭")
class TQDMProgressBar(Callback):
"""使用tqdm显示训练进度,兼容不同版本的PyTorch Lightning"""
def __init__(self):
self.progress_bar = None
self.epoch_bar = None
def on_train_start(self, trainer: Trainer, pl_module: pl.LightningModule) -> None:
"""训练开始时初始化进度条"""
# 兼容不同版本的步数估计
total_steps = self._get_total_steps(trainer)
self.progress_bar = tqdm(
total=total_steps,
desc="Training Steps",
position=0,
leave=True,
dynamic_ncols=True
)
self.epoch_bar = tqdm(
total=trainer.max_epochs,
desc="Epochs",
position=1,
leave=True,
dynamic_ncols=True
)
def _get_total_steps(self, trainer: Trainer) -> int:
"""获取训练总步数,兼容不同版本的PyTorch Lightning"""
# 尝试使用新版本属性
if hasattr(trainer, 'estimated_stepping_batches'):
return trainer.estimated_stepping_batches
# 尝试使用旧版本属性
if hasattr(trainer, 'estimated_steps'):
return trainer.estimated_steps
# 回退到手动计算
try:
if hasattr(trainer, 'num_training_batches'):
num_batches = trainer.num_training_batches
else:
num_batches = len(trainer.train_dataloader)
if hasattr(trainer, 'accumulate_grad_batches'):
accumulate = trainer.accumulate_grad_batches
else:
accumulate = 1
steps_per_epoch = num_batches // accumulate
total_steps = trainer.max_epochs * steps_per_epoch
print(f"回退计算训练总步数: {total_steps} = {trainer.max_epochs} epochs × {steps_per_epoch} steps/epoch")
return total_steps
except Exception as e:
print(f"无法确定训练总步数: {e}, 使用默认值10000")
return 10000
def on_train_batch_end(self, trainer: Trainer, pl_module: pl.LightningModule,
outputs: Any, batch: Any, batch_idx: int) -> None:
"""每个训练批次结束时更新进度条"""
if self.progress_bar:
# 防止进度条超过总步数
if self.progress_bar.n < self.progress_bar.total:
self.progress_bar.update(1)
try:
# 尝试从输出中获取损失
loss = outputs.get('loss')
if loss is not None:
if isinstance(loss, torch.Tensor):
loss = loss.item()
self.progress_bar.set_postfix({"loss": loss})
except Exception:
pass
def on_train_epoch_end(self, trainer: Trainer, pl_module: pl.LightningModule) -> None:
"""每个训练轮次结束时更新轮次进度条"""
if self.epoch_bar:
self.epoch_bar.update(1)
self.epoch_bar.set_postfix({"epoch": trainer.current_epoch})
def on_train_end(self, trainer: Trainer, pl_module: pl.LightningModule) -> None:
"""训练结束时关闭进度条"""
if self.progress_bar:
self.progress_bar.close()
if self.epoch_bar:
self.epoch_bar.close()
class PerformanceMonitor(Callback):
"""性能监控回调,记录内存使用和训练速度"""
def __init__(self):
self.epoch_start_time = 0
self.batch_times = []
def on_train_epoch_start(self, trainer: Trainer, pl_module: pl.LightningModule) -> None:
"""每个训练轮次开始时记录时间和重置内存统计"""
self.epoch_start_time = time.time()
self.batch_times = []
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
# 修改1:添加dataloader_idx参数
def on_train_batch_start(self, trainer: Trainer, pl_module: pl.LightningModule,
batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
"""每个训练批次开始时记录时间"""
self.batch_start_time = time.time()
# 修改2:添加dataloader_idx参数
def on_train_batch_end(self, trainer: Trainer, pl_module: pl.LightningModule,
outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
"""每个训练批次结束时记录时间"""
self.batch_times.append(time.time() - self.batch_start_time)
def on_train_epoch_end(self, trainer: Trainer, pl_module: pl.LightningModule) -> None:
"""每个训练轮次结束时计算并记录性能指标"""
epoch_time = time.time() - self.epoch_start_time
if self.batch_times:
avg_batch_time = sum(self.batch_times) / len(self.batch_times)
batches_per_second = 1.0 / avg_batch_time
else:
avg_batch_time = 0
batches_per_second = 0
memory_info = ""
if torch.cuda.is_available():
max_memory = torch.cuda.max_memory_allocated() / 2 ** 20 # MiB
memory_info = f", 峰值显存: {max_memory:.2f} MiB"
rank_zero_info(
f"Epoch {trainer.current_epoch} | "
f"耗时: {epoch_time:.2f}s | "
f"Batch耗时: {avg_batch_time:.4f}s ({batches_per_second:.2f} batches/s)"
f"{memory_info}"
)
def get_world_size() -> int:
"""获取分布式训练中的总进程数"""
if dist.is_initialized():
return dist.get_world_size()
return 1
def all_gather(data: torch.Tensor) -> List[torch.Tensor]:
"""在分布式环境中收集所有进程的数据"""
world_size = get_world_size()
if world_size == 1:
return [data]
# 获取各进程的Tensor大小
local_size = torch.tensor([data.numel()], device=data.device)
size_list = [torch.zeros_like(local_size) for _ in range(world_size)]
dist.all_gather(size_list, local_size)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
# 收集数据
tensor_list = []
for size in size_list:
tensor_list.append(torch.empty((max_size,), dtype=data.dtype, device=data.device))
if local_size < max_size:
padding = torch.zeros(max_size - local_size, dtype=data.dtype, device=data.device)
data = torch.cat((data.view(-1), padding))
dist.all_gather(tensor_list, data.view(-1))
# 截断到实际大小
results = []
for tensor, size in zip(tensor_list, size_list):
results.append(tensor[:size].reshape(data.shape))
return results
def create_experiment_directories(logging_config: DictConfig, experiment_name: str) -> Tuple[str, str, str]:
"""创建实验目录结构"""
now = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
logdir = os.path.join(logging_config.logdir, f"{experiment_name}_{now}")
ckptdir = os.path.join(logdir, "checkpoints")
cfgdir = os.path.join(logdir, "configs")
os.makedirs(ckptdir, exist_ok=True)
os.makedirs(cfgdir, exist_ok=True)
print(f"实验目录: {logdir}")
print(f"检查点目录: {ckptdir}")
print(f"配置目录: {cfgdir}")
return logdir, ckptdir, cfgdir
def setup_callbacks(config_manager: ConfigManager, ckptdir: str, tb_logger: TensorBoardLogger) -> List[Callback]:
"""设置训练回调函数"""
callbacks = []
# 模型检查点
checkpoint_callback = ModelCheckpoint(
dirpath=ckptdir,
filename='{epoch}-{step}-{val_loss:.2f}',
monitor='val_loss',
save_top_k=3,
mode='min',
save_last=True,
save_on_train_epoch_end=True, # 确保在epoch结束时保存完整状态
save_weights_only=False, # 明确设置为False,保存完整检查点
every_n_train_steps=1000 # 每1000步保存一次
)
callbacks.append(checkpoint_callback)
# 学习率监控
lr_monitor = LearningRateMonitor(logging_interval="step")
callbacks.append(lr_monitor)
# 图像日志记录
image_logger_cfg = config_manager.get_callbacks_config().get("image_logger", {})
image_logger = EnhancedImageLogger(
batch_frequency=image_logger_cfg.get("batch_frequency", 500),
max_images=image_logger_cfg.get("max_images", 4),
loggers=[tb_logger]
)
callbacks.append(image_logger)
# 进度条
progress_bar = TQDMProgressBar()
callbacks.append(progress_bar)
# 性能监控
perf_monitor = PerformanceMonitor()
callbacks.append(perf_monitor)
return callbacks
def preprocess_checkpoint(checkpoint_path: str, model: pl.LightningModule) -> Dict[str, Any]:
"""预处理检查点文件,确保包含所有必要的键,并添加缺失的训练状态"""
print(f"预处理检查点文件: {checkpoint_path}")
# 加载检查点
try:
checkpoint = torch.load(checkpoint_path, map_location="cpu")
except Exception as e:
print(f"加载检查点失败: {e}")
raise
# 强制重置训练状态
checkpoint['epoch'] = 0
checkpoint['global_step'] = 0
checkpoint['lr_schedulers'] = []
checkpoint['optimizer_states'] = []
print("已重置训练状态: epoch=0, global_step=0")
# 检查是否缺少关键训练状态
required_keys = ['optimizer_states', 'lr_schedulers', 'epoch', 'global_step']
missing_keys = [k for k in required_keys if k not in checkpoint]
if missing_keys:
print(f"警告: 检查点缺少训练状态字段 {missing_keys},将创建伪训练状态")
# 创建伪训练状态
checkpoint.setdefault('optimizer_states', [])
checkpoint.setdefault('lr_schedulers', [])
checkpoint.setdefault('epoch', 0)
checkpoint.setdefault('global_step', 0)
# 检查是否缺少 position_ids
state_dict = checkpoint.get("state_dict", {})
if "cond_stage_model.transformer.text_model.embeddings.position_ids" not in state_dict:
print("警告: 检查点缺少 'cond_stage_model.transformer.text_model.embeddings.position_ids' 键")
# 获取模型中的 position_ids 形状
if hasattr(model, "cond_stage_model") and hasattr(model.cond_stage_model, "transformer"):
try:
max_position_embeddings = model.cond_stage_model.transformer.text_model.config.max_position_embeddings
position_ids = torch.arange(max_position_embeddings).expand((1, -1))
state_dict["cond_stage_model.transformer.text_model.embeddings.position_ids"] = position_ids
print("已添加 position_ids 到检查点")
except Exception as e:
print(f"无法添加 position_ids: {e}")
# 确保有 state_dict
if "state_dict" not in checkpoint:
checkpoint["state_dict"] = state_dict
return checkpoint
# 正确继承原始模型类
from ldm.models.diffusion.ddpm import LatentDiffusion
class CustomLatentDiffusion(LatentDiffusion):
"""自定义 LatentDiffusion 类,处理检查点加载问题"""
def on_load_checkpoint(self, checkpoint):
"""在加载检查点时自动处理缺失的键"""
state_dict = checkpoint["state_dict"]
# 检查是否缺少 position_ids
if "cond_stage_model.transformer.text_model.embeddings.position_ids" not in state_dict:
print("警告: 检查点缺少 'cond_stage_model.transformer.text_model.embeddings.position_ids' 键")
# 获取模型中的 position_ids 形状
max_position_embeddings = self.cond_stage_model.transformer.text_model.config.max_position_embeddings
position_ids = torch.arange(max_position_embeddings).expand((1, -1))
state_dict["cond_stage_model.transformer.text_model.embeddings.position_ids"] = position_ids
print("已添加 position_ids 到 state_dict")
# 使用非严格模式加载
self.load_state_dict(state_dict, strict=False)
print("模型权重加载完成")
def filter_kwargs(cls, kwargs, log_prefix=""):
# 关键参数白名单 - 这些参数必须保留
ESSENTIAL_PARAMS = {
'unet_config', 'first_stage_config', 'cond_stage_config',
'scheduler_config', 'ckpt_path', 'linear_start', 'linear_end'
}
# 特殊处理:允许所有包含"config"的参数
filtered_kwargs = {}
for k, v in kwargs.items():
if k in ESSENTIAL_PARAMS or 'config' in k:
filtered_kwargs[k] = v
else:
print(f"{log_prefix}过滤参数: {k}")
print(f"{log_prefix}保留参数: {list(filtered_kwargs.keys())}")
return filtered_kwargs
def check_checkpoint_content(checkpoint_path):
"""打印检查点包含的键,确认是否有训练状态"""
checkpoint = torch.load(checkpoint_path, map_location="cpu")
print("检查点包含的键:", list(checkpoint.keys()))
if "state_dict" in checkpoint:
print("模型权重存在")
if "optimizer_states" in checkpoint:
print("优化器状态存在")
if "epoch" in checkpoint:
print(f"保存的epoch: {checkpoint['epoch']}")
if "global_step" in checkpoint:
print(f"保存的global_step: {checkpoint['global_step']}")
def main() -> None:
"""主函数,训练和推理流程的入口点"""
# 启用Tensor Core加速
torch.set_float32_matmul_precision('high')
# 解析命令行参数
parser = argparse.ArgumentParser(description="扩散模型训练框架")
parser.add_argument("--config", type=str, default="configs/train.yaml", help="配置文件路径")
parser.add_argument("--name", type=str, default="experiment", help="实验名称")
parser.add_argument("--resume", action="store_true", default=True, help="恢复训练")
parser.add_argument("--debug", action="store_true", help="调试模式")
parser.add_argument("--seed", type=int, default=42, help="随机种子")
parser.add_argument("--scale_lr", action="store_true", help="根据GPU数量缩放学习率")
parser.add_argument("--precision", type=str, default="32", choices=["16", "32", "bf16"], help="训练精度")
args, unknown = parser.parse_known_args()
# 设置随机种子
seed_everything(args.seed, workers=True)
print(f"设置随机种子: {args.seed}")
# 初始化配置管理器
try:
config_manager = ConfigManager(args.config, unknown)
config = config_manager.config
except Exception as e:
print(f"加载配置失败: {e}")
sys.exit(1)
# 创建日志目录
logging_config = config_manager.get_logging_config()
logdir, ckptdir, cfgdir = create_experiment_directories(logging_config, args.name)
# 保存配置
config_manager.save_config(os.path.join(cfgdir, "config.yaml"))
# 配置日志记录器
tb_logger = TensorBoardLogger(os.path.join(logdir, "tensorboard"))
# 配置回调函数
callbacks = setup_callbacks(config_manager, ckptdir, tb_logger)
# 初始化数据模块
try:
print("初始化数据模块...")
data_config = config_manager.get_data_config()
data_module = instantiate_from_config(data_config)
data_module.setup()
print("可用数据集:", list(data_module.datasets.keys()))
except Exception as e:
print(f"数据模块初始化失败: {str(e)}")
return
# 创建模型
try:
model_config = config_manager.get_model_config()
model_params = model_config.get("params", {})
# 创建模型实例
model = CustomLatentDiffusion(**model_config.get("params", {}))
print("模型初始化成功")
# 检查并转换预训练权重
ckpt_path = model_config.params.get("ckpt_path", "")
if ckpt_path and os.path.exists(ckpt_path):
print(f"加载预训练权重: {ckpt_path}")
checkpoint = torch.load(ckpt_path, map_location="cpu")
state_dict = checkpoint.get("state_dict", checkpoint)
# 查找所有与conv_in.weight相关的键
conv_in_keys = []
for key in state_dict.keys():
if "conv_in.weight" in key and "first_stage_model" in key:
conv_in_keys.append(key)
# 转换找到的权重
for conv_in_key in conv_in_keys:
if state_dict[conv_in_key].shape[1] == 3: # 原始是3通道
print(f"转换权重: {conv_in_key} 从3通道到1通道")
# 取RGB三通道的平均值作为单通道权重
rgb_weights = state_dict[conv_in_key]
ir_weights = rgb_weights.mean(dim=1, keepdim=True)
state_dict[conv_in_key] = ir_weights
print(f"转换前形状: {rgb_weights.shape}")
print(f"转换后形状: {ir_weights.shape}")
print(f"模型层形状: {model.first_stage_model.encoder.conv_in.weight.shape}")
# 非严格模式加载(允许其他层不匹配)
missing, unexpected = model.load_state_dict(state_dict, strict=False)
print(f"权重加载完成: 缺失层 {len(missing)}, 不匹配层 {len(unexpected)}")
if missing:
print("缺失层:", missing)
if unexpected:
print("意外层:", unexpected)
except Exception as e:
print(f"模型初始化失败: {str(e)}")
return
print("VAE输入层形状:", model.first_stage_model.encoder.conv_in.weight.shape)
# 权重转换
if ckpt_path and os.path.exists(ckpt_path):
print(f"加载预训练权重: {ckpt_path}")
checkpoint = torch.load(ckpt_path, map_location="cpu")
state_dict = checkpoint.get("state_dict", checkpoint)
# 增强:查找所有需要转换的层(包括可能的变体)
conversion_keys = []
for key in state_dict.keys():
if "conv_in" in key or "conv_out" in key or "nin_shortcut" in key:
if state_dict[key].ndim == 4 and state_dict[key].shape[1] == 3:
conversion_keys.append(key)
print(f"找到需要转换的层: {conversion_keys}")
# 转换权重
for key in conversion_keys:
print(f"转换权重: {key}")
print(f"原始形状: {state_dict[key].shape}")
# RGB权重 [out_c, in_c=3, kH, kW]
rgb_weights = state_dict[key]
# 转换为单通道权重 [out_c, 1, kH, kW]
if rgb_weights.shape[1] == 3:
ir_weights = rgb_weights.mean(dim=1, keepdim=True)
state_dict[key] = ir_weights
print(f"转换后形状: {state_dict[key].shape}")
# 加载转换后的权重
try:
# 使用非严格模式加载
missing, unexpected = model.load_state_dict(state_dict, strict=False)
print(f"权重加载完成: 缺失层 {len(missing)}, 不匹配层 {len(unexpected)}")
# 打印重要信息
if missing:
print("缺失层:", missing[:5]) # 只显示前5个避免过多输出
if unexpected:
print("意外层:", unexpected[:5])
# 特别检查conv_in层
if "first_stage_model.encoder.conv_in.weight" in missing:
print("警告: conv_in.weight未加载,需要手动初始化")
# 手动初始化单通道卷积层
with torch.no_grad():
model.first_stage_model.encoder.conv_in.weight.data.normal_(mean=0.0, std=0.02)
print("已手动初始化conv_in.weight")
except RuntimeError as e:
print(f"加载权重时出错: {e}")
print("尝试仅加载兼容的权重...")
# 创建新的状态字典只包含兼容的键
model_state = model.state_dict()
compatible_dict = {}
for k, v in state_dict.items():
if k in model_state and v.shape == model_state[k].shape:
compatible_dict[k] = v
# 加载兼容的权重
model.load_state_dict(compatible_dict, strict=False)
print(f"部分权重加载完成: {len(compatible_dict)}/{len(state_dict)}")
# 配置学习率
training_config = config_manager.get_training_config()
bs = data_config.params.batch_size
base_lr = model_config.base_learning_rate
ngpu = training_config.get("gpus", 1)
accumulate_grad_batches = training_config.get("accumulate_grad_batches", 1)
if args.scale_lr:
model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
print(f"学习率缩放至: {model.learning_rate:.2e} = {accumulate_grad_batches} × {ngpu} × {bs} × {base_lr:.2e}")
else:
model.learning_rate = base_lr
print(f"使用基础学习率: {model.learning_rate:.2e}")
# 检查是否恢复训练
resume_from_checkpoint = None
if args.resume:
# 优先使用自动保存的last.ckpt
last_ckpt = os.path.join(ckptdir, "last.ckpt")
if os.path.exists(last_ckpt):
print(f"恢复训练状态: {last_ckpt}")
resume_from_checkpoint = last_ckpt
else:
# 回退到指定检查点
fallback_ckpt = os.path.join(current_dir, "checkpoints", "M3FD.ckpt")
if os.path.exists(fallback_ckpt):
print(f"警告: 使用仅含权重的检查点,训练状态将重置: {fallback_ckpt}")
resume_from_checkpoint = fallback_ckpt
else:
print("未找到可用的检查点,从头开始训练")
# 如果需要恢复训练,预处理检查点
if resume_from_checkpoint and os.path.exists(resume_from_checkpoint):
try:
# 预处理检查点 - 添加缺失的状态
checkpoint = preprocess_checkpoint(resume_from_checkpoint, model)
# 创建新的完整检查点文件
fixed_ckpt_path = os.path.join(ckptdir, "fixed_checkpoint.ckpt")
torch.save(checkpoint, fixed_ckpt_path)
print(f"修复后的完整检查点已保存到: {fixed_ckpt_path}")
# 使用修复后的检查点
resume_from_checkpoint = fixed_ckpt_path
except Exception as e:
print(f"预处理检查点失败: {e}")
print("将尝试使用默认方式加载检查点")
# 配置日志记录器
tb_logger = TensorBoardLogger(os.path.join(logdir, "tensorboard"))
# 配置回调函数
callbacks = setup_callbacks(config_manager, ckptdir, tb_logger)
# 检查是否有验证集
has_validation = hasattr(data_module, 'datasets') and 'validation' in data_module.datasets
# 计算训练批次数
try:
train_loader = data_module.train_dataloader()
num_train_batches = len(train_loader)
print(f"训练批次数: {num_train_batches}")
except Exception as e:
print(f"计算训练批次数失败: {e}")
num_train_batches = 0
# 设置训练器参数(先设置基础参数)
trainer_config = {
"default_root_dir": logdir,
"max_epochs": training_config.max_epochs,
"gpus": ngpu,
"distributed_backend": "ddp" if ngpu > 1 else None,
"plugins": [DDPPlugin(find_unused_parameters=False)] if ngpu > 1 else None,
"precision": 16,
"accumulate_grad_batches": accumulate_grad_batches,
"callbacks": callbacks,
"logger": tb_logger, # 添加日志记录器
"resume_from_checkpoint": resume_from_checkpoint,
"fast_dev_run": args.debug,
"limit_val_batches": 0 if not has_validation else 1.0,
"num_sanity_val_steps": 0, # 跳过初始验证加速恢复
"log_every_n_steps": 10 # 更频繁的日志记录
}
# 动态调整验证配置
if has_validation:
if num_train_batches < 50:
# 小数据集:使用epoch验证
trainer_config["check_val_every_n_epoch"] = 1
# 确保移除步数验证参数
if "val_check_interval" in trainer_config:
del trainer_config["val_check_interval"]
else:
# 大数据集:使用步数验证
val_check_interval = min(2000, num_train_batches)
if num_train_batches < 100:
val_check_interval = max(1, num_train_batches // 4)
trainer_config["val_check_interval"] = val_check_interval
# 创建训练器
try:
print("最终训练器配置:")
for k, v in trainer_config.items():
print(f" {k}: {v}")
trainer = Trainer(**trainer_config)
except Exception as e:
print(f"创建训练器失败: {e}")
tb_logger.close()
sys.exit(1)
# 执行训练
try:
print("开始训练...")
trainer.fit(model, data_module)
print("训练完成!")
except KeyboardInterrupt:
print("训练被用户中断")
if trainer.global_rank == 0 and trainer.model is not None:
trainer.save_checkpoint(os.path.join(ckptdir, "interrupted.ckpt"))
except Exception as e:
print(f"训练出错: {e}")
if trainer.global_rank == 0 and hasattr(trainer, 'model') and trainer.model is not None:
trainer.save_checkpoint(os.path.join(ckptdir, "error.ckpt"))
raise
finally:
# 关闭日志记录器
tb_logger.close()
# 打印性能分析报告
if trainer.global_rank == 0 and hasattr(trainer, 'profiler'):
print("训练摘要:")
print(trainer.profiler.summary())
if __name__ == "__main__":
main()运行报错:模型初始化失败: Error(s) in loading state_dict for CustomLatentDiffusion:
size mismatch for first_stage_model.encoder.conv_in.weight: copying a param with shape torch.Size([128, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 1, 3, 3]).
最新发布