LaMa训练数据生成:随机掩码与真实场景的结合

LaMa训练数据生成:随机掩码与真实场景的结合

【免费下载链接】lama 🦙 LaMa Image Inpainting, Resolution-robust Large Mask Inpainting with Fourier Convolutions, WACV 2022 【免费下载链接】lama 项目地址: https://gitcode.com/GitHub_Trending/la/lama

引言:数据生成在图像修复中的关键作用

图像修复(Image Inpainting)技术的性能高度依赖训练数据的质量与多样性。LaMa(Large Mask Inpainting with Fourier Convolutions)作为2022年WACV提出的先进模型,其核心优势在于对大尺寸掩码的鲁棒处理能力。然而,这种能力的实现离不开精心设计的训练数据生成策略——随机掩码与真实场景数据的有机结合。本文将系统剖析LaMa的数据生成 pipeline,从掩码生成机制、真实场景数据处理到混合训练策略,全方位展示如何构建高效的图像修复训练数据集。

随机掩码生成:从参数化设计到多样化覆盖

1. 掩码生成器架构

LaMa采用模块化的掩码生成架构,通过MixedMaskGenerator类实现多种掩码类型的随机组合。其核心设计思想是:通过概率分布控制不同掩码类型的生成比例,模拟真实世界中可能出现的各种破损情况。

class MixedMaskGenerator:
    def __init__(self, irregular_proba=1, box_proba=0.3, segm_proba=0, ...):
        self.probas = np.array([irregular_proba, box_proba, segm_proba, ...])
        self.probas /= self.probas.sum()  # 归一化概率分布
        
    def __call__(self, img):
        kind = np.random.choice(len(self.probas), p=self.probas)
        return self.gens[kind](img)  # 根据概率选择掩码生成器

2. 三种基础随机掩码类型

LaMa定义了细(Thin)、中(Medium)、粗(Thick) 三种掩码类型,通过YAML配置文件精确控制生成参数。以下是256x256尺寸下的核心参数对比:

掩码类型线条数量 (min/max)最大宽度最大长度矩形概率面积占比
细 (Thin)4-5010px40px0%≤50%
中 (Medium)4-550px100px30%≤50%
粗 (Thick)1-5100px200px30%≤50%

细掩码配置示例(random_thin_256.yaml):

generator_kind: random
mask_generator_kwargs:
  irregular_proba: 1  # 仅生成不规则线条
  irregular_kwargs:
    min_times: 4       # 最少4条线
    max_times: 50      # 最多50条线
    max_width: 10      # 线条最大宽度10px
    max_len: 40        # 线条最大长度40px
  box_proba: 0         # 不生成矩形
max_tamper_area: 0.5   # 最大遮挡面积50%

3. 掩码生成算法实现

LaMa通过saicinpainting/training/data/masks.py实现了多种掩码生成器,核心逻辑如下:

class MixedMaskGenerator:
    def __init__(self, irregular_proba=1, box_proba=0.3, segm_proba=0):
        self.probas = np.array([irregular_proba, box_proba, segm_proba])
        self.probas /= self.probas.sum()
        self.gens = [
            RandomIrregularMaskGenerator(),  # 不规则线条
            RandomRectangleMaskGenerator(),  # 矩形框
            RandomSegmentationMaskGenerator() # 分割掩码
        ]
    
    def __call__(self, img):
        # 按概率选择掩码类型
        kind = np.random.choice(len(self.probas), p=self.probas)
        return self.gens[kind](img)

不规则线条生成核心代码

def make_random_irregular_mask(shape, max_angle=4, max_len=40, max_width=10):
    mask = np.zeros(shape, np.float32)
    for _ in range(np.random.randint(4, 50)):  # 随机线条数量
        angle = np.random.randint(max_angle)  # 随机角度
        length = np.random.randint(max_len)   # 随机长度
        brush_w = np.random.randint(max_width)# 随机宽度
        # 随机起点和终点
        start_x, start_y = np.random.randint(0, shape[1]), np.random.randint(0, shape[0])
        end_x = np.clip(start_x + length * np.sin(angle), 0, shape[1])
        end_y = np.clip(start_y + length * np.cos(angle), 0, shape[0])
        cv2.line(mask, (start_x, start_y), (end_x, end_y), 1.0, brush_w)
    return mask[None, ...]  # 添加通道维度

真实场景数据处理:从原始图像到训练样本

1. 数据集准备流程

LaMa支持CelebA-HQ和Places等真实场景数据集,通过fetch_data目录下的脚本自动化处理流程:

CelebA-HQ处理示例(celebahq_dataset_prepare.sh):

# 创建目录并解压
mkdir celeba-hq-dataset
unzip data256x256.zip -d celeba-hq-dataset/

# 重命名文件(00001.jpg → 0.jpg)
for i in `echo {00001..30000}`; do
    mv "celeba-hq-dataset/data256x256/$i.jpg" "celeba-hq-dataset/data256x256/$[10#$i - 1].jpg"
done

# 划分训练/验证集
cat fetch_data/train_shuffled.flist | shuf > temp_train.flist
head -n 2000 temp_train.flist > val_shuffled.flist  # 2000张验证集
tail -n +2001 temp_train.flist > train_shuffled.flist  # 剩余训练集

2. Places数据集掩码生成

Places数据集通过gen_mask_dataset.py批量生成掩码,支持多尺度掩码类型:

# 生成512x512粗掩码
python3 bin/gen_mask_dataset.py \
  configs/data_gen/random_thick_512.yaml \
  places_standard_dataset/evaluation/hires \
  places_standard_dataset/evaluation/random_thick_512/

生成流程包含三个关键步骤:

  1. 图像加载:读取高分辨率原始图像
  2. 掩码生成:根据配置文件参数生成随机掩码
  3. 数据对保存:存储原始图像-掩码对供训练使用

数据混合策略:随机与真实的协同训练

1. 训练数据配置

LaMa通过YAML配置文件定义数据加载和混合策略,典型配置如configs/training/data/abl-04-256-mh-dist.yaml

batch_size: 10
val_batch_size: 2
num_workers: 3

train:
  indir: ${location.data_root_dir}/train
  out_size: 256  # 输出尺寸256x256
  mask_gen_kwargs:
    irregular_proba: 1  # 不规则线条概率100%
    irregular_kwargs:
      max_len: 200       # 最大线条长度200px
    box_proba: 1         # 矩形概率100%
    box_kwargs:
      bbox_min_size: 30  # 矩形最小尺寸30px
      max_times: 4       # 最多4个矩形
  transform_variant: distortions  # 应用畸变数据增强

2. 数据加载流水线

训练时的数据加载流程如下: mermaid

关键增强策略

  • 随机裁剪:确保输入尺寸一致
  • 色彩抖动:亮度、对比度随机调整
  • 水平翻转:50%概率水平翻转
  • 掩码混合:随机选择不规则线条或矩形掩码

3. 分割掩码融合

对于真实场景掩码,LaMa使用基于Detectron2的实例分割生成器:

class RandomSegmentationMaskGenerator:
    def __init__(self, confidence_threshold=0.5):
        self.cfg = get_cfg()
        self.cfg.merge_from_file(model_zoo.get_config_file(
            "COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml"))
        self.predictor = DefaultPredictor(self.cfg)
    
    def __call__(self, img):
        # 运行全景分割
        panoptic_seg, segments_info = self.predictor(img)["panoptic_seg"]
        # 提取实例掩码
        masks = [self._get_mask(panoptic_seg, seg["id"]) 
                 for seg in segments_info if seg["isthing"]]
        return np.random.choice(masks)  # 随机选择一个实例掩码

工程实践:从配置到部署

1. 数据生成命令行工具

LaMa提供gen_mask_dataset.py工具批量生成训练数据:

# 基本用法
python3 bin/gen_mask_dataset.py \
  <配置文件路径> \
  <输入图像目录> \
  <输出目录> \
  [--num-workers NUM_WORKERS]  # 并行处理进程数

2. 目录结构规范

推荐的数据集目录结构:

places_standard_dataset/
├── evaluation/
│   ├── hires/              # 原始高分辨率图像
│   ├── random_thin_256/    # 256细掩码
│   ├── random_medium_256/  # 256中掩码
│   └── random_thick_256/   # 256粗掩码
└── train/                  # 训练集
    ├── images/             # 训练图像
    └── masks/              # 对应掩码

总结与展望

LaMa通过参数化随机掩码生成真实场景数据处理灵活的混合策略,构建了高效的图像修复训练数据生成 pipeline。这种方法的核心优势在于:

  1. 多样性:支持细/中/粗多类型掩码,覆盖不同破损场景
  2. 真实性:结合CelebA-HQ/Places等真实数据集
  3. 灵活性:通过配置文件轻松调整掩码参数和数据分布

未来可探索的改进方向包括:

  • 基于语义的智能掩码生成
  • 动态难度调整的掩码策略
  • 跨数据集的迁移学习能力

通过本文介绍的方法,开发者可以构建高质量的图像修复训练数据,为模型性能优化奠定基础。


收藏与关注:本文代码示例和配置文件均来自LaMa官方仓库,完整实现可参考:
git clone https://gitcode.com/GitHub_Trending/la/lama

下期预告:LaMa模型架构深度解析——傅里叶卷积的图像修复应用

【免费下载链接】lama 🦙 LaMa Image Inpainting, Resolution-robust Large Mask Inpainting with Fourier Convolutions, WACV 2022 【免费下载链接】lama 项目地址: https://gitcode.com/GitHub_Trending/la/lama

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值