SUPIR模型训练全攻略:从数据准备到Loss优化

SUPIR模型训练全攻略:从数据准备到Loss优化

【免费下载链接】SUPIR 【免费下载链接】SUPIR 项目地址: https://gitcode.com/gh_mirrors/su/SUPIR

引言

在图像修复(Image Restoration)领域,如何平衡模型性能与计算效率一直是研究的核心挑战。SUPIR(Scaling Up to Excellence: Practicing Model Scaling for Photo-Realistic Image Restoration In the Wild)作为CVPR2024的最新成果,通过创新的模型架构设计和训练策略,实现了真实场景下照片级图像修复的突破。本文将系统讲解SUPIR模型的训练全流程,从环境配置、数据准备到高级Loss优化策略,帮助研究者和工程师快速掌握这一前沿技术。

环境配置与依赖安装

系统要求

SUPIR模型训练对硬件环境有较高要求,推荐配置如下:

  • GPU:NVIDIA A100 (80GB) 或同等算力的GPU
  • CPU:16核及以上
  • 内存:128GB及以上
  • 存储:至少500GB可用空间(用于存储数据集和模型检查点)
  • 操作系统:Linux (Ubuntu 20.04+)

依赖安装

# 克隆仓库
git clone https://gitcode.com/gh_mirrors/su/SUPIR
cd SUPIR

# 创建并激活conda环境
conda create -n SUPIR python=3.8 -y
conda activate SUPIR

# 安装依赖包
pip install --upgrade pip
pip install -r requirements.txt

关键依赖版本说明:

  • PyTorch:2.1.0及以上
  • CUDA:11.7及以上
  • xFormers:0.0.20及以上(用于高效注意力计算)
  • OpenCV:4.7.0及以上(用于图像处理)

预训练模型下载

SUPIR训练需要以下预训练模型:

模型名称用途下载地址
SDXL base 1.0_0.9vae基础扩散模型HuggingFace
LLaVA v1.5 13B视觉语言模型HuggingFace
CLIP ViT-bigG-14图像编码器HuggingFace

下载完成后,修改CKPT_PTH.py文件,指定模型路径:

# CKPT_PTH.py
LLAVA_CLIP_PATH = "/path/to/clip-vit-large-patch14-336"
LLAVA_MODEL_PATH = "/path/to/llava-v1.5-13b"
SDXL_CLIP1_PATH = "/path/to/clip-vit-large-patch14"
SDXL_CLIP2_CACHE_DIR = "/path/to/CLIP-ViT-bigG-14-laion2B-39B-b160k"

数据准备

数据集结构

SUPIR训练推荐使用如下数据集结构:

dataset/
├── train/
│   ├── input/       # 低质量输入图像
│   └── target/      # 高质量目标图像
└── val/
    ├── input/
    └── target/

数据预处理

数据预处理步骤包括:

  1. 图像尺寸调整:统一调整为512×512或1024×1024
  2. 颜色空间转换:转换为RGB格式
  3. 数据增强:包括随机裁剪、翻转等
  4. 标准化:将像素值归一化到[-1, 1]范围

以下是数据预处理的Python代码示例:

import cv2
import numpy as np
from torch.utils.data import Dataset

class SRDataset(Dataset):
    def __init__(self, input_dir, target_dir, size=512):
        self.input_dir = input_dir
        self.target_dir = target_dir
        self.size = size
        self.image_names = os.listdir(input_dir)
        
    def __len__(self):
        return len(self.image_names)
    
    def __getitem__(self, idx):
        img_name = self.image_names[idx]
        input_path = os.path.join(self.input_dir, img_name)
        target_path = os.path.join(self.target_dir, img_name)
        
        # 读取图像
        input_img = cv2.imread(input_path)
        target_img = cv2.imread(target_path)
        
        # 转换为RGB
        input_img = cv2.cvtColor(input_img, cv2.COLOR_BGR2RGB)
        target_img = cv2.cvtColor(target_img, cv2.COLOR_BGR2RGB)
        
        # 调整大小
        input_img = cv2.resize(input_img, (self.size, self.size))
        target_img = cv2.resize(target_img, (self.size, self.size))
        
        # 归一化
        input_img = (input_img.astype(np.float32) / 127.5) - 1.0
        target_img = (target_img.astype(np.float32) / 127.5) - 1.0
        
        # 转换为Tensor
        input_tensor = torch.from_numpy(input_img).permute(2, 0, 1)
        target_tensor = torch.from_numpy(target_img).permute(2, 0, 1)
        
        return {"input": input_tensor, "target": target_tensor}

数据加载配置

修改训练配置文件options/SUPIR_v0.yaml,设置数据相关参数:

data:
  target: sgm.data.DataModuleFromConfig
  params:
    batch_size: 16
    num_workers: 8
    train:
      target: your.dataset.SRDataset
      params:
        input_dir: dataset/train/input
        target_dir: dataset/train/target
        size: 512
    validation:
      target: your.dataset.SRDataset
      params:
        input_dir: dataset/val/input
        target_dir: dataset/val/target
        size: 512

模型架构解析

SUPIR整体架构

SUPIR采用两阶段级联架构,结合了自编码器和扩散模型:

mermaid

关键组件详解

  1. VAE自编码器

    • 作用:将图像压缩到潜空间,减少计算量
    • 配置:4通道输出,采用KL散度损失
  2. 扩散模型

    • 网络结构:U-Net架构,包含空间注意力和通道注意力
    • 时间步编码:采用正弦位置编码
    • 条件控制:通过CLIP特征引导图像生成
  3. 条件控制器

    • 输入:文本提示和CLIP图像特征
    • 输出:控制扩散过程的条件向量

训练配置详解

基础训练参数

修改options/SUPIR_v0.yaml配置文件,设置基础训练参数:

model:
  target: SUPIR.models.SUPIR_model.SUPIRModel
  params:
    ae_dtype: bf16          # 自编码器数据类型
    diffusion_dtype: fp16   # 扩散模型数据类型
    scale_factor: 0.13025   # 潜空间缩放因子
    disable_first_stage_autocast: True  # 禁用第一阶段自动混合精度
    
    network_config:
      target: SUPIR.modules.SUPIR_v0.LightGLVUNet
      params:
        in_channels: 4       # 输入通道数
        out_channels: 4      # 输出通道数
        model_channels: 320  # 基础通道数
        attention_resolutions: [4, 2]  # 注意力分辨率
        num_res_blocks: 2     # 残差块数量
        channel_mult: [1, 2, 4]  # 通道倍增因子
        transformer_depth: [1, 2, 10]  # Transformer深度

采样器配置

SUPIR支持多种采样器,用于控制扩散过程:

sampler_config:
  target: sgm.modules.diffusionmodules.sampling.RestoreEDMSampler
  params:
    num_steps: 50           # 采样步数
    restore_cfg: 4.0        # 恢复CFG尺度
    s_churn: 5              # 噪声扰动强度
    s_noise: 1.01           # 噪声缩放因子

优化器配置

optimizer:
  target: torch.optim.AdamW
  params:
    lr: 2e-5                # 初始学习率
    weight_decay: 1e-4      # 权重衰减
    betas: [0.9, 0.999]     # AdamW动量参数

训练过程详解

训练命令

使用以下命令启动训练:

CUDA_VISIBLE_DEVICES=0,1,2,3 python llava/train/train.py \
  --model_name_or_path liuhaotian/llava-v1.5-13b \
  --data_path ./dataset/train.json \
  --image_folder ./dataset/train/input \
  --vision_tower openai/clip-vit-large-patch14 \
  --tune_mm_mlp_adapter True \
  --mm_vision_select_layer -1 \
  --mm_use_im_start_end False \
  --mm_use_im_patch_token True \
  --bf16 True \
  --output_dir ./checkpoints/supir-v0 \
  --num_train_epochs 10 \
  --per_device_train_batch_size 4 \
  --per_device_eval_batch_size 4 \
  --gradient_accumulation_steps 4 \
  --evaluation_strategy "no" \
  --save_strategy "steps" \
  --save_steps 5000 \
  --save_total_limit 5 \
  --learning_rate 2e-5 \
  --weight_decay 0. \
  --warmup_ratio 0.03 \
  --lr_scheduler_type "cosine" \
  --logging_steps 10 \
  --tf32 True \
  --model_max_length 2048 \
  --gradient_checkpointing True \
  --lazy_preprocess True \
  --report_to wandb

学习率调度

SUPIR采用余弦退火学习率调度:

# sgm/lr_scheduler.py
class LambdaWarmUpCosineScheduler:
    def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps):
        self.lr_warm_up_steps = warm_up_steps  # 预热步数
        self.lr_start = lr_start              # 初始学习率
        self.lr_min = lr_min                  # 最小学习率
        self.lr_max = lr_max                  # 最大学习率
        self.lr_max_decay_steps = max_decay_steps  # 衰减总步数
    
    def __call__(self, n):
        if n < self.lr_warm_up_steps:
            # 线性预热
            lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
        else:
            # 余弦退火
            t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
            t = min(t, 1.0)
            lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (1 + np.cos(t * np.pi))
        return lr

学习率曲线如下:

mermaid

Loss函数优化

Loss函数类型

SUPIR支持多种Loss函数,可根据任务需求选择:

  1. L2 Loss

    • 适用场景:通用图像恢复
    • 特点:稳定但容易产生模糊
  2. L1 Loss

    • 适用场景:保留边缘和细节
    • 特点:对异常值更鲁棒
  3. LPIPS Loss

    • 适用场景:感知质量优化
    • 特点:基于预训练网络的感知损失

Loss配置

在配置文件中设置Loss类型:

loss:
  target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
  params:
    type: "lpips"  # 可选: l2, l1, lpips
    sigma_sampler_config:
      target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
      params:
        rho: 7.0
        sigma_min: 0.002
        sigma_max: 80.0
    offset_noise_level: 0.0  # 偏移噪声水平

混合Loss策略

对于复杂场景,推荐使用混合Loss策略:

# 自定义混合Loss
class MixedLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.l2_loss = nn.MSELoss()
        self.lpips_loss = LPIPS().eval()
        
    def forward(self, pred, target):
        l2 = self.l2_loss(pred, target)
        lpips = self.lpips_loss(pred, target).mean()
        return 0.5 * l2 + 0.5 * lpips

训练监控与评估

指标监控

使用WandB监控训练过程:

wandb login

训练命令中添加WandB参数:

--report_to wandb \
--run_name supir-training-experiment \
--wandb_project supir-image-restoration

关键监控指标:

  • 训练Loss:L2/L1/LPIPS损失
  • 评估指标:PSNR、SSIM、LPIPS
  • 生成图像:定期保存生成结果

评估策略

设置定期评估,监控模型性能:

validation:
  target: sgm.validation.Validation
  params:
    validate_every: 5000  # 每5000步评估一次
    n_validation_samples: 100  # 评估样本数量
    save_samples: True  # 保存样本图像
    metrics:
      - target: sgm.metrics.PSNR
      - target: sgm.metrics.SSIM
      - target: sgm.metrics.LPIPS

模型优化与加速

混合精度训练

SUPIR支持混合精度训练,减少显存占用:

training:
  precision: "bf16"  # 可选: fp32, fp16, bf16
  gradient_accumulation_steps: 4  # 梯度累积
  max_grad_norm: 1.0  # 梯度裁剪

显存优化

对于显存有限的情况,可采用以下策略:

  1. 模型并行:将模型拆分到多个GPU
  2. 梯度检查点:牺牲部分计算速度换取显存节省
  3. 低精度优化:使用fp16或bf16数据类型

配置示例:

# 低显存模式训练
python gradio_demo.py --loading_half_params --use_tile_vae --load_8bit_llava

常见问题与解决方案

训练不稳定

问题:Loss波动大,生成图像质量不稳定。 解决方案

  • 降低学习率,从2e-5调整为1e-5
  • 增加梯度裁剪,设置max_grad_norm: 0.5
  • 使用更大的批量大小,或启用梯度累积

显存不足

问题:训练过程中出现"CUDA out of memory"错误。 解决方案

  • 降低批量大小,从16调整为8
  • 启用混合精度训练
  • 使用梯度检查点:use_checkpoint: True

生成图像模糊

问题:输出图像过于模糊,细节丢失。 解决方案

  • 从L2 Loss切换为L1 Loss
  • 增加条件控制强度:s_stage2: 1.2
  • 调整采样参数:减少edm_steps,增加s_cfg

结论与展望

SUPIR通过创新的级联扩散架构和灵活的训练配置,实现了高质量的图像恢复。本文详细介绍了从环境配置、数据准备到模型训练、Loss优化的全流程。未来工作可关注以下方向:

  1. 多模态引导:结合文本和图像指导进行更精细的图像恢复
  2. 模型压缩:减小模型体积,实现实时推理
  3. 领域适应:针对特定场景(如医学影像、遥感图像)优化模型

通过不断优化训练策略和模型架构,SUPIR有望在更多实际应用场景中发挥作用,推动图像恢复技术的进一步发展。

【免费下载链接】SUPIR 【免费下载链接】SUPIR 项目地址: https://gitcode.com/gh_mirrors/su/SUPIR

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值