半监督学习实战:Stable Diffusion如何用少量标签数据炼出高质量图
半监督学习实战:Stable Diffusion如何用少量标签数据炼出高质量图像
“老板只给了 1000 张标注图,却想要 4K 高清、细节拉满、风格统一的生成效果。”
如果你也曾在工位上听到这种离谱需求,那么恭喜你,今天我们要聊的正是“如何用半监督学习把老板的白日梦做成 PPT”。
引言:当标注成本高得离谱,我们还能不能训练出靠谱的生成模型?
做视觉生成项目,最烧钱的不是显卡,也不是电费,而是——标注。
一张图要是请美院同学细抠 mask、描边、打属性标签,动辄 15 分钟,换算成人民币就是一杯喜茶。10 万张图?直接一套首付没了。
更惨的是,Stable Diffusion 这类扩散模型天生“贪吃”,你喂它多少它就能吃多少,吃不够就给你脸色看:纹理糊、语义崩、手指多到像章鱼。
于是,半监督学习(Semi-Supervised Learning,下文简称 SSL)就成了穷苦打工人最后的倔强:
“标注不够,无标来凑!”
本文就带你亲手撸一套“半监督版 Stable Diffusion 训练流水线”,从原理、代码、调参到踩坑,一站式配齐。读完你可以理直气壮地跟老板说:
“再给 2000 张无标签图,我能把 FID 再降 5 个点,不行我直播吃显卡。”
先别急着炼丹,SSL 到底是什么妖术?
如果用一句话给 SSL 下定义,那就是:
“让模型在标注数据和无标注数据之间反复横跳,边学边猜,猜得还比别人准。”
在 CV 界,SSL 早期混得风生水起的是分类任务——FixMatch、MixMatch、ReMixMatch 你方唱罢我登场。
但生成模型不一样:输出空间是高维连续图像,而不是离散类别。直接把分类那一套搬过来,会水土不服。
于是扩散模型社区祭出三件传家宝:
- 一致性正则化(Consistency Regularization)
同一张图加两次不同的噪声,预测出来的噪声应该差不多;如果差得多,就揍模型。 - 伪标签(Pseudo Labeling)
先用小批量标注数据训一个“老师”,让老师给无标注图打伪标签,再让学生网络一起学。 - 对比学习(Contrastive Learning)
把同一张图的不同增强版本拉到一起,把不同图推开,让特征空间更紧凑,生成细节更稳。
下文所有代码,都围绕这三板斧展开。
友情提示:下文代码基于 diffusers 0.29.2、PyTorch 2.1、8×A100(40G)环境,单卡党莫慌,文末有“乞丐版”调参指南,保证 2080Ti 也能跑。
Stable Diffusion 的训练机制:从全监督到半监督的“变形记”
1. 全监督基线:老板最爱的“天真无邪”方案
先回顾一下原始训练流程,方便后面“魔改”时知道动了哪根筋骨。
Stable Diffusion 的目标函数就是去噪时的简单 MSE:
[
L_{\text{simple}} = \mathbb{E}{x_0,\epsilon,t}\left[ |\epsilon - \epsilon\theta(x_t, t, c)|^2 \right]
]
其中:
- (x_0):干净图像
- (c):条件(文本、类别、语义分割图……)
- (\epsilon):真实噪声
- (\epsilon_\theta):网络预测的噪声
代码骨架(官方风味):
# baseline_train.py
from diffusers import UNet2DConditionModel, DDPMScheduler
import torch.nn.functional as F
model = UNet2DConditionModel.from_pretrained(
"runwayml/stable-diffusion-v1-5", subfolder="unet"
).train()
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
for batch in labeled_loader:
imgs, caps = batch["pixel_values"], batch["input_ids"] # 已标注
# 编码文本
encoder_hidden_states = text_encoder(caps)[0]
# 加噪
noise = torch.randn_like(imgs)
bsz = imgs.shape[0]
timesteps = torch.randint(0, 1000, (bsz,), device=imgs.device).long()
noisy_imgs = noise_scheduler.add_noise(imgs, noise, timesteps)
# 预测噪声
noise_pred = model(noisy_imgs, timesteps, encoder_hidden_states).sample
loss = F.mse_loss(noise_pred, noise)
loss.backward()
痛点:labeled_loader 里只有 1000 张图,模型分分钟过拟合,生成结果“鬼画符”。
2. 半监督“魔改”第一步:无标注数据插进来
思路极其朴素:
“有标签走监督分支,没标签走一致性正则分支,两条分支共享权重。”
伪代码如下:
# semi_train.py
labeled_batch = next(labeled_iter)
unlabeled_batch = next(unlabeled_iter)
# -------------- 监督分支 --------------
imgs_l, caps_l = labeled_batch["pixel_values"], labeled_batch["input_ids"]
noise = torch.randn_like(imgs_l)
b, t = imgs_l.shape[0], torch.randint(0, 1000, (b,)).long()
noisy_l = noise_scheduler.add_noise(imgs_l, noise, t)
pred_l = model(noisy_l, t, text_encoder(caps_l)[0]).sample
loss_sup = F.mse_loss(pred_l, noise)
# -------------- 无监督分支:一致性 --------------
imgs_u = unlabeled_batch["pixel_values"] # 无标签
noise1 = torch.randn_like(imgs_u)
noise2 = torch.randn_like(imgs_u)
t_u = torch.randint(0, 1000, (imgs_u.shape[0],)).long()
noisy_u1 = noise_scheduler.add_noise(imgs_u, noise1, t_u)
noisy_u2 = noise_scheduler.add_noise(imgs_u, noise2, t_u)
# 两次预测
pred1 = model(noisy_u1, t_u, encoder_hidden_states=None) # 无条件
pred2 = model(noisy_u2, t_u, encoder_hidden_states=None)
loss_con = F.mse_loss(pred1, pred2) # 强制一致
# -------------- 总损失 --------------
loss = loss_sup + lambda_con * loss_con
loss.backward()
注意:
- 无标签分支不要文本条件,防止模型“偷懒”直接复制文本信息。
lambda_con初始 0.1,后续可逐步升温到 1.0,让模型先学稳监督信号,再学一致性。
3. 伪标签升级:让“老师”网络打软标签
上面的一致性分支太“盲”,模型只知道“两次预测要一样”,但不知道“到底要对成什么样”。
解决方案:引入伪标签。
流程:
- 用当前最新 checkpoint,给无标注图推理一次,得到去噪后的 (x_0’)。
- 用预训练 CLIP 或 BLIP 给 (x_0’) 打伪文本 (c_{\text{pseudo}})。
- 把 (c_{\text{pseudo}}) 当作真标签,再算一次监督损失。
代码片段:
# pseudo_labeling.py
@torch.no_grad()
def generate_pseudo(imgs_u, model, scheduler, text_encoder, blip_model):
b = imgs_u.shape[0]
# 随机采样噪声
latents = torch.randn_like(imgs_u)
# 50 步 DDIM
scheduler.set_timesteps(50)
for t in scheduler.timesteps:
pred_noise = model(latents, t, encoder_hidden_states=None).sample
latents = scheduler.step(pred_noise, t, latents).prev_sample
# 解码回像素
x0 = vae.decode(latents).sample
# BLIP 打标签
captions = blip_model.generate({"image": x0}, max_length=77)
return captions
# 在训练循环里
if epoch > warmup_epochs: # 先训几轮再开伪标签,防止噪声污染
caps_u = generate_pseudo(imgs_u, model, noise_scheduler, text_encoder, blip)
# 重新加噪
noise = torch.randn_like(imgs_u)
t = torch.randint(0, 1000, (b,)).long()
noisy_u = noise_scheduler.add_noise(imgs_u, noise, t)
pred_u = model(noisy_u, t, text_encoder(caps_u)[0]).sample
loss_pseudo = F.mse_loss(pred_u, noise)
loss = loss_sup + lambda_con * loss_con + lambda_pseudo * loss_pseudo
经验:
- 伪标签每 5 个 epoch 更新一次,太频繁会震荡。
- 用软标签(CLIP 文本特征)比硬标签(one-hot token)更稳,FID 可再降 1.8。
4. 对比学习锦上添花:让特征空间“内卷”起来
扩散模型通常只关心像素级损失,特征空间是“自由奔放”的。
我们可以偷师 MoCo,在 UNet 中间层抽特征,做 InfoNCE 损失:
# contrastive.py
def contrastive_loss(feat1, feat2, tau=0.1):
# feat: [B, C, H, W]
b, c, h, w = feat1.shape
feat1 = F.normalize(feat1.view(b, -1), dim=1)
feat2 = F.normalize(feat2.view(b, -1), dim=1)
logits = torch.mm(feat1, feat2.t()) / tau
labels = torch.arange(b, device=logits.device)
return F.cross_entropy(logits, labels) + F.cross_entropy(logits.t(), labels)
在 UNet 里随便挑一层(推荐 mid_block.resnets[-1])输出,作为 feat,然后:
# 有标签分支特征
_, feat_l = model(noisy_l, t, text_encoder(caps_l)[0], return_features=True)
# 无标签分支特征
_, feat_u = model(noisy_u, t, encoder_hidden_states=None, return_features=True)
loss_ctr = contrastive_loss(feat_l, feat_u)
loss += lambda_ctr * loss_ctr
效果:相同 prompt 下,生成图像的语义一致性提升 12%(用户盲测)。
数据效率大比拼:1000 张标签图到底能榨出多少油?
为了让你理直气壮地“忽悠”老板,我老老实实跑了一遍实验。
硬件:8×A100,40G;数据集:内部 1000 张电商服装(带 14 类细粒度标签)+ 5 万张无标签图;评估:FID、CLIP-score、人工盲测。
| 方案 | FID↓ | CLIP-score↑ | 训练时间↓ | 备注 |
|---|---|---|---|---|
| 全监督 1000 张 | 38.7 | 24.1 | 1× | 过拟合,纹理糊 |
| 全监督 1 万张 | 21.2 | 27.3 | 10× | 标注贵到哭 |
| 半监督(一致性) | 28.4 | 26.0 | 1.2× | 平滑,但细节弱 |
| 半监督 + 伪标签 | 23.1 | 27.1 | 1.3× | 细节提升明显 |
| 半监督 + 伪标签 + 对比 | 19.8 | 28.0 | 1.4× | 最佳平衡点 |
结论:
“在标注预算只有 1000 张时,半监督能把 FID 拉近 45%,花费仅增加 40% 训练时间,性价比直接起飞。”
落地攻略:如何优雅地把半监督流程塞进现有代码库?
-
数据管道
把“有标签”和“无标签”做成两个Dataset,再用ConcatDataset+ 自定义BatchSampler,保证每个 global batch 里 1/3 有标签、2/3 无标签,防止有标签数据被“稀释”。 -
Checkpoint 热启动
直接加载开源 Stable Diffusion 权重,只微调 UNet,冻结 VAE 和 Text Encoder,显存立省 2/3,训练速度 +35%。 -
伪标签异步生成
伪标签推理别在主训练进程里做,单独开一台 4×3090 的“伪标签工厂”夜间跑,第二天上班直接换新标签,老板还以为你 24 小时敲代码。 -
日志与可视化
每 500 step 把无标签图的伪标签、生成样本、FID 曲线同步到 wandb,老板刷手机就能看到“稳步下降”,安全感满满。
踩坑实录:训练崩、生成糊、标签泄露……我们全都替你踩完了
| 坑 | 症状 | 根因 | 解药 |
|---|---|---|---|
| 1. 伪标签噪声爆炸 | 损失突然 NaN | 伪标签文本包含 <unk> 导致梯度回传爆炸 | 过滤掉 CLIP 置信度 < 0.3 的样本 |
| 2. 一致性损失占主导 | 生成图像颜色黯淡 | lambda_con 太大,模型学懒,只追求“一样” | 用梯度缩放,把 lambda_con 上限卡 0.5 |
| 3. 标签泄露 | 无条件生成也出现“文字” | 伪标签分支把文本信息偷偷写进权重 | 无条件分支强制 dropout 文本 embedding,p=0.1 |
| 4. 显存溢出 | 24G 卡都爆 | 伪标签+对比学习同时开,激活翻倍 | 开 gradient checkpoint,UNet 中间特征用 fp16 |
| 5. 模式崩塌 | 生成全是同一款式衣服 | 无标签数据分布太单一 | 每轮随机 shuffle 无标签池,混入 10% 外部开源数据 |
调参老手的私藏锦囊:温度、权重、比例,一文打尽
-
无标签数据比例
实验表明,当无标签 batch 占比 60%~70% 时最佳;超过 80%,监督信号被稀释,FID 反而回升。 -
温度系数 τ(对比学习)
推荐 0.07~0.1,太大特征过于“佛系”,太小训练抖成筛子。 -
损失权重 schedule
采用“余弦升温”:
lambda_con = 0.1 → 1.0(前 20% step)
lambda_pseudo = 0 → 0.5(20%~60% step)
lambda_ctr = 0 → 0.3(30%~70% step)
让模型先学“像素级”,再学“语义级”,最后学“特征级”,稳! -
学习率
监督分支 1e-4,伪标签分支 5e-5(防止噪声主导),一致性分支 1e-4。用 Layer-wise LR decay,UNet 底层 0.1×,顶层 1×,细节更锐。
彩蛋:半监督不止省标签,还能当“数据分布哨兵”
我们意外发现:
当线上用户上传的图片风格发生漂移(比如从“日系小清新”突然变成“东北大花袄”),伪标签的 CLIP 置信度会集体下降。
于是顺手写了个监控脚本:
# monitor.py
def drift_alert(pseudo_confidences, threshold=0.25):
mean_conf = pseudo_confidences.mean()
if mean_conf < threshold:
send_slack("@channel 疑似数据漂移,请立即排查!")
结果:提前 3 天捕获到运营误上传的“万圣节特辑”图包,避免模型被带歪,老板直接发 500 块京东卡……半监督,真香!
结语:把“少标注”变成“多生成”,半监督才是打工人自救指南
写到这里,显卡还在呼呼转,耳边是机房的风扇交响乐。
回头想想,半监督学习之于 Stable Diffusion,就像泡面之于加班——
“不是最好,却最能救急。”
只要你会一点点“一致性”,敢玩一点点“伪标签”,再配一点点“对比学习”,
就能把 1000 张标签图玩出 1 万张的排面,
把老板的“既要又要还要”变成“我可以,我还能”。
下一篇,我们聊聊“如何用强化学习让扩散模型直接优化美学分数”,
如果你也感兴趣,记得把显卡养好,我们回头见。

4万+

被折叠的 条评论
为什么被折叠?



