tf.accumulate_n()

本文介绍了 TensorFlow 中的 tf.accumulate_n() 函数,详细解释了如何使用该函数对多个张量进行累加操作。通过具体示例展示了如何对实数和复数类型的张量进行累加,并给出了会话运行的具体输出结果。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

tf.accumulate_n() 对值求和
传参 :

tf.accumulate_n(
    inputs,#输入数据 数据需要封装成list 数据的秩要一致
    shape=None,#输出的矩阵模式
    tensor_dtype=None,#数据类型
    name=None
)

使用案例:

import tensorflow as tf

x = tf.constant(-1)
y = tf.constant([-1 + 2j, -2])
sess = tf.Session()
print(sess.run(tf.accumulate_n([x, x, x])))#-3
print(sess.run(tf.accumulate_n([y, y, y])))#[-3.+6.j,-6.+0.j]
import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # 0=INFO, 1=WARNING, 2=ERROR, 3=FATAL os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' # 禁用oneDNN日志 import sys import glob import time import json import torch import pickle import shutil import argparse import datetime import torchvision import numpy as np from tqdm import tqdm from PIL import Image import torch.nn as nn from packaging import version from functools import partial import pytorch_lightning as pl from omegaconf import OmegaConf, DictConfig import torch.distributed as dist from typing import List, Dict, Any, Optional, Union, Tuple from ldm.util import instantiate_from_config from pytorch_lightning import seed_everything from pytorch_lightning.trainer import Trainer from torch.utils.data import DataLoader, Dataset from ldm.data.base import Txt2ImgIterableBaseDataset from pytorch_lightning.plugins import DDPPlugin from pytorch_lightning.utilities import rank_zero_info from pytorch_lightning.utilities.distributed import rank_zero_only from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor from torch.cuda.amp import autocast, GradScaler # 模型路径 current_dir = os.path.dirname(os.path.abspath(__file__)) for path in ["download", "download/CLIP", "download/k-diffusion", "download/stable_diffusion", "download/taming-transformers"]: sys.path.append(os.path.join(current_dir, path)) class ConfigManager: """配置管理类,统一处理配置加载和解析""" def __init__(self, config_files: Union[str, List[str]], cli_args: Optional[List[str]] = None): # 将单个字符串路径转换为列表 if isinstance(config_files, str): config_files = [config_files] # 验证配置文件存在 self.configs = [] for cfg in config_files: if not os.path.exists(cfg): raise FileNotFoundError(f"配置文件不存在: {cfg}") self.configs.append(OmegaConf.load(cfg)) # 解析命令行参数 self.cli = OmegaConf.from_dotlist(cli_args) if cli_args else OmegaConf.create() # 合并所有配置 self.config = OmegaConf.merge(*self.configs, self.cli) def get_model_config(self) -> DictConfig: """获取模型配置""" if "model" not in self.config: raise KeyError("配置文件中缺少'model'部分") return self.config.model def get_data_config(self) -> DictConfig: """获取数据配置""" if "data" not in self.config: raise KeyError("配置文件中缺少'data'部分") return self.config.data def get_training_config(self) -> DictConfig: """获取训练配置,提供默认值""" training_config = self.config.get("training", OmegaConf.create()) # 设置默认值 defaults = { "max_epochs": 200, "gpus": torch.cuda.device_count(), "accumulate_grad_batches": 1, "learning_rate": 1e-4, "precision": 32 } for key, value in defaults.items(): if key not in training_config: training_config[key] = value return training_config def get_logging_config(self) -> DictConfig: """获取日志配置""" return self.config.get("logging", OmegaConf.create({"logdir": "logs"})) def get_callbacks_config(self) -> DictConfig: """获取回调函数配置""" return self.config.get("callbacks", OmegaConf.create()) def save_config(self, save_path: str) -> None: """保存配置到文件""" os.makedirs(os.path.dirname(save_path), exist_ok=True) OmegaConf.save(self.config, save_path) print(f"配置已保存到: {save_path}") class DataModuleFromConfig(pl.LightningDataModule): def __init__(self, batch_size, num_workers, train=None, validation=None, test=None): super().__init__() self.batch_size = batch_size self.num_workers = num_workers self.dataset_configs = dict() if train is not None: self.dataset_configs["train"] = train if validation is not None: self.dataset_configs["validation"] = validation if test is not None: self.dataset_configs["test"] = test def setup(self, stage=None): self.datasets = { k: instantiate_from_config(cfg) for k, cfg in self.dataset_configs.items() } def _get_dataloader(self, dataset_name, shuffle=False): dataset = self.datasets.get(dataset_name) if dataset is None: raise ValueError(f"数据集 {dataset_name} 未配置") return DataLoader( dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=shuffle, pin_memory=True ) def train_dataloader(self): return self._get_dataloader("train", shuffle=True) def val_dataloader(self): return self._get_dataloader("validation") def test_dataloader(self): return self._get_dataloader("test") def worker_init_fn(worker_id: int) -> None: """数据加载器工作进程初始化函数""" worker_info = torch.utils.data.get_worker_info() if worker_info is None: return dataset = worker_info.dataset worker_id = worker_info.id if isinstance(dataset, Txt2ImgIterableBaseDataset): # 对可迭代数据集进行分片 split_size = dataset.num_records // worker_info.num_workers dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size] # 设置随机种子 seed = torch.initial_seed() % 2**32 + worker_id np.random.seed(seed) torch.manual_seed(seed) class EnhancedImageLogger(Callback): """增强的图像日志记录器,支持多平台日志输出""" def __init__(self, batch_frequency: int, max_images: int, clamp: bool = True, rescale: bool = True, loggers: Optional[List] = None, log_first_step: bool = False, log_images_kwargs: Optional[Dict] = None): super().__init__() self.batch_frequency = max(1, batch_frequency) self.max_images = max_images self.clamp = clamp self.rescale = rescale self.loggers = loggers or [] self.log_first_step = log_first_step self.log_images_kwargs = log_images_kwargs or {} self.log_steps = [2 ** n for n in range(6, int(np.log2(self.batch_frequency)) + 1)] if self.batch_frequency > 1 else [] def check_frequency(self, step: int) -> bool: """检查是否达到记录频率""" if step == 0 and self.log_first_step: return True if step % self.batch_frequency == 0: return True if step in self.log_steps: if len(self.log_steps) > 0: self.log_steps.pop(0) return True return False def log_images(self, pl_module: pl.LightningModule, batch: Any, step: int, split: str = "train") -> None: """记录图像并发送到所有日志记录器""" if not self.check_frequency(step) or not hasattr(pl_module, "log_images"): return is_train = pl_module.training if is_train: pl_module.eval() # 切换到评估模式 with torch.no_grad(): try: images = pl_module.log_images(batch, split=split, **self.log_images_kwargs) except Exception as e: print(f"记录图像时出错: {e}") images = {} # 处理图像数据 for k in list(images.keys()): if not isinstance(images[k], torch.Tensor): continue N = min(images[k].shape[0], self.max_images) images[k] = images[k][:N] # 分布式环境下收集所有图像 if torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1: images[k] = torch.cat(all_gather(images[k])) images[k] = images[k].detach().cpu() if self.clamp: images[k] = torch.clamp(images[k], -1., 1.) if self.rescale: images[k] = (images[k] + 1.0) / 2.0 # 缩放到[0,1] # 发送到所有日志记录器 for logger in self.loggers: if hasattr(logger, 'log_images'): try: logger.log_images(images, step, split) except Exception as e: print(f"日志记录器 {type(logger).__name__} 记录图像失败: {e}") if is_train: pl_module.train() # 恢复训练模式 def on_train_batch_end(self, trainer: Trainer, pl_module: pl.LightningModule, outputs: Any, batch: Any, batch_idx: int) -> None: """训练批次结束时记录图像""" if trainer.global_step % trainer.log_every_n_steps == 0: self.log_images(pl_module, batch, pl_module.global_step, "train") def on_validation_batch_end(self, trainer: Trainer, pl_module: pl.LightningModule, outputs: Any, batch: Any, batch_idx: int) -> None: """验证批次结束时记录图像""" if batch_idx == 0: # 只记录第一个验证批次 self.log_images(pl_module, batch, pl_module.global_step, "val") class TensorBoardLogger: """TensorBoard日志记录器,完整实现PyTorch Lightning日志记录器接口""" def __init__(self, save_dir: str): from torch.utils.tensorboard import SummaryWriter os.makedirs(save_dir, exist_ok=True) self.save_dir = save_dir self.writer = SummaryWriter(save_dir) self._name = "TensorBoard" # 日志记录器名称 self._version = "1.0" # 版本信息 self._experiment = self.writer # 实验对象 print(f"TensorBoard日志保存在: {save_dir}") @property def name(self) -> str: return self._name @property def version(self) -> str: return self._version @property def experiment(self) -> Any: return self._experiment def log_hyperparams(self, params: Dict) -> None: """记录超参数到TensorBoard""" try: # 将嵌套字典展平 flat_params = {} for key, value in params.items(): if isinstance(value, dict): for sub_key, sub_value in value.items(): flat_params[f"{key}/{sub_key}"] = sub_value else: flat_params[key] = value # 记录超参数 self.writer.add_hparams( {k: v for k, v in flat_params.items() if isinstance(v, (int, float, str))}, {}, run_name="." ) print("已记录超参数到TensorBoard") except Exception as e: print(f"记录超参数失败: {e}") def log_graph(self, model: torch.nn.Module, input_array: Optional[torch.Tensor] = None) -> None: """记录模型计算图到TensorBoard""" try: # 扩散模型通常有复杂的前向传播,跳过图记录 print("跳过扩散模型的计算图记录") return except Exception as e: print(f"记录模型计算图失败: {e}") def log_metrics(self, metrics: Dict[str, float], step: int) -> None: """记录指标到TensorBoard""" for name, value in metrics.items(): try: self.writer.add_scalar(name, value, global_step=step) except Exception as e: print(f"添加标量失败: {name}, 错误: {e}") def log_images(self, images: Dict[str, torch.Tensor], step: int, split: str) -> None: """记录图像到TensorBoard""" for k, img in images.items(): if img.numel() == 0: continue try: grid = torchvision.utils.make_grid(img, nrow=min(8, img.shape[0])) self.writer.add_image(f"{split}/{k}", grid, global_step=step) except Exception as e: print(f"添加图像失败: {k}, 错误: {e}") def save(self) -> None: """保存日志(TensorBoard自动保存,这里无需额外操作)""" pass def finalize(self, status: str) -> None: """完成日志记录并关闭写入器""" self.close() def close(self) -> None: """关闭日志写入器""" if hasattr(self, 'writer') and self.writer is not None: self.writer.flush() self.writer.close() self.writer = None print(f"TensorBoard日志已关闭") class TQDMProgressBar(Callback): """使用tqdm显示训练进度,兼容不同版本的PyTorch Lightning""" def __init__(self): self.progress_bar = None self.epoch_bar = None def on_train_start(self, trainer: Trainer, pl_module: pl.LightningModule) -> None: """训练开始时初始化进度条""" # 兼容不同版本的步数估计 total_steps = self._get_total_steps(trainer) self.progress_bar = tqdm( total=total_steps, desc="Training Steps", position=0, leave=True, dynamic_ncols=True ) self.epoch_bar = tqdm( total=trainer.max_epochs, desc="Epochs", position=1, leave=True, dynamic_ncols=True ) def _get_total_steps(self, trainer: Trainer) -> int: """获取训练总步数,兼容不同版本的PyTorch Lightning""" # 尝试使用新版本属性 if hasattr(trainer, 'estimated_stepping_batches'): return trainer.estimated_stepping_batches # 尝试使用旧版本属性 if hasattr(trainer, 'estimated_steps'): return trainer.estimated_steps # 回退到手动计算 try: if hasattr(trainer, 'num_training_batches'): num_batches = trainer.num_training_batches else: num_batches = len(trainer.train_dataloader) if hasattr(trainer, 'accumulate_grad_batches'): accumulate = trainer.accumulate_grad_batches else: accumulate = 1 steps_per_epoch = num_batches // accumulate total_steps = trainer.max_epochs * steps_per_epoch print(f"回退计算训练总步数: {total_steps} = {trainer.max_epochs} epochs × {steps_per_epoch} steps/epoch") return total_steps except Exception as e: print(f"无法确定训练总步数: {e}, 使用默认值10000") return 10000 def on_train_batch_end(self, trainer: Trainer, pl_module: pl.LightningModule, outputs: Any, batch: Any, batch_idx: int) -> None: """每个训练批次结束时更新进度条""" if self.progress_bar: # 防止进度条超过总步数 if self.progress_bar.n < self.progress_bar.total: self.progress_bar.update(1) try: # 尝试从输出中获取损失 loss = outputs.get('loss') if loss is not None: if isinstance(loss, torch.Tensor): loss = loss.item() self.progress_bar.set_postfix({"loss": loss}) except Exception: pass def on_train_epoch_end(self, trainer: Trainer, pl_module: pl.LightningModule) -> None: """每个训练轮次结束时更新轮次进度条""" if self.epoch_bar: self.epoch_bar.update(1) self.epoch_bar.set_postfix({"epoch": trainer.current_epoch}) def on_train_end(self, trainer: Trainer, pl_module: pl.LightningModule) -> None: """训练结束时关闭进度条""" if self.progress_bar: self.progress_bar.close() if self.epoch_bar: self.epoch_bar.close() class PerformanceMonitor(Callback): """性能监控回调,记录内存使用和训练速度""" def __init__(self): self.epoch_start_time = 0 self.batch_times = [] def on_train_epoch_start(self, trainer: Trainer, pl_module: pl.LightningModule) -> None: """每个训练轮次开始时记录时间和重置内存统计""" self.epoch_start_time = time.time() self.batch_times = [] if torch.cuda.is_available(): torch.cuda.reset_peak_memory_stats() torch.cuda.synchronize() # 修改1:添加dataloader_idx参数 def on_train_batch_start(self, trainer: Trainer, pl_module: pl.LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: """每个训练批次开始时记录时间""" self.batch_start_time = time.time() # 修改2:添加dataloader_idx参数 def on_train_batch_end(self, trainer: Trainer, pl_module: pl.LightningModule, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: """每个训练批次结束时记录时间""" self.batch_times.append(time.time() - self.batch_start_time) def on_train_epoch_end(self, trainer: Trainer, pl_module: pl.LightningModule) -> None: """每个训练轮次结束时计算并记录性能指标""" epoch_time = time.time() - self.epoch_start_time if self.batch_times: avg_batch_time = sum(self.batch_times) / len(self.batch_times) batches_per_second = 1.0 / avg_batch_time else: avg_batch_time = 0 batches_per_second = 0 memory_info = "" if torch.cuda.is_available(): max_memory = torch.cuda.max_memory_allocated() / 2 ** 20 # MiB memory_info = f", 峰值显存: {max_memory:.2f} MiB" rank_zero_info( f"Epoch {trainer.current_epoch} | " f"耗时: {epoch_time:.2f}s | " f"Batch耗时: {avg_batch_time:.4f}s ({batches_per_second:.2f} batches/s)" f"{memory_info}" ) def get_world_size() -> int: """获取分布式训练中的总进程数""" if dist.is_initialized(): return dist.get_world_size() return 1 def all_gather(data: torch.Tensor) -> List[torch.Tensor]: """在分布式环境中收集所有进程的数据""" world_size = get_world_size() if world_size == 1: return [data] # 获取各进程的Tensor大小 local_size = torch.tensor([data.numel()], device=data.device) size_list = [torch.zeros_like(local_size) for _ in range(world_size)] dist.all_gather(size_list, local_size) size_list = [int(size.item()) for size in size_list] max_size = max(size_list) # 收集数据 tensor_list = [] for size in size_list: tensor_list.append(torch.empty((max_size,), dtype=data.dtype, device=data.device)) if local_size < max_size: padding = torch.zeros(max_size - local_size, dtype=data.dtype, device=data.device) data = torch.cat((data.view(-1), padding)) dist.all_gather(tensor_list, data.view(-1)) # 截断到实际大小 results = [] for tensor, size in zip(tensor_list, size_list): results.append(tensor[:size].reshape(data.shape)) return results def create_experiment_directories(logging_config: DictConfig, experiment_name: str) -> Tuple[str, str, str]: """创建实验目录结构""" now = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") logdir = os.path.join(logging_config.logdir, f"{experiment_name}_{now}") ckptdir = os.path.join(logdir, "checkpoints") cfgdir = os.path.join(logdir, "configs") os.makedirs(ckptdir, exist_ok=True) os.makedirs(cfgdir, exist_ok=True) print(f"实验目录: {logdir}") print(f"检查点目录: {ckptdir}") print(f"配置目录: {cfgdir}") return logdir, ckptdir, cfgdir def setup_callbacks(config_manager: ConfigManager, ckptdir: str, tb_logger: TensorBoardLogger) -> List[Callback]: """设置训练回调函数""" callbacks = [] # 模型检查点 checkpoint_callback = ModelCheckpoint( dirpath=ckptdir, filename='{epoch}-{step}-{val_loss:.2f}', monitor='val_loss', save_top_k=3, mode='min', save_last=True, save_on_train_epoch_end=True, # 确保在epoch结束时保存完整状态 save_weights_only=False, # 明确设置为False,保存完整检查点 every_n_train_steps=1000 # 每1000步保存一次 ) callbacks.append(checkpoint_callback) # 学习率监控 lr_monitor = LearningRateMonitor(logging_interval="step") callbacks.append(lr_monitor) # 图像日志记录 image_logger_cfg = config_manager.get_callbacks_config().get("image_logger", {}) image_logger = EnhancedImageLogger( batch_frequency=image_logger_cfg.get("batch_frequency", 500), max_images=image_logger_cfg.get("max_images", 4), loggers=[tb_logger] ) callbacks.append(image_logger) # 进度条 progress_bar = TQDMProgressBar() callbacks.append(progress_bar) # 性能监控 perf_monitor = PerformanceMonitor() callbacks.append(perf_monitor) return callbacks def preprocess_checkpoint(checkpoint_path: str, model: pl.LightningModule) -> Dict[str, Any]: """预处理检查点文件,确保包含所有必要的键,并添加缺失的训练状态""" print(f"预处理检查点文件: {checkpoint_path}") # 加载检查点 try: checkpoint = torch.load(checkpoint_path, map_location="cpu") except Exception as e: print(f"加载检查点失败: {e}") raise # 强制重置训练状态 checkpoint['epoch'] = 0 checkpoint['global_step'] = 0 checkpoint['lr_schedulers'] = [] checkpoint['optimizer_states'] = [] print("已重置训练状态: epoch=0, global_step=0") # 检查是否缺少关键训练状态 required_keys = ['optimizer_states', 'lr_schedulers', 'epoch', 'global_step'] missing_keys = [k for k in required_keys if k not in checkpoint] if missing_keys: print(f"警告: 检查点缺少训练状态字段 {missing_keys},将创建伪训练状态") # 创建伪训练状态 checkpoint.setdefault('optimizer_states', []) checkpoint.setdefault('lr_schedulers', []) checkpoint.setdefault('epoch', 0) checkpoint.setdefault('global_step', 0) # 检查是否缺少 position_ids state_dict = checkpoint.get("state_dict", {}) if "cond_stage_model.transformer.text_model.embeddings.position_ids" not in state_dict: print("警告: 检查点缺少 'cond_stage_model.transformer.text_model.embeddings.position_ids' 键") # 获取模型中的 position_ids 形状 if hasattr(model, "cond_stage_model") and hasattr(model.cond_stage_model, "transformer"): try: max_position_embeddings = model.cond_stage_model.transformer.text_model.config.max_position_embeddings position_ids = torch.arange(max_position_embeddings).expand((1, -1)) state_dict["cond_stage_model.transformer.text_model.embeddings.position_ids"] = position_ids print("已添加 position_ids 到检查点") except Exception as e: print(f"无法添加 position_ids: {e}") # 确保有 state_dict if "state_dict" not in checkpoint: checkpoint["state_dict"] = state_dict return checkpoint # 正确继承原始模型类 from ldm.models.diffusion.ddpm import LatentDiffusion class CustomLatentDiffusion(LatentDiffusion): """自定义 LatentDiffusion 类,处理检查点加载问题""" def on_load_checkpoint(self, checkpoint): """在加载检查点时自动处理缺失的键""" state_dict = checkpoint["state_dict"] # 检查是否缺少 position_ids if "cond_stage_model.transformer.text_model.embeddings.position_ids" not in state_dict: print("警告: 检查点缺少 'cond_stage_model.transformer.text_model.embeddings.position_ids' 键") # 获取模型中的 position_ids 形状 max_position_embeddings = self.cond_stage_model.transformer.text_model.config.max_position_embeddings position_ids = torch.arange(max_position_embeddings).expand((1, -1)) state_dict["cond_stage_model.transformer.text_model.embeddings.position_ids"] = position_ids print("已添加 position_ids 到 state_dict") # 使用非严格模式加载 self.load_state_dict(state_dict, strict=False) print("模型权重加载完成") def filter_kwargs(cls, kwargs, log_prefix=""): # 关键参数白名单 - 这些参数必须保留 ESSENTIAL_PARAMS = { 'unet_config', 'first_stage_config', 'cond_stage_config', 'scheduler_config', 'ckpt_path', 'linear_start', 'linear_end' } # 特殊处理:允许所有包含"config"的参数 filtered_kwargs = {} for k, v in kwargs.items(): if k in ESSENTIAL_PARAMS or 'config' in k: filtered_kwargs[k] = v else: print(f"{log_prefix}过滤参数: {k}") print(f"{log_prefix}保留参数: {list(filtered_kwargs.keys())}") return filtered_kwargs def check_checkpoint_content(checkpoint_path): """打印检查点包含的键,确认是否有训练状态""" checkpoint = torch.load(checkpoint_path, map_location="cpu") print("检查点包含的键:", list(checkpoint.keys())) if "state_dict" in checkpoint: print("模型权重存在") if "optimizer_states" in checkpoint: print("优化器状态存在") if "epoch" in checkpoint: print(f"保存的epoch: {checkpoint['epoch']}") if "global_step" in checkpoint: print(f"保存的global_step: {checkpoint['global_step']}") def main() -> None: """主函数,训练和推理流程的入口点""" # 启用Tensor Core加速 torch.set_float32_matmul_precision('high') # 解析命令行参数 parser = argparse.ArgumentParser(description="扩散模型训练框架") parser.add_argument("--config", type=str, default="configs/train.yaml", help="配置文件路径") parser.add_argument("--name", type=str, default="experiment", help="实验名称") parser.add_argument("--resume", action="store_true", default=True, help="恢复训练") parser.add_argument("--debug", action="store_true", help="调试模式") parser.add_argument("--seed", type=int, default=42, help="随机种子") parser.add_argument("--scale_lr", action="store_true", help="根据GPU数量缩放学习率") parser.add_argument("--precision", type=str, default="32", choices=["16", "32", "bf16"], help="训练精度") args, unknown = parser.parse_known_args() # 设置随机种子 seed_everything(args.seed, workers=True) print(f"设置随机种子: {args.seed}") # 初始化配置管理器 try: config_manager = ConfigManager(args.config, unknown) config = config_manager.config except Exception as e: print(f"加载配置失败: {e}") sys.exit(1) # 创建日志目录 logging_config = config_manager.get_logging_config() logdir, ckptdir, cfgdir = create_experiment_directories(logging_config, args.name) # 保存配置 config_manager.save_config(os.path.join(cfgdir, "config.yaml")) # 配置日志记录器 tb_logger = TensorBoardLogger(os.path.join(logdir, "tensorboard")) # 配置回调函数 callbacks = setup_callbacks(config_manager, ckptdir, tb_logger) # 初始化数据模块 try: print("初始化数据模块...") data_config = config_manager.get_data_config() data_module = instantiate_from_config(data_config) data_module.setup() print("可用数据集:", list(data_module.datasets.keys())) except Exception as e: print(f"数据模块初始化失败: {str(e)}") return # 创建模型 try: model_config = config_manager.get_model_config() model_params = model_config.get("params", {}) # 创建模型实例 model = CustomLatentDiffusion(**model_config.get("params", {})) print("模型初始化成功") # 检查并转换预训练权重 ckpt_path = model_config.params.get("ckpt_path", "") if ckpt_path and os.path.exists(ckpt_path): print(f"加载预训练权重: {ckpt_path}") checkpoint = torch.load(ckpt_path, map_location="cpu") state_dict = checkpoint.get("state_dict", checkpoint) # 查找所有与conv_in.weight相关的键 conv_in_keys = [] for key in state_dict.keys(): if "conv_in.weight" in key and "first_stage_model" in key: conv_in_keys.append(key) # 转换找到的权重 for conv_in_key in conv_in_keys: if state_dict[conv_in_key].shape[1] == 3: # 原始是3通道 print(f"转换权重: {conv_in_key} 从3通道到1通道") # 取RGB三通道的平均值作为单通道权重 rgb_weights = state_dict[conv_in_key] ir_weights = rgb_weights.mean(dim=1, keepdim=True) state_dict[conv_in_key] = ir_weights print(f"转换前形状: {rgb_weights.shape}") print(f"转换后形状: {ir_weights.shape}") print(f"模型层形状: {model.first_stage_model.encoder.conv_in.weight.shape}") # 非严格模式加载(允许其他层不匹配) missing, unexpected = model.load_state_dict(state_dict, strict=False) print(f"权重加载完成: 缺失层 {len(missing)}, 不匹配层 {len(unexpected)}") if missing: print("缺失层:", missing) if unexpected: print("意外层:", unexpected) except Exception as e: print(f"模型初始化失败: {str(e)}") return print("VAE输入层形状:", model.first_stage_model.encoder.conv_in.weight.shape) # 权重转换 if ckpt_path and os.path.exists(ckpt_path): print(f"加载预训练权重: {ckpt_path}") checkpoint = torch.load(ckpt_path, map_location="cpu") state_dict = checkpoint.get("state_dict", checkpoint) # 增强:查找所有需要转换的层(包括可能的变体) conversion_keys = [] for key in state_dict.keys(): if "conv_in" in key or "conv_out" in key or "nin_shortcut" in key: if state_dict[key].ndim == 4 and state_dict[key].shape[1] == 3: conversion_keys.append(key) print(f"找到需要转换的层: {conversion_keys}") # 转换权重 for key in conversion_keys: print(f"转换权重: {key}") print(f"原始形状: {state_dict[key].shape}") # RGB权重 [out_c, in_c=3, kH, kW] rgb_weights = state_dict[key] # 转换为单通道权重 [out_c, 1, kH, kW] if rgb_weights.shape[1] == 3: ir_weights = rgb_weights.mean(dim=1, keepdim=True) state_dict[key] = ir_weights print(f"转换后形状: {state_dict[key].shape}") # 加载转换后的权重 try: # 使用非严格模式加载 missing, unexpected = model.load_state_dict(state_dict, strict=False) print(f"权重加载完成: 缺失层 {len(missing)}, 不匹配层 {len(unexpected)}") # 打印重要信息 if missing: print("缺失层:", missing[:5]) # 只显示前5个避免过多输出 if unexpected: print("意外层:", unexpected[:5]) # 特别检查conv_in层 if "first_stage_model.encoder.conv_in.weight" in missing: print("警告: conv_in.weight未加载,需要手动初始化") # 手动初始化单通道卷积层 with torch.no_grad(): model.first_stage_model.encoder.conv_in.weight.data.normal_(mean=0.0, std=0.02) print("已手动初始化conv_in.weight") except RuntimeError as e: print(f"加载权重时出错: {e}") print("尝试仅加载兼容的权重...") # 创建新的状态字典只包含兼容的键 model_state = model.state_dict() compatible_dict = {} for k, v in state_dict.items(): if k in model_state and v.shape == model_state[k].shape: compatible_dict[k] = v # 加载兼容的权重 model.load_state_dict(compatible_dict, strict=False) print(f"部分权重加载完成: {len(compatible_dict)}/{len(state_dict)}") # 配置学习率 training_config = config_manager.get_training_config() bs = data_config.params.batch_size base_lr = model_config.base_learning_rate ngpu = training_config.get("gpus", 1) accumulate_grad_batches = training_config.get("accumulate_grad_batches", 1) if args.scale_lr: model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr print(f"学习率缩放至: {model.learning_rate:.2e} = {accumulate_grad_batches} × {ngpu} × {bs} × {base_lr:.2e}") else: model.learning_rate = base_lr print(f"使用基础学习率: {model.learning_rate:.2e}") # 检查是否恢复训练 resume_from_checkpoint = None if args.resume: # 优先使用自动保存的last.ckpt last_ckpt = os.path.join(ckptdir, "last.ckpt") if os.path.exists(last_ckpt): print(f"恢复训练状态: {last_ckpt}") resume_from_checkpoint = last_ckpt else: # 回退到指定检查点 fallback_ckpt = os.path.join(current_dir, "checkpoints", "M3FD.ckpt") if os.path.exists(fallback_ckpt): print(f"警告: 使用仅含权重的检查点,训练状态将重置: {fallback_ckpt}") resume_from_checkpoint = fallback_ckpt else: print("未找到可用的检查点,从头开始训练") # 如果需要恢复训练,预处理检查点 if resume_from_checkpoint and os.path.exists(resume_from_checkpoint): try: # 预处理检查点 - 添加缺失的状态 checkpoint = preprocess_checkpoint(resume_from_checkpoint, model) # 创建新的完整检查点文件 fixed_ckpt_path = os.path.join(ckptdir, "fixed_checkpoint.ckpt") torch.save(checkpoint, fixed_ckpt_path) print(f"修复后的完整检查点已保存到: {fixed_ckpt_path}") # 使用修复后的检查点 resume_from_checkpoint = fixed_ckpt_path except Exception as e: print(f"预处理检查点失败: {e}") print("将尝试使用默认方式加载检查点") # 配置日志记录器 tb_logger = TensorBoardLogger(os.path.join(logdir, "tensorboard")) # 配置回调函数 callbacks = setup_callbacks(config_manager, ckptdir, tb_logger) # 检查是否有验证集 has_validation = hasattr(data_module, 'datasets') and 'validation' in data_module.datasets # 计算训练批次数 try: train_loader = data_module.train_dataloader() num_train_batches = len(train_loader) print(f"训练批次数: {num_train_batches}") except Exception as e: print(f"计算训练批次数失败: {e}") num_train_batches = 0 # 设置训练器参数(先设置基础参数) trainer_config = { "default_root_dir": logdir, "max_epochs": training_config.max_epochs, "gpus": ngpu, "distributed_backend": "ddp" if ngpu > 1 else None, "plugins": [DDPPlugin(find_unused_parameters=False)] if ngpu > 1 else None, "precision": 16, "accumulate_grad_batches": accumulate_grad_batches, "callbacks": callbacks, "logger": tb_logger, # 添加日志记录器 "resume_from_checkpoint": resume_from_checkpoint, "fast_dev_run": args.debug, "limit_val_batches": 0 if not has_validation else 1.0, "num_sanity_val_steps": 0, # 跳过初始验证加速恢复 "log_every_n_steps": 10 # 更频繁的日志记录 } # 动态调整验证配置 if has_validation: if num_train_batches < 50: # 小数据集:使用epoch验证 trainer_config["check_val_every_n_epoch"] = 1 # 确保移除步数验证参数 if "val_check_interval" in trainer_config: del trainer_config["val_check_interval"] else: # 大数据集:使用步数验证 val_check_interval = min(2000, num_train_batches) if num_train_batches < 100: val_check_interval = max(1, num_train_batches // 4) trainer_config["val_check_interval"] = val_check_interval # 创建训练器 try: print("最终训练器配置:") for k, v in trainer_config.items(): print(f" {k}: {v}") trainer = Trainer(**trainer_config) except Exception as e: print(f"创建训练器失败: {e}") tb_logger.close() sys.exit(1) # 执行训练 try: print("开始训练...") trainer.fit(model, data_module) print("训练完成!") except KeyboardInterrupt: print("训练被用户中断") if trainer.global_rank == 0 and trainer.model is not None: trainer.save_checkpoint(os.path.join(ckptdir, "interrupted.ckpt")) except Exception as e: print(f"训练出错: {e}") if trainer.global_rank == 0 and hasattr(trainer, 'model') and trainer.model is not None: trainer.save_checkpoint(os.path.join(ckptdir, "error.ckpt")) raise finally: # 关闭日志记录器 tb_logger.close() # 打印性能分析报告 if trainer.global_rank == 0 and hasattr(trainer, 'profiler'): print("训练摘要:") print(trainer.profiler.summary()) if __name__ == "__main__": main()运行报错:模型初始化失败: Error(s) in loading state_dict for CustomLatentDiffusion: size mismatch for first_stage_model.encoder.conv_in.weight: copying a param with shape torch.Size([128, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 1, 3, 3]).
07-04
当前代码如下: import sys import serial import serial.tools.list_ports import numpy as np from PyQt5.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QComboBox, QLabel, QGroupBox, QStatusBar, QTextEdit, QSplitter, QCheckBox) from PyQt5.QtCore import QTimer, Qt, QThread, pyqtSignal from pyqtgraph import PlotWidget, mkPen, mkBrush import pyqtgraph as pg from collections import deque import time import binascii class LidarWorker(QThread): """雷达数据处理线程""" new_points = pyqtSignal(list) # 信号:发射新的点云数据 raw_data_signal = pyqtSignal(str) # 信号:发射原始数据 def __init__(self, port, baudrate=153600): super().__init__() self.port = port self.baudrate = baudrate self.serial_conn = None self.running = False self._last_valid_dist = 1.0 # 上次有效距离值(用于突变检测) self.buffer = bytearray() # 确保缓冲区初始化 # 根据实际数据确认的配置 self.HEADER = b'\xb4\x4b' # 正确的包头 self.HEADER_LENGTH = 2 self.POINT_SIZE = 5 # 角度2字节 + 距离2字节 + 质量1字节 # 雷达参数 self.max_distance = 8.0 # 最大测距距离(米) self.min_distance = 0.12 # 最小测距距离(米) # 调试 self.debug_counter = 0 def run(self): """主线程函数""" if not self._connect_serial(): self.new_points.emit([]) return while self.running: if self.serial_conn.in_waiting > 0: data = self.serial_conn.read(self.serial_conn.in_waiting) self.raw_data_signal.emit(binascii.hexlify(data).decode('ascii')) self._process_buffer(data) time.sleep(0.001) self._disconnect_serial() def _connect_serial(self): """连接串口""" try: self.serial_conn = serial.Serial( port=self.port, baudrate=self.baudrate, bytesize=serial.EIGHTBITS, parity=serial.PARITY_NONE, stopbits=serial.STOPBITS_ONE, timeout=0.1 ) self.running = True return True except serial.SerialException as e: print(f"无法连接到串口: {e}") return False def _disconnect_serial(self): """断开串口连接""" if self.serial_conn and self.serial_conn.is_open: self.serial_conn.close() def _process_buffer(self, new_data): """最终版数据解析方法""" self.buffer.extend(new_data) points = [] while len(self.buffer) >= 7: # 最小完整包长度 # 增强版包头检测(允许部分数据干扰) header_pos = -1 for i in range(min(len(self.buffer) - 1, 100)): # 最多检查前100字节 if self.buffer[i] == 0xb4 and self.buffer[i + 1] == 0x4b: header_pos = i break if header_pos == -1: # 保留最后1字节可能的部分包头 self.buffer = self.buffer[-1:] if len(self.buffer) > 0 else bytearray() break # 移除包头前的无效数据 if header_pos > 0: self.buffer = self.buffer[header_pos:] if len(self.buffer) < 7: break # 提取完整数据包 packet = self.buffer[:7] self.buffer = self.buffer[7:] try: # 改进的字节顺序检测 angle_be = packet[2] * 256 + packet[3] angle_le = packet[3] * 256 + packet[2] dist_be = packet[4] * 256 + packet[5] dist_le = packet[5] * 256 + packet[4] # 测试所有可能的组合 for angle, dist in [(angle_be, dist_be), (angle_le, dist_le), (angle_be, dist_le), (angle_le, dist_be)]: actual_angle = angle * 0.01 actual_dist = dist * 0.001 if (0 <= actual_angle < 360 and 0.1 <= actual_dist <= 12.0 and abs(actual_dist - self._last_valid_dist) < 2.0): # 距离突变检查 points.append((actual_angle, actual_dist)) self._last_valid_dist = actual_dist break except Exception as e: print(f"数据包解析异常: {e}\n原始数据: {binascii.hexlify(packet)}") if points: self.new_points.emit(points) elif len(self.buffer) > 20: # 缓冲区积压过多 print(f"警告: 缓冲区积压 {len(self.buffer)}字节,执行清空") self.buffer.clear() def stop(self): """停止线程""" self.running = False self.wait() class LidarApp(QMainWindow): """主应用程序窗口""" def __init__(self): super().__init__() self.setWindowTitle("YDLIDAR X3Pro 平面地图") self.setGeometry(100, 100, 1200, 800) # 初始化UI self.init_ui() # 雷达工作线程 self.lidar_worker = None # 点云数据 self.point_cloud = [] self.map_points = [] # 用于地图显示的点 # 更新可用串口 self.update_serial_ports() # 设置定时器定期更新图形 self.plot_timer = QTimer() self.plot_timer.timeout.connect(self.update_map_display) self.plot_timer.start(50) # 20Hz刷新率 def init_ui(self): """初始化用户界面""" # 主窗口布局 main_widget = QWidget() self.setCentralWidget(main_widget) main_layout = QHBoxLayout(main_widget) main_layout.setContentsMargins(5, 5, 5, 5) # 使用分割器实现可调整布局 splitter = QSplitter(Qt.Horizontal) # 左侧控制面板 control_panel = QWidget() control_panel.setMaximumWidth(300) control_layout = QVBoxLayout() control_layout.setContentsMargins(5, 5, 5, 5) # 串口设置组 port_group = QGroupBox("串口设置") port_layout = QVBoxLayout() self.port_combo = QComboBox() self.refresh_btn = QPushButton("刷新串口列表") self.refresh_btn.clicked.connect(self.update_serial_ports) port_layout.addWidget(QLabel("选择串口:")) port_layout.addWidget(self.port_combo) port_layout.addWidget(self.refresh_btn) port_group.setLayout(port_layout) # 雷达控制组 control_group = QGroupBox("雷达控制") control_btn_layout = QVBoxLayout() self.start_btn = QPushButton("开始扫描") self.start_btn.clicked.connect(self.start_lidar) self.stop_btn = QPushButton("停止扫描") self.stop_btn.clicked.connect(self.stop_lidar) self.stop_btn.setEnabled(False) # 地图显示选项 self.accumulate_check = QCheckBox("累积模式") self.accumulate_check.setChecked(True) self.clear_map_btn = QPushButton("清空地图") self.clear_map_btn.clicked.connect(self.clear_map) control_btn_layout.addWidget(self.start_btn) control_btn_layout.addWidget(self.stop_btn) control_btn_layout.addWidget(self.accumulate_check) control_btn_layout.addWidget(self.clear_map_btn) control_group.setLayout(control_btn_layout) # 状态信息组 status_group = QGroupBox("状态信息") status_layout = QVBoxLayout() self.status_label = QLabel("状态: 未连接") self.point_count_label = QLabel("点数: 0") self.fps_label = QLabel("FPS: 0") status_layout.addWidget(self.status_label) status_layout.addWidget(self.point_count_label) status_layout.addWidget(self.fps_label) status_layout.addStretch() status_group.setLayout(status_layout) # 添加到控制面板 control_layout.addWidget(port_group) control_layout.addWidget(control_group) control_layout.addWidget(status_group) control_layout.addStretch() control_panel.setLayout(control_layout) # 右侧主显示区 right_panel = QWidget() right_layout = QVBoxLayout() # 平面地图显示 self.map_widget = pg.PlotWidget() self.map_widget.setBackground('w') self.map_widget.setTitle("平面地图 - YDLIDAR X3Pro", color='k') self.map_widget.setLabel('left', 'Y (m)') self.map_widget.setLabel('bottom', 'X (m)') self.map_widget.showGrid(x=True, y=True) self.map_widget.setXRange(-8, 8) self.map_widget.setYRange(-8, 8) self.map_widget.setAspectLocked(True) # 地图散点图 self.map_plot = pg.ScatterPlotItem( size=5, pen=mkPen(color='b', width=1), brush=mkBrush('b') ) self.map_widget.addItem(self.map_plot) # 原始数据展示 self.raw_data_display = QTextEdit() self.raw_data_display.setReadOnly(True) self.raw_data_display.setMaximumHeight(150) self.raw_data_display.setStyleSheet("font-family: monospace;") right_layout.addWidget(self.map_widget) right_layout.addWidget(QLabel("原始数据(十六进制):")) right_layout.addWidget(self.raw_data_display) right_panel.setLayout(right_layout) # 添加分割器 splitter.addWidget(control_panel) splitter.addWidget(right_panel) splitter.setStretchFactor(1, 3) # 右侧区域占用更多空间 main_layout.addWidget(splitter) # 状态栏 self.status_bar = QStatusBar() self.setStatusBar(self.status_bar) self.status_bar.showMessage("就绪") # FPS计算 self.last_update_time = time.time() self.frame_count = 0 self.current_fps = 0 def update_serial_ports(self): """更新可用串口列表""" self.port_combo.clear() ports = serial.tools.list_ports.comports() for port in ports: self.port_combo.addItem(f"{port.device}", port.device) if not ports: self.port_combo.addItem("未找到串口", None) self.start_btn.setEnabled(False) else: self.start_btn.setEnabled(True) def start_lidar(self): """启动雷达扫描""" port = self.port_combo.currentData() if not port: self.status_bar.showMessage("错误: 没有可用的串口") return # 禁用按钮 self.start_btn.setEnabled(False) self.stop_btn.setEnabled(True) self.port_combo.setEnabled(False) self.refresh_btn.setEnabled(False) # 清空数据和显示 self.point_cloud = [] if not self.accumulate_check.isChecked(): self.map_points = [] self.raw_data_display.clear() # 启动工作线程 self.lidar_worker = LidarWorker(port) self.lidar_worker.new_points.connect(self.add_new_points) self.lidar_worker.raw_data_signal.connect(self.display_raw_data) self.lidar_worker.start() self.status_label.setText("状态: 扫描中...") self.status_bar.showMessage(f"已连接到 {port},开始扫描...") def stop_lidar(self): """停止雷达扫描""" if self.lidar_worker: self.lidar_worker.stop() self.lidar_worker = None # 恢复按钮状态 self.start_btn.setEnabled(True) self.stop_btn.setEnabled(False) self.port_combo.setEnabled(True) self.refresh_btn.setEnabled(True) self.status_label.setText("状态: 已停止") self.status_bar.showMessage("扫描已停止") def clear_map(self): """清空地图数据""" self.map_points = [] self.update_map_display() def add_new_points(self, new_points): """添加新的点云数据""" if new_points: self.point_cloud = new_points # 更新当前帧数据 # 累积模式处理 if self.accumulate_check.isChecked(): self.map_points.extend(new_points) # 限制累积点数 if len(self.map_points) > 5000: self.map_points = self.map_points[-5000:] self.point_count_label.setText(f"点数: {len(self.point_cloud)}") # 计算FPS self.frame_count += 1 current_time = time.time() if current_time - self.last_update_time >= 1.0: self.current_fps = self.frame_count / (current_time - self.last_update_time) self.fps_label.setText(f"FPS: {self.current_fps:.1f}") self.last_update_time = current_time self.frame_count = 0 def display_raw_data(self, hex_data): """显示原始数据""" # 限制显示的行数,避免内存占用过高 max_lines = 50 current_text = self.raw_data_display.toPlainText() lines = current_text.split('\n') if len(lines) > max_lines: lines = lines[-max_lines:] # 添加新数据 lines.append(hex_data) self.raw_data_display.setPlainText('\n'.join(lines)) # 自动滚动到底部 scrollbar = self.raw_data_display.verticalScrollBar() scrollbar.setValue(scrollbar.maximum()) def update_map_display(self): """增强版地图显示""" if not hasattr(self, 'scatter'): self.scatter = pg.ScatterPlotItem(size=5) self.map_widget.addItem(self.scatter) self.map_widget.setXRange(-5, 5) self.map_widget.setYRange(-5, 5) display_points = self.map_points if self.accumulate_check.isChecked() else self.point_cloud if not display_points: self.scatter.setData([], []) return # 坐标转换(添加异常处理) try: angles = np.radians([p[0] for p in display_points]) distances = np.clip([p[1] for p in display_points], 0.1, 8.0) x = distances * np.cos(angles) y = distances * np.sin(angles) # 自动调整视野(保留10%边距) if len(x) > 10: # 至少有10个点才调整 x_margin = (max(x) - min(x)) * 0.1 y_margin = (max(y) - min(y)) * 0.1 self.map_widget.setXRange(min(x) - x_margin, max(x) + x_margin) self.map_widget.setYRange(min(y) - y_margin, max(y) + y_margin) # 动态颜色映射 norm_dist = (distances - 0.1) / (8.0 - 0.1) colors = [ (int(255 * (1 - v)), 0, int(255 * v), 180) for v in norm_dist ] self.scatter.setData(x=x, y=y, brush=colors, size=5) except Exception as e: print(f"地图渲染错误: {e}") def closeEvent(self, event): """窗口关闭事件""" self.stop_lidar() self.plot_timer.stop() super().closeEvent(event) if __name__ == "__main__": app = QApplication(sys.argv) # 设置全局样式 app.setStyle('Fusion') window = LidarApp() window.show() sys.exit(app.exec_()) 现在能显示点云,但这不是laser线条地图,我需要生成平面二维地图。请结合数据手册利用激光雷达进行路径规划,借助SLAM(Simultaneous Localization and Mapping)技术实现了地图的生成与更新,从而能够在未知环境中高效行进,帮我修改代码。直接给出完整代码。
最新发布
08-19
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值