VmambaIR 代码复现

1、环境

conda create -n vmambair python=3.9
conda activate vmambair
pip install torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 --index-url https://download.pytorch.org/whl/cu121
cd kernels/selective_scan
pip install .
cd Deraining
bash pip.sh
pip install timm fvcore

2、trian

创建文件夹 Options

创建文件 Deraining.yml

# general settings
name: Deraining_Restormer  # 项目名称或配置名称
model_type: ImageCleanModel  # 模型类型,这里是图像清理模型
scale: 1  # 缩放比例(通常用于超分辨率任务,这里是1,表示不缩放)
num_gpu: 1  # 训练时使用的GPU数量,设置为0表示CPU模式
manual_seed: 100  # 随机种子,用于保证实验的可复现性

# dataset and data loader settings
datasets:
  train:
    name: TrainSet  # 训练集的名称
    type: Dataset_PairedImage  # 数据集类型,这里是配对图像数据集
    dataroot_gt: /root/autodl-tmp/VmambaIR-main/Deraining/Dataset/train/B  # 高质量(Ground Truth)图像的路径
    dataroot_lq: /root/autodl-tmp/VmambaIR-main/Deraining/Dataset/train/A  # 低质量(输入)图像的路径
    geometric_augs: true  # 是否开启几何增强(如旋转、翻转等)

#    filename_tmpl: '{}'  # 文件名模板,用于匹配图像文件名
    io_backend:
      type: disk  # 输入输出后端类型,这里是磁盘存储

    # data loader
    use_shuffle: true  # 是否在每个epoch随机打乱数据
    num_worker_per_gpu: 4  # 每个GPU的Dataloader线程数
    batch_size_per_gpu: 2  # 每个GPU的批量大小

#    ### -------------Progressive training--------------------------
#    mini_batch_sizes: [8,5,4,2,1,1]  # 进行渐进式训练时,每个阶段的批量大小
#    iters: [92000,64000,48000,36000,36000,24000]  # 进行渐进式训练时,每个阶段的迭代次数
#    gt_size: 384  # 渐进式训练中的最大Patch大小
#    gt_sizes: [128,160,192,256,320,384]  # 渐进式训练中各个阶段的Patch大小
#    ### ------------------------------------------------------------

    ## ------- Training on single fixed-patch size 128x128---------
#     如果需要固定Patch大小,可以取消下面的注释
    mini_batch_sizes: [8]
    iters: [300000]
    gt_size: 256
    gt_sizes: [256]
    ## ------------------------------------------------------------

    dataset_enlarge_ratio: 1  # 数据集扩大的倍数
    prefetch_mode: ~  # 预取模式(暂未配置)

  val:
    name: ValSet  # 验证集的名称
    type: Dataset_PairedImage  # 数据集类型
    dataroot_gt: /root/autodl-tmp/VmambaIR-main/Deraining/Dataset/val/B  # 验证集的高质量图像路径
    dataroot_lq: /root/autodl-tmp/VmambaIR-main/Deraining/Dataset/val/A  # 验证集的低质量图像路径
    io_backend:
      type: disk  # 输入输出后端类型

# network structures
network_g:
  type: Restormer  # 网络类型,这里是Restormer
  inp_channels: 3  # 输入通道数
  out_channels: 3  # 输出通道数
  dim: 48  # 网络的维度
  num_blocks: [4,6,6,8]  # 各个残差块的数量
  num_refinement_blocks: 4  # 用于精炼的残差块数量
  heads: [1,2,4,8]  # 注意力头的数量
  ffn_expansion_factor: 2.66  # 前馈网络的扩展因子
  bias: False  # 是否使用偏置
  LayerNorm_type: WithBias  # 使用的LayerNorm类型
  dual_pixel_task: False  # 是否开启双像素任务模式


# path
path:
  pretrain_network_g: ~  # 预训练模型路径(暂未配置)
  strict_load_g: true  # 加载预训练模型时是否严格匹配网络结构
  resume_state: ~  # 恢复训练状态的路径(暂未配置)

# training settings
train:
  total_iter: 300  # 总训练迭代次数
  warmup_iter: -1  # 预热迭代次数(-1表示无预热)
  use_grad_clip: true  # 是否启用梯度裁剪

  # Split 300k iterations into two cycles.
  # 1st cycle: fixed 3e-4 LR for 92k iters.
  # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters.
  scheduler:
    type: CosineAnnealingRestartCyclicLR  # 学习率调度器类型,这里是余弦退火重启周期学习率
    periods: [92000, 208000]  # 每个周期的迭代次数
    restart_weights: [1,1]  # 每个周期的重启权重
    eta_mins: [0.0003,0.000001]  # 每个周期的最小学习率

  mixing_augs:
    mixup: false  # 是否启用Mixup数据增强
    mixup_beta: 1.2  # Mixup数据增强的Beta分布参数
    use_identity: true  # 是否使用恒等混合(Mixup的一种变体)

  optim_g:
    type: AdamW  # 优化器类型,这里是AdamW
    lr: !!float 3e-4  # 初始学习率
    weight_decay: !!float 1e-4  # 权重衰减
    betas: [0.9, 0.999]  # Adam优化器的Beta参数

  # losses
  pixel_opt:
    type: L1Loss  # 像素级损失类型,这里是L1损失
    loss_weight: 1  # 损失的权重系数
    reduction: mean  # 损失的缩减方式

# validation settings
val:
  window_size: 8  # 推理时的窗口大小
  val_freq: !!float 4e3  # 验证频率(验证间隔次数)
  save_img: false  # 是否保存验证图像
  rgb2bgr: true  # 是否将RGB图像转换为BGR格式
  use_image: true  # 是否使用图像进行验证
  max_minibatch: 8  # 验证时的最大Mini-Batch大小

  metrics:
    psnr:  # PSNR(峰值信噪比)指标
      type: calculate_psnr  # PSNR计算方式
      crop_border: 0  # 裁剪边界大小
      test_y_channel: true  # 是否仅测试Y通道


# logging settings
logger:
  print_freq: 1000  # 训练日志打印频率
  save_checkpoint_freq: !!float 4e3  # 检查点保存频率(每隔多少次迭代保存一次模型)
  use_tb_logger: true  # 是否使用TensorBoard记录日志
  wandb:
    project: ~  # Weights & Biases项目名称(暂未配置)
    resume_id: ~  # 恢复训练的W&B ID(暂未配置)

# dist training settings
dist_params:
  backend: nccl  # 分布式训练的后端类型
  port: 29500  # 分布式训练的端口号

启动命令:

python basicsr/train.py -opt /root/autodl-tmp/VmambaIR-main/Deraining/Options/Deraining.yml

 运行成功:

3、test

python test.py

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值