Stable Diffusion做数据增强?开发者的新武器来了!
Stable Diffusion做数据增强?开发者的新武器来了!
“数据不够?那就让AI自己画!”——某位凌晨三点还在调prompt的算法工程师
当训练数据成了“稀有动物”
做CV的兄弟姐妹都懂,数据永远比idea贵。
老板一句“我要模型精度再涨5个点”,背后往往是标注团队通宵达旦地画框、打点、写标签。更惨的是,有些场景连原始图片都凑不齐:
- 医疗影像里,某种罕见病灶一年才出现几十例;
- 工业产线上,缺陷样本比996的程序员还稀缺;
- 新零售商品库,长尾SKU的货架图只能靠采购小哥手机随手拍——光照、角度、背景全靠缘分。
传统 augmentation 三板斧(旋转、裁剪、颜色抖动)在这些场景下就像用指甲刀砍大树,语义信息没变,但也没增加多少新东西。
直到某天,我盯着 Stable Diffusion 生成的“赛博朋克猫”出神,脑子里突然蹦出一个念头:
既然它能画猫,能不能画“缺陷”?
于是,这篇“血泪踩坑史”就有了开头。
为什么偏偏是 Stable Diffusion?
先别急着抄家伙,生成式模型那么多,凭啥选它?
| 模型 | 可控性 | 开源程度 | 消费级显卡友好度 | 备注 |
|---|---|---|---|---|
| StyleGAN3 | 中 | 高 | 凑合 | 画风偏“艺术”,语义控制需额外网络 |
| DALL·E 2 | 高 | 闭源 | ❌ | API 限速+钱包警告 |
| Midjourney | 高 | 闭源 | ❌ | 付费+不能本地批量 |
| Stable Diffusion | 高 | 完全开源 | RTX 3060 就能 512² 跑 batch | 社区轮子多到用不完 |
一句话:免费、本地、可批量、可微调、社区还卷。
对我们这些**“公司只给预算 0 元”**的开发者来说,它就是天降正义。
把“魔法”拆开:Stable Diffusion 到底干了啥?
“别急着念咒,先搞清楚魔杖是什么木头。”
1. 潜在空间里的“降噪游戏”
Stable Diffusion 把图像压缩到 64×64 的潜在向量(latent space),然后在这块“小画布”上做扩散——前向加噪、反向去噪。
好处?
- 比直接操作像素省显存,512² 图在 8G 显存里能跑 batch=8;
- latent 空间天生带“语义坐标”,文本 embedding 像遥控器,往哪儿走它都听得懂。
2. 提示词 = 遥控器的“按钮组合”
正向 prompt:a photo of cracked phone screen, close-up, industrial inspection, 4K, sharp
负向 prompt:cartoon, painting, lowres, blurry, extra fingers
负向 prompt 是隐藏宝藏:把“不想要的”写进去,比单纯堆正向词更能减少废图。
3. ControlNet:给“画家”一把尺子
纯文本容易“抽卡”,ControlNet 把 Canny 边缘、深度图、语义分割 mask 变成“草图”,让生成结果结构不变、纹理随便换。
做数据增强时,原图边缘图 + 随机 prompt = 同一结构不同外观,完美。
4. LoRA:不煮大锅饭,只开小灶
全量微调 4 GB 模型?老板不给显卡。
LoRA 把权重更新拆成两个小矩阵,训练量降到 1/10,10 张图 10 分钟就能学会“某种裂纹风格”,迁移学习神器。
搭一条“可控”的增强流水线
“没有流程的生成,都是玄学。”
下面给出一条工业界能落地的 Python 流水线,每一步都能打断点 debug,拒绝黑箱。
0. 环境一键复现
# 建议用 conda,别问,问就是省头发
conda create -n sdaug python=3.10
conda activate sdaug
pip install diffusers==0.21.0 transformers accelerate xformers opencv-python safetensors
# 显卡>=16G 可不开 xformers,<16G 建议加上,省显存
1. 原图→边缘图:保留结构
# canny_extract.py
import cv2
import os
def extract_canny(img_path, low=100, high=200, output_size=512):
img = cv2.imread(img_path)
img = cv2.resize(img, (output_size, output_size))
canny = cv2.Canny(img, low, high)
# 扩通道,适配 ControlNet 输入
canny = cv2.cvtColor(canny, cv2.COLOR_GRAY2RGB)
return canny
# 批量处理
os.makedirs("canny_dir", exist_ok=True)
for f in os.listdir("raw_images"):
canny = extract_canny(f"raw_images/{f}")
cv2.imwrite(f"canny_dir/{f}", canny)
小贴士:Canny 阈值别手抖,低阈值太高会把细节弄丢,经验值 50/150 起步,每张图都用同样阈值,保证后续对齐。
2. prompt 模板:把“随机”装进笼子里
# prompt_bank.py
templates = {
"crack": [
"a photo of {defect} on {object}, industrial scene, {lighting}, 4K, sharp, no text",
"close-up shot of {defect} defect, metal surface, {lighting}, realistic, high contrast"
],
"lighting": ["under factory LED light", "natural daylight", "dim warehouse light", "fluorescent tube light"],
"object": ["aluminum panel", "steel plate", "phone screen", "car bumper"]
}
def sample_prompt(defect="crack"):
import random
t = random.choice(templates[defect])
lighting = random.choice(templates["lighting"])
obj = random.choice(templates["object"])
return t.format(defect=defect, lighting=lighting, object=obj)
模板化 = 可复现 + 可单元测试,别小看这一步,后期排查语义漂移全靠它。
3. 图像→图像:把边缘图喂给 Stable Diffusion
# sd_aug.py
import torch
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
from diffusers.utils import load_image
import os
# 1. 加载 ControlNet
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16
)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", controlnet=controlnet,
torch_dtype=torch.float16
).to("cuda")
# 2. 内存优化三板斧
pipe.enable_xformers_memory_efficient_attention()
pipe.enable_attention_slicing()
pipe.enable_model_cpu_offload() # batch 大时开
# 3. 批量生成
os.makedirs("aug_images", exist_ok=True)
for idx, canny_file in enumerate(os.listdir("canny_dir")):
canny_image = load_image(f"canny_dir/{canny_file}")
prompt = sample_prompt(defect="crack")
negative = "cartoon, painting, lowres, blurry, extra fingers, text, watermark"
out = pipe(
prompt=prompt,
negative_prompt=negative,
image=canny_image,
num_inference_steps=30,
guidance_scale=7.5,
generator=torch.Generator().manual_seed(42 + idx), # 可复现
strength=0.9 # 0~1,越大越偏离原图
).images[0]
out.save(f"aug_images/{idx:04d}.jpg")
strength 参数是灵魂:
- 0.7 以下:基本只是“重新打光”;
- 0.9 左右:结构保留但纹理大换血;
- 1.0:放飞自我,可能把裂纹画成涂鸦。
4. 自动过滤:别让“垃圾”进数据集
# filter.py
from transformers import CLIPProcessor, CLIPModel
import torch, os, json
clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to("cuda")
proc = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
def clip_score(image, text):
inputs = proc(text=[text], images=image, return_tensors="pt").to("cuda")
with torch.no_grad():
outputs = clip(**inputs)
# 余弦相似度
logits = outputs.logits_per_image
return logits.item()
threshold = 28 # 经验值,按业务调
manifest = []
for imgf in os.listdir("aug_images"):
from PIL import Image
img = Image.open(f"aug_images/{imgf}")
score = clip_score(img, "a photo of cracked phone screen")
if score >= threshold:
manifest.append({"file": imgf, "clip": score})
json.dump(manifest, open("valid_images.json", "w"), ensure_ascii=False, indent=2)
print(f"过滤后剩余 {len(manifest)} 张,淘汰率 {1-len(manifest)/len(os.listdir('aug_images')):.2%}")
CLIP 分数不是圣旨,只能当“初筛”。后续还要让业务分类器回测,看加入这些图后 validation 涨不涨,再决定要不要下调阈值。
三个真实到“掉头发”的落地案例
1. 医疗影像:给罕见病灶“加戏”
背景:某三甲放射科,早期肺结节 CT 只有 87 张阳性,阴性 3000+,模型快把阳性当成“外星人”。
方案:
- 用
nnUNet把原图结节 mask 抠出来→生成深度图→喂 ControlNet; - prompt 模板里加入“-1mm 薄层、低剂量、40 岁患者”等医学关键词,让生成图自带 CT 质感;
- 生成 800 张后,请主任肉眼筛掉 120 张“不像病灶”的废图(医生眼光毒辣,一眼看穿伪影)。
结果:
- 召回率从 0.61 → 0.78,假阳性降 35%;
- 论文投 MICCAI,审稿人唯一质疑:“伦理批准呢?”——生成数据也要补伦理批件,别踩坑。
2. 工业质检:把“缺陷”搬到不同产线
背景:手机盖板玻璃,裂纹样本 214 张,客户要求识别 5 条产线、3 种光照、4 种角度共 60 种工况。
方案:
- 用 Blender 批量渲染虚拟边缘图(裂纹形状固定,角度/光照随意调),一天造 2W 张边缘图;
- 边缘图 + ControlNet 生成真实纹理,prompt 里随机“factory LED/natural light/fluorescent”;
- 训练前用传统增强 + 生成图 1:1 混合,防止过拟合“AI 画风”。
结果:
- 上线后客户现场采集 1000 张实测,精度 96.4% → 98.1%;
- 客户爸爸一句“效果不错”,项目经理终于不用通宵陪产线了。
3. 零售 SKU:让长尾商品“摆满货架”
背景:新零售货架审核,长尾 SKU 只有 1 张官网渲染图,现场拍到的角度、光照、遮挡千奇百怪,模型直接“脸盲”。
方案:
- 官网图抠透明 PNG→生成 360° 旋转 8 角度→深度图;
- ControlNet 生成货架背景,prompt 里随机“supermarket/warehouse/convenience store”;
- 用复制-粘贴+泊松融合把商品贴回实景,再跑一遍 CLIP 过滤,去掉“悬浮”伪影。
结果:
- 单 SKU 从 1 张扩到 120 张,现场实测召回提升 27%;
- 运营小姐姐终于不用再“求爷爷告奶奶”地找门店拍照。
生成数据不是“仙丹”:评估与清洗
“AI 生成的图,自己都不一定相信。”
1. 人工定性:三张表搞定
| 维度 | 评分 1~5 | 说明 |
|---|---|---|
| 语义一致性 | 裂纹还是划痕? | 业务专家最懂 |
| 外观自然度 | 有无伪影? | 肉眼即可 |
| 多样性 | 10 张图是否雷同? | 打眼一看就知 |
经验:找 3 个标注员,每人 100 张,Kappa>0.75 才算一致,否则重新培训。
2. 自动定量:CLIP + FID + 分类器回测
# fid_score.py
# 先 pip install pytorch-fid
import subprocess
real_path = "raw_images"
fake_path = "aug_images"
subprocess.run(f"python -m pytorch_fid {real_path} {fake_path}", shell=True)
FID<50 基本可用,<30 属于“以假乱真”,>100 直接回炉。
3. 分类器回测:一图胜千言
- 基线模型:只用真实图;
- 实验组:真实+生成(比例 1:1、1:2、2:1 都试);
- 观察 validation 曲线:
- 涨点 → 继续加;
- 掉点 → 过滤 or 降权;
- 过拟合 → 早停 + 正则。
踩坑实录:我们被 Stable Diffusion 坑过的 7 个大坑
| 坑 | 症状 | 排查 | 缓解 |
|---|---|---|---|
| 语义漂移 | 要“裂纹”却出“划痕” | prompt 模板太泛 | 加“close-up, macro”+ 负向“scratch” |
| 模式崩溃 | 100 张图一个姿势 | seed 区间太小 | 把 seed 打散 + strength 0.7~0.9 随机 |
| 类别混淆 | 把“污渍”当成“裂纹” | CLIP 分数虚高 | 加分类器二筛,把置信度<0.8 的踢掉 |
| 显存爆炸 | 凌晨 3 点 OOM | batch=8 太大 | enable_model_cpu_offload + batch=2 |
| 色差灾难 | 生成图偏蓝 | 原图白平衡被改 | 生成后统一做颜色校正(cv2.COLOR_BGR2LAB + 直方图匹配) |
| 伪影“油画” | 纹理像糊墙 | steps<20 | 至少 30 步,开 xformers 也不省这一步 |
| 伦理踩雷 | 生成病人 CT 被质疑 | 未脱敏 | 所有医疗数据先脱敏+伦理批件,生成图也要走审查 |
效率与质量双赢的“猥琐技巧”
1. LoRA 10 分钟速成
# 准备 20 张目标风格图,统一 512²
accelerate launch train_network.py \
--pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \
--dataset_config="toml/裂纹.toml" \
--output_dir="lora/crack_style" \
--network_module=networks.lora \
--network_dim=32 \
--max_train_epochs=10 \
--learning_rate=1e-4 \
--unet_lr=1e-4 \
--text_encoder_lr=1e-5
dim=32 足够,再大就是浪费显卡。
训练完拿到crack_style.safetensors,inference 时一句--network_module就能加载,比全量微调省 90% 时间。
2. 批量生成显存“三件套”
pipe.enable_xformers_memory_efficient_attention()
pipe.enable_attention_slicing()
pipe.enable_model_cpu_offload() # batch>4 必开
亲测 RTX 3060 12G,512² 开 batch=6 稳如老狗,再往上就 OOM。
3. OpenCV 后处理去伪影
# denoise.py
import cv2
def remove_oil_painting_artifact(img_path):
img = cv2.imread(img_path)
# 快速去噪同时保留边缘
dst = cv2.edgePreservingFilter(img, flags=1, sigma_s=60, sigma_r=0.4)
# 轻微锐化
kernel = np.array([[0,-1,0], [-1,5,-1], [0,-1,0]], np.float32)
dst = cv2.filter2D(dst, -1, kernel=kernel)
cv2.imwrite(img_path, dst)
后处理跑 1000 张图只要 30 秒,FID 能再降 3~5 个点,肉眼可见更“真实”。
4. 生成-验证-反馈闭环
# loop.py
while True:
generate() # 生成
filter_clip() # 自动筛
retrain() # 重训模型
eval() # 看指标
if metric > target:
break
else:
update_prompt() # 把 false positive 样例反向写进负向 prompt
把“踩坑”自动写进模板,越跑越聪明,这才是“可持续卷”。
清醒一点:生成数据不是万能药
- 真实数据永远是大爷,生成数据只能当“补剂”;
- 比例别超过 1:1,除非你想看模型在真实场景“翻车”;
- 生成图分布再真,也缺“长尾噪声”,建议搭配主动学习——把模型最难分辨的样本送回人工标注,形成“真实→生成→真实”的飞轮。
尾声:下一个数据瓶颈,让 AI 自己“画”过去
Stable Diffusion 不是玩具,它是程序员手里的“数据 3D 打印机”。
当你再一次为“样本不足”而抓狂时,不妨打开 IDE,写几行 prompt,让 GPU 替你去“拍照”。
也许明天早上,你会收获一文件夹“新鲜出炉”的裂纹、病灶、SKU,而它们的名字,叫 hope_001.jpg ~ hope_999.jpg。
程序员不是魔法师,但我们可以用代码,把“不可能”变成“再跑一轮 epoch”。
祝你下一个 epoch,数据管够,显存不爆,老板不催,论文接收。
生成愉快,晚安。

1032

被折叠的 条评论
为什么被折叠?



