LaMa训练数据生成:随机掩码与真实场景的结合
引言:数据生成在图像修复中的关键作用
图像修复(Image Inpainting)技术的性能高度依赖训练数据的质量与多样性。LaMa(Large Mask Inpainting with Fourier Convolutions)作为2022年WACV提出的先进模型,其核心优势在于对大尺寸掩码的鲁棒处理能力。然而,这种能力的实现离不开精心设计的训练数据生成策略——随机掩码与真实场景数据的有机结合。本文将系统剖析LaMa的数据生成 pipeline,从掩码生成机制、真实场景数据处理到混合训练策略,全方位展示如何构建高效的图像修复训练数据集。
随机掩码生成:从参数化设计到多样化覆盖
1. 掩码生成器架构
LaMa采用模块化的掩码生成架构,通过MixedMaskGenerator类实现多种掩码类型的随机组合。其核心设计思想是:通过概率分布控制不同掩码类型的生成比例,模拟真实世界中可能出现的各种破损情况。
class MixedMaskGenerator:
def __init__(self, irregular_proba=1, box_proba=0.3, segm_proba=0, ...):
self.probas = np.array([irregular_proba, box_proba, segm_proba, ...])
self.probas /= self.probas.sum() # 归一化概率分布
def __call__(self, img):
kind = np.random.choice(len(self.probas), p=self.probas)
return self.gens[kind](img) # 根据概率选择掩码生成器
2. 三种基础随机掩码类型
LaMa定义了细(Thin)、中(Medium)、粗(Thick) 三种掩码类型,通过YAML配置文件精确控制生成参数。以下是256x256尺寸下的核心参数对比:
| 掩码类型 | 线条数量 (min/max) | 最大宽度 | 最大长度 | 矩形概率 | 面积占比 |
|---|---|---|---|---|---|
| 细 (Thin) | 4-50 | 10px | 40px | 0% | ≤50% |
| 中 (Medium) | 4-5 | 50px | 100px | 30% | ≤50% |
| 粗 (Thick) | 1-5 | 100px | 200px | 30% | ≤50% |
细掩码配置示例(random_thin_256.yaml):
generator_kind: random
mask_generator_kwargs:
irregular_proba: 1 # 仅生成不规则线条
irregular_kwargs:
min_times: 4 # 最少4条线
max_times: 50 # 最多50条线
max_width: 10 # 线条最大宽度10px
max_len: 40 # 线条最大长度40px
box_proba: 0 # 不生成矩形
max_tamper_area: 0.5 # 最大遮挡面积50%
3. 掩码生成算法实现
LaMa通过saicinpainting/training/data/masks.py实现了多种掩码生成器,核心逻辑如下:
class MixedMaskGenerator:
def __init__(self, irregular_proba=1, box_proba=0.3, segm_proba=0):
self.probas = np.array([irregular_proba, box_proba, segm_proba])
self.probas /= self.probas.sum()
self.gens = [
RandomIrregularMaskGenerator(), # 不规则线条
RandomRectangleMaskGenerator(), # 矩形框
RandomSegmentationMaskGenerator() # 分割掩码
]
def __call__(self, img):
# 按概率选择掩码类型
kind = np.random.choice(len(self.probas), p=self.probas)
return self.gens[kind](img)
不规则线条生成核心代码:
def make_random_irregular_mask(shape, max_angle=4, max_len=40, max_width=10):
mask = np.zeros(shape, np.float32)
for _ in range(np.random.randint(4, 50)): # 随机线条数量
angle = np.random.randint(max_angle) # 随机角度
length = np.random.randint(max_len) # 随机长度
brush_w = np.random.randint(max_width)# 随机宽度
# 随机起点和终点
start_x, start_y = np.random.randint(0, shape[1]), np.random.randint(0, shape[0])
end_x = np.clip(start_x + length * np.sin(angle), 0, shape[1])
end_y = np.clip(start_y + length * np.cos(angle), 0, shape[0])
cv2.line(mask, (start_x, start_y), (end_x, end_y), 1.0, brush_w)
return mask[None, ...] # 添加通道维度
真实场景数据处理:从原始图像到训练样本
1. 数据集准备流程
LaMa支持CelebA-HQ和Places等真实场景数据集,通过fetch_data目录下的脚本自动化处理流程:
CelebA-HQ处理示例(celebahq_dataset_prepare.sh):
# 创建目录并解压
mkdir celeba-hq-dataset
unzip data256x256.zip -d celeba-hq-dataset/
# 重命名文件(00001.jpg → 0.jpg)
for i in `echo {00001..30000}`; do
mv "celeba-hq-dataset/data256x256/$i.jpg" "celeba-hq-dataset/data256x256/$[10#$i - 1].jpg"
done
# 划分训练/验证集
cat fetch_data/train_shuffled.flist | shuf > temp_train.flist
head -n 2000 temp_train.flist > val_shuffled.flist # 2000张验证集
tail -n +2001 temp_train.flist > train_shuffled.flist # 剩余训练集
2. Places数据集掩码生成
Places数据集通过gen_mask_dataset.py批量生成掩码,支持多尺度掩码类型:
# 生成512x512粗掩码
python3 bin/gen_mask_dataset.py \
configs/data_gen/random_thick_512.yaml \
places_standard_dataset/evaluation/hires \
places_standard_dataset/evaluation/random_thick_512/
生成流程包含三个关键步骤:
- 图像加载:读取高分辨率原始图像
- 掩码生成:根据配置文件参数生成随机掩码
- 数据对保存:存储原始图像-掩码对供训练使用
数据混合策略:随机与真实的协同训练
1. 训练数据配置
LaMa通过YAML配置文件定义数据加载和混合策略,典型配置如configs/training/data/abl-04-256-mh-dist.yaml:
batch_size: 10
val_batch_size: 2
num_workers: 3
train:
indir: ${location.data_root_dir}/train
out_size: 256 # 输出尺寸256x256
mask_gen_kwargs:
irregular_proba: 1 # 不规则线条概率100%
irregular_kwargs:
max_len: 200 # 最大线条长度200px
box_proba: 1 # 矩形概率100%
box_kwargs:
bbox_min_size: 30 # 矩形最小尺寸30px
max_times: 4 # 最多4个矩形
transform_variant: distortions # 应用畸变数据增强
2. 数据加载流水线
训练时的数据加载流程如下:
关键增强策略:
- 随机裁剪:确保输入尺寸一致
- 色彩抖动:亮度、对比度随机调整
- 水平翻转:50%概率水平翻转
- 掩码混合:随机选择不规则线条或矩形掩码
3. 分割掩码融合
对于真实场景掩码,LaMa使用基于Detectron2的实例分割生成器:
class RandomSegmentationMaskGenerator:
def __init__(self, confidence_threshold=0.5):
self.cfg = get_cfg()
self.cfg.merge_from_file(model_zoo.get_config_file(
"COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml"))
self.predictor = DefaultPredictor(self.cfg)
def __call__(self, img):
# 运行全景分割
panoptic_seg, segments_info = self.predictor(img)["panoptic_seg"]
# 提取实例掩码
masks = [self._get_mask(panoptic_seg, seg["id"])
for seg in segments_info if seg["isthing"]]
return np.random.choice(masks) # 随机选择一个实例掩码
工程实践:从配置到部署
1. 数据生成命令行工具
LaMa提供gen_mask_dataset.py工具批量生成训练数据:
# 基本用法
python3 bin/gen_mask_dataset.py \
<配置文件路径> \
<输入图像目录> \
<输出目录> \
[--num-workers NUM_WORKERS] # 并行处理进程数
2. 目录结构规范
推荐的数据集目录结构:
places_standard_dataset/
├── evaluation/
│ ├── hires/ # 原始高分辨率图像
│ ├── random_thin_256/ # 256细掩码
│ ├── random_medium_256/ # 256中掩码
│ └── random_thick_256/ # 256粗掩码
└── train/ # 训练集
├── images/ # 训练图像
└── masks/ # 对应掩码
总结与展望
LaMa通过参数化随机掩码生成、真实场景数据处理和灵活的混合策略,构建了高效的图像修复训练数据生成 pipeline。这种方法的核心优势在于:
- 多样性:支持细/中/粗多类型掩码,覆盖不同破损场景
- 真实性:结合CelebA-HQ/Places等真实数据集
- 灵活性:通过配置文件轻松调整掩码参数和数据分布
未来可探索的改进方向包括:
- 基于语义的智能掩码生成
- 动态难度调整的掩码策略
- 跨数据集的迁移学习能力
通过本文介绍的方法,开发者可以构建高质量的图像修复训练数据,为模型性能优化奠定基础。
收藏与关注:本文代码示例和配置文件均来自LaMa官方仓库,完整实现可参考:
git clone https://gitcode.com/GitHub_Trending/la/lama
下期预告:LaMa模型架构深度解析——傅里叶卷积的图像修复应用
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考



