摘要
Stable Diffusion WebUI的核心功能是生成高质量的图像,其背后涉及复杂的处理流程。本文将深入分析WebUI中图像生成的核心流程,从用户界面交互到模型推理,再到最终图像输出的全过程。
关键词: Stable Diffusion, WebUI, 图像生成, 源码分析, 深度学习
1. 引言
Stable Diffusion WebUI作为最受欢迎的AI图像生成工具之一,其核心功能就是根据用户输入的参数生成图像。这个过程涉及多个组件的协同工作,包括模型加载、参数处理、采样器调度、VAE编解码等。理解图像生成的核心流程对于优化性能、调试问题以及二次开发都具有重要意义。
本文将从源码角度深入分析图像生成的完整流程,帮助开发者和高级用户更好地理解和使用WebUI。
2. 核心处理类结构
2.1 StableDiffusionProcessing基类
[StableDiffusionProcessing](file:///E:/project/stable-diffusion-webui/modules/processing.py#L79-L163)是所有图像生成处理类的基类,定义了通用的属性和方法:
@dataclass(repr=False)
class StableDiffusionProcessing:
sd_model: object = None
outpath_samples: str = None
outpath_grids: str = None
prompt: str = ""
prompt_for_display: str = None
negative_prompt: str = ""
styles: list[str] = None
seed: int = -1
subseed: int = -1
subseed_strength: float = 0
# ... 更多属性
该类使用Python的dataclass装饰器,自动提供了构造函数和其他常用方法。它包含了图像生成所需的所有基本参数,如提示词、种子、尺寸、采样器设置等。
2.2 文生图处理类
[StableDiffusionProcessingTxt2Img](file:///E:/project/stable-diffusion-webui/modules/processing.py#L1241-L1313)继承自[StableDiffusionProcessing](file:///E:/project/stable-diffusion-webui/modules/processing.py#L79-L163),专门处理文本到图像的生成任务:
@dataclass(repr=False)
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
enable_hr: bool = False
denoising_strength: float = 0.75
firstphase_width: int = 0
firstphase_height: int = 0
hr_scale: float = 2.0
# ... 高清修复相关参数
2.3 图生图处理类
[StableDiffusionProcessingImg2Img](file:///E:/project/stable-diffusion-webui/modules/processing.py#L1623-L1648)同样继承自[StableDiffusionProcessing](file:///E:/project/stable-diffusion-webui/modules/processing.py#L79-L163),用于处理图像到图像的变换任务:
@dataclass(repr=False)
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
init_images: list = None
resize_mode: int = 0
denoising_strength: float = 0.75
image_cfg_scale: float = None
mask: Any = None
# ... 图像处理相关参数
3. 图像生成主流程
3.1 process_images函数
[process_images](file:///E:/project/stable-diffusion-webui/modules/processing.py#L874-L923)函数是图像生成的入口点,负责协调整个生成过程:
def process_images(p: StableDiffusionProcessing) -> Processed:
if p.scripts is not None:
p.scripts.before_process(p)
# 处理设置覆盖
stored_opts = {k: opts.data[k] if k in opts.data else opts.get_default(k) for k in p.override_settings.keys() if k in opts.data}
try:
# 处理模型检查点覆盖
if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None:
p.override_settings.pop('sd_model_checkpoint', None)
sd_models.reload_model_weights()
# 应用设置
for k, v in p.override_settings.items():
opts.set(k, v, is_api=True, run_callbacks=False)
if k == 'sd_model_checkpoint':
sd_models.reload_model_weights()
if k == 'sd_vae':
sd_vae.reload_vae_weights()
# 应用Token Merging优化
sd_models.apply_token_merging(p.sd_model, p.get_token_merging_ratio())
# 修复无效的采样器和调度器
sd_samplers.fix_p_invalid_sampler_and_scheduler(p)
# 执行实际的图像生成
with profiling.Profiler():
res = process_images_inner(p)
finally:
# 恢复设置
sd_models.apply_token_merging(p.sd_model, 0)
if p.override_settings_restore_afterwards:
for k, v in stored_opts.items():
setattr(opts, k, v)
if k == 'sd_vae':
sd_vae.reload_vae_weights()
return res
3.2 process_images_inner函数
[process_images_inner](file:///E:/project/stable-diffusion-webui/modules/processing.py#L925-L1083)函数包含了图像生成的核心逻辑:
def process_images_inner(p: StableDiffusionProcessing) -> Processed:
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
# 初始化设备和种子
devices.torch_gc()
seed = get_fixed_seed(p.seed)
subseed = get_fixed_seed(p.subseed)
# 设置面部修复和tiling选项
if p.restore_faces is None:
p.restore_faces = opts.face_restoration
if p.tiling is None:
p.tiling = opts.tiling
# 处理模型维度
if hasattr(shared.sd_model, 'fix_dimensions'):
p.width, p.height = shared.sd_model.fix_dimensions(p.width, p.height)
# 设置模型和VAE信息
p.sd_model_name = shared.sd_model.sd_checkpoint_info.name_for_extra
p.sd_model_hash = shared.sd_model.sd_model_hash
p.sd_vae_name = sd_vae.get_loaded_vae_name()
p.sd_vae_hash = sd_vae.get_loaded_vae_hash()
# 应用圆形扩展
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
modules.sd_hijack.model_hijack.clear_comments()
# 填充字段和设置提示词
p.fill_fields_from_opts()
p.setup_prompts()
# 设置种子
if isinstance(seed, list):
p.all_seeds = seed
else:
p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))]
# 加载嵌入
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
model_hijack.embedding_db.load_textual_inversion_embeddings()
# 执行脚本处理
if p.scripts is not None:
p.scripts.process(p)
# 初始化输出容器
infotexts = []
output_images = []
# 主要生成循环
with torch.no_grad(), p.sd_model.ema_scope():
with devices.autocast():
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
# 加载预览模型
if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN":
sd_vae_approx.model()
# 应用UNet
sd_unet.apply_unet()
# 设置作业计数
if state.job_count == -1:
state.job_count = p.n_iter
# 迭代批次
for n in range(p.n_iter):
p.iteration = n
# 处理中断和跳过
if state.skipped:
state.skipped = False
if state.interrupted or state.stopping_generation:
break
# 重新加载模型权重
sd_models.reload_model_weights()
# 设置当前批次的参数
p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
# 设置随机数生成器
latent_channels = getattr(shared.sd_model, 'latent_channels', opt_C)
p.rng = rng.ImageRNG((latent_channels, p.height // opt_f, p.width // opt_f),
p.seeds, subseeds=p.subseeds,
subseed_strength=p.subseed_strength,
seed_resize_from_h=p.seed_resize_from_h,
seed_resize_from_w=p.seed_resize_from_w)
# 执行批处理前的脚本
if p.scripts is not None:
p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
# 解析额外网络提示词
p.parse_extra_network_prompts()
# 激活额外网络
if not p.disable_extra_networks:
with devices.autocast():
extra_networks.activate(p, p.extra_network_data)
# 执行批处理脚本
if p.scripts is not None:
p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
# 设置条件
p.setup_conds()
# 更新额外生成参数
p.extra_generation_params.update(model_hijack.extra_generation_params)
# 保存参数文件
if n == 0 and not cmd_opts.no_prompt_history:
with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file:
processed = Processed(p, [])
file.write(processed.infotext(p, 0))
# 添加注释
for comment in model_hijack.comments:
p.comment(comment)
# 设置作业状态
if p.n_iter > 1:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
# 应用Alpha调度覆盖
sd_models.apply_alpha_schedule_override(p.sd_model, p)
# 执行采样
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc,
seeds=p.seeds, subseeds=p.subseeds,
subseed_strength=p.subseed_strength, prompts=p.prompts)
# 执行采样后的脚本
if p.scripts is not None:
ps = scripts.PostSampleArgs(samples_ddim)
p.scripts.post_sample(p, ps)
samples_ddim = ps.samples
# 解码样本
if getattr(samples_ddim, 'already_decoded', False):
x_samples_ddim = samples_ddim
else:
devices.test_for_nans(samples_ddim, "unet")
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim,
target_device=devices.cpu,
check_for_nans=True)
# 处理解码后的图像
x_samples_ddim = torch.stack(x_samples_ddim).float()
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
# 清理内存
del samples_ddim
if lowvram.is_enabled(shared.sd_model):
lowvram.send_everything_to_cpu()
devices.torch_gc()
state.nextjob()
# 执行批处理后的脚本
if p.scripts is not None:
p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
batch_params = scripts.PostprocessBatchListArgs(list(x_samples_ddim))
p.scripts.postprocess_batch_list(p, batch_params, batch_number=n)
x_samples_ddim = batch_params.images
# 定义信息文本函数
def infotext(index=0, use_main_prompt=False):
return create_infotext(p, p.prompts, p.seeds, p.subseeds,
use_main_prompt=use_main_prompt, index=index,
all_negative_prompts=p.negative_prompts)
# 保存样本
save_samples = p.save_samples()
# 处理每个图像
for i, x_sample in enumerate(x_samples_ddim):
p.batch_index = i
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
# 面部修复
if p.restore_faces:
if save_samples and opts.save_images_before_face_restoration:
images.save_image(Image.fromarray(x_sample), p.outpath_samples, "",
p.seeds[i], p.prompts[i], opts.samples_format,
info=infotext(i), p=p, suffix="-before-face-restoration")
devices.torch_gc()
x_sample = modules.face_restoration.restore_faces(x_sample)
devices.torch_gc()
image = Image.fromarray(x_sample)
# 图像后处理脚本
if p.scripts is not None:
pp = scripts.PostprocessImageArgs(image)
p.scripts.postprocess_image(p, pp)
image = pp.image
mask_for_overlay = getattr(p, "mask_for_overlay", None)
# 应用遮罩叠加
if not shared.opts.overlay_inpaint:
overlay_image = None
elif getattr(p, "overlay_images", None) is not None and i < len(p.overlay_images):
overlay_image = p.overlay_images[i]
else:
overlay_image = None
if p.scripts is not None:
ppmo = scripts.PostProcessMaskOverlayArgs(i, mask_for_overlay, overlay_image)
p.scripts.postprocess_maskoverlay(p, ppmo)
mask_for_overlay, overlay_image = ppmo.mask_for_overlay, ppmo.overlay_image
# 颜色校正
if p.color_corrections is not None and i < len(p.color_corrections):
if save_samples and opts.save_images_before_color_correction:
image_without_cc, _ = apply_overlay(image, p.paste_to, overlay_image)
images.save_image(image_without_cc, p.outpath_samples, "",
p.seeds[i], p.prompts[i], opts.samples_format,
info=infotext(i), p=p, suffix="-before-color-correction")
image = apply_color_correction(p.color_corrections[i], image)
# 应用叠加
image, original_denoised_image = apply_overlay(image, p.paste_to, overlay_image)
# 后期合成处理
if p.scripts is not None:
pp = scripts.PostprocessImageArgs(image)
p.scripts.postprocess_image_after_composite(p, pp)
image = pp.image
# 保存图像
if save_samples:
images.save_image(image, p.outpath_samples, "", p.seeds[i],
p.prompts[i], opts.samples_format, info=infotext(i), p=p)
# 添加PNG信息
text = infotext(i)
infotexts.append(text)
if opts.enable_pnginfo:
image.info["parameters"] = text
output_images.append(image)
# 处理遮罩相关输出
if mask_for_overlay is not None:
if opts.return_mask or opts.save_mask:
image_mask = mask_for_overlay.convert('RGB')
if save_samples and opts.save_mask:
images.save_image(image_mask, p.outpath_samples, "",
p.seeds[i], p.prompts[i], opts.samples_format,
info=infotext(i), p=p, suffix="-mask")
if opts.return_mask:
output_images.append(image_mask)
if opts.return_mask_composite or opts.save_mask_composite:
image_mask_composite = Image.composite(
original_denoised_image.convert('RGBA').convert('RGBa'),
Image.new('RGBa', image.size),
images.resize_image(2, mask_for_overlay, image.width, image.height).convert('L')
).convert('RGBA')
if save_samples and opts.save_mask_composite:
images.save_image(image_mask_composite, p.outpath_samples, "",
p.seeds[i], p.prompts[i], opts.samples_format,
info=infotext(i), p=p, suffix="-mask-composite")
if opts.return_mask_composite:
output_images.append(image_mask_composite)
# 清理内存
del x_samples_ddim
devices.torch_gc()
# 处理网格图像
if not infotexts:
infotexts.append(Processed(p, []).infotext(p, 0))
p.color_corrections = None
index_of_first_image = 0
unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
if (opts.return_grid or opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
grid = images.image_grid(output_images, p.batch_size)
if opts.return_grid:
text = infotext(use_main_prompt=True)
infotexts.insert(0, text)
if opts.enable_pnginfo:
grid.info["parameters"] = text
output_images.insert(0, grid)
index_of_first_image = 1
if opts.grid_save:
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0],
p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True),
short_filename=not opts.grid_extended_filename, p=p, grid=True)
# 停用额外网络
if not p.disable_extra_networks and p.extra_network_data:
extra_networks.deactivate(p, p.extra_network_data)
devices.torch_gc()
# 返回处理结果
res = Processed(
p,
images_list=output_images,
seed=p.all_seeds[0],
info=infotexts[0],
subseed=p.all_subseeds[0],
index_of_first_image=index_of_first_image,
infotexts=infotexts,
)
# 执行最后的后处理脚本
if p.scripts is not None:
p.scripts.postprocess(p, res)
return res
4. 文生图流程详解
4.1 初始化过程
文生图的初始化过程在[StableDiffusionProcessingTxt2Img.init](file:///E:/project/stable-diffusion-webui/modules/processing.py#L1315-L1420)方法中实现:
def init(self, all_prompts, all_seeds, all_subseeds):
if self.enable_hr:
# 设置高清修复相关参数
self.extra_generation_params["Denoising strength"] = self.denoising_strength
# 处理高清修复检查点
if self.hr_checkpoint_name and self.hr_checkpoint_name != 'Use same checkpoint':
self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name)
if self.hr_checkpoint_info is None:
raise Exception(f'Could not find checkpoint with name {self.hr_checkpoint_name}')
self.extra_generation_params["Hires checkpoint"] = self.hr_checkpoint_info.short_title
# 设置高清修复采样器和调度器
if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
self.extra_generation_params["Hires sampler"] = self.hr_sampler_name
# 设置高清修复提示词
def get_hr_prompt(p, index, prompt_text, **kwargs):
hr_prompt = p.all_hr_prompts[index]
return hr_prompt if hr_prompt != prompt_text else None
def get_hr_negative_prompt(p, index, negative_prompt, **kwargs):
hr_negative_prompt = p.all_hr_negative_prompts[index]
return hr_negative_prompt if hr_negative_prompt != negative_prompt else None
self.extra_generation_params["Hires prompt"] = get_hr_prompt
self.extra_generation_params["Hires negative prompt"] = get_hr_negative_prompt
# 设置其他高清修复参数
self.extra_generation_params["Hires schedule type"] = None
if self.hr_scheduler is None:
self.hr_scheduler = self.scheduler
# 设置缩放模式
self.latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
# 计算目标分辨率
self.calculate_target_resolution()
# 更新进度条
if not state.processing_has_refined_job_count:
if state.job_count == -1:
state.job_count = self.n_iter
if getattr(self, 'txt2img_upscale', False):
total_steps = (self.hr_second_pass_steps or self.steps) * state.job_count
else:
total_steps = (self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count
shared.total_tqdm.updateTotal(total_steps)
state.job_count = state.job_count * 2
state.processing_has_refined_job_count = True
# 设置高清修复步骤和上采样器
if self.hr_second_pass_steps:
self.extra_generation_params["Hires steps"] = self.hr_second_pass_steps
if self.hr_upscaler is not None:
self.extra_generation_params["Hires upscaler"] = self.hr_upscaler
4.2 采样过程
文生图的采样过程在[StableDiffusionProcessingTxt2Img.sample](file:///E:/project/stable-diffusion-webui/modules/processing.py#L1422-L1503)方法中实现:
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
# 创建采样器
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
if self.firstpass_image is not None and self.enable_hr:
# 处理已有的首遍图像
if self.latent_scale_mode is None:
# 直接使用图像
image = np.array(self.firstpass_image).astype(np.float32) / 255.0 * 2.0 - 1.0
image = np.moveaxis(image, 2, 0)
samples = None
decoded_samples = torch.asarray(np.expand_dims(image, 0))
else:
# 编码图像到潜在空间
image = np.array(self.firstpass_image).astype(np.float32) / 255.0
image = np.moveaxis(image, 2, 0)
image = torch.from_numpy(np.expand_dims(image, axis=0))
image = image.to(shared.device, dtype=devices.dtype_vae)
if opts.sd_vae_encode_method != 'Full':
self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
samples = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model)
decoded_samples = None
devices.torch_gc()
else:
# 正常生成图像
x = self.rng.next()
if self.scripts is not None:
self.scripts.process_before_every_sampling(
p=self,
x=x,
noise=x,
c=conditioning,
uc=unconditional_conditioning
)
# 执行采样
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning,
image_conditioning=self.txt2img_image_conditioning(x))
del x
# 如果不启用高清修复,直接返回样本
if not self.enable_hr:
return samples
devices.torch_gc()
# 解码样本
if self.latent_scale_mode is None:
decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples,
target_device=devices.cpu,
check_for_nans=True)).to(dtype=torch.float32)
else:
decoded_samples = None
# 重新加载高清修复模型
with sd_models.SkipWritingToConfig():
sd_models.reload_model_weights(info=self.hr_checkpoint_info)
# 执行高清修复通道
return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
4.3 高清修复通道
高清修复通道在[sample_hr_pass](file:///E:/project/stable-diffusion-webui/modules/processing.py#L1505-L1617)方法中实现:
def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
if shared.state.interrupted:
return samples
self.is_hr_pass = True
target_width = self.hr_upscale_to_x
target_height = self.hr_upscale_to_y
def save_intermediate(image, index):
"""保存高清修复前的图像"""
if not self.save_samples() or not opts.save_images_before_highres_fix:
return
if not isinstance(image, Image.Image):
image = sd_samplers.sample_to_image(image, index, approximation=0)
info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [],
iteration=self.iteration, position_in_batch=index)
images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index],
opts.samples_format, info=info, p=self, suffix="-before-highres-fix")
# 创建高清修复采样器
img2img_sampler_name = self.hr_sampler_name or self.sampler_name
self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
if self.latent_scale_mode is not None:
# 潜在空间上采样
for i in range(samples.shape[0]):
save_intermediate(samples, i)
samples = torch.nn.functional.interpolate(samples,
size=(target_height // opt_f, target_width // opt_f),
mode=self.latent_scale_mode["mode"],
antialias=self.latent_scale_mode["antialias"])
# 设置图像条件
if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
image_conditioning = self.img2img_image_conditioning(
decode_first_stage(self.sd_model, samples), samples)
else:
image_conditioning = self.txt2img_image_conditioning(samples)
else:
# 解码空间上采样
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
batch_images = []
for i, x_sample in enumerate(lowres_samples):
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
image = Image.fromarray(x_sample)
save_intermediate(image, i)
# 上采样图像
image = images.resize_image(0, image, target_width, target_height,
upscaler_name=self.hr_upscaler)
image = np.array(image).astype(np.float32) / 255.0
image = np.moveaxis(image, 2, 0)
batch_images.append(image)
decoded_samples = torch.from_numpy(np.array(batch_images))
decoded_samples = decoded_samples.to(shared.device, dtype=devices.dtype_vae)
if opts.sd_vae_encode_method != 'Full':
self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
samples = images_tensor_to_samples(decoded_samples,
approximation_indexes.get(opts.sd_vae_encode_method))
image_conditioning = self.img2img_image_conditioning(decoded_samples, samples)
shared.state.nextjob()
# 裁剪样本
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2,
self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
# 创建噪声
self.rng = rng.ImageRNG(samples.shape[1:], self.seeds, subseeds=self.subseeds,
subseed_strength=self.subseed_strength,
seed_resize_from_h=self.seed_resize_from_h,
seed_resize_from_w=self.seed_resize_from_w)
noise = self.rng.next()
# 激活高清修复额外网络
if not self.disable_extra_networks:
with devices.autocast():
extra_networks.activate(self, self.hr_extra_network_data)
# 计算高清修复条件
with devices.autocast():
self.calculate_hr_conds()
# 应用Token Merging
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
# 执行高清修复前的脚本
if self.scripts is not None:
self.scripts.before_hr(self)
self.scripts.process_before_every_sampling(
p=self,
x=samples,
noise=noise,
c=self.hr_c,
uc=self.hr_uc,
)
# 执行高清修复采样
samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc,
steps=self.hr_second_pass_steps or self.steps,
image_conditioning=image_conditioning)
# 恢复Token Merging设置
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
self.sampler = None
devices.torch_gc()
# 解码最终样本
decoded_samples = decode_latent_batch(self.sd_model, samples,
target_device=devices.cpu,
check_for_nans=True)
self.is_hr_pass = False
return decoded_samples
5. 图生图流程详解
5.1 初始化过程
图生图的初始化过程在[StableDiffusionProcessingImg2Img.init](file:///E:/project/stable-diffusion-webui/modules/processing.py#L1650-L1769)方法中实现:
def init(self, all_prompts, all_seeds, all_subseeds):
# 设置去噪强度
self.extra_generation_params["Denoising strength"] = self.denoising_strength
# 设置图像CFG比例
self.image_cfg_scale: float = self.image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None
# 创建采样器
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
crop_region = None
image_mask = self.image_mask
# 处理遮罩
if image_mask is not None:
# 将RGBA遮罩转换为二值遮罩
image_mask = create_binary_mask(image_mask, round=self.mask_round)
# 反转遮罩
if self.inpainting_mask_invert:
image_mask = ImageOps.invert(image_mask)
self.extra_generation_params["Mask mode"] = "Inpaint not masked"
# 应用遮罩模糊
if self.mask_blur_x > 0:
np_mask = np.array(image_mask)
kernel_size = 2 * int(2.5 * self.mask_blur_x + 0.5) + 1
np_mask = cv2.GaussianBlur(np_mask, (kernel_size, 1), self.mask_blur_x)
image_mask = Image.fromarray(np_mask)
if self.mask_blur_y > 0:
np_mask = np.array(image_mask)
kernel_size = 2 * int(2.5 * self.mask_blur_y + 0.5) + 1
np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur_y)
image_mask = Image.fromarray(np_mask)
# 全局修复或局部修复
if self.inpaint_full_res:
self.mask_for_overlay = image_mask
mask = image_mask.convert('L')
crop_region = masking.get_crop_region_v2(mask, self.inpaint_full_res_padding)
if crop_region:
crop_region = masking.expand_crop_region(crop_region, self.width, self.height,
mask.width, mask.height)
x1, y1, x2, y2 = crop_region
mask = mask.crop(crop_region)
image_mask = images.resize_image(2, mask, self.width, self.height)
self.paste_to = (x1, y1, x2-x1, y2-y1)
self.extra_generation_params["Inpaint area"] = "Only masked"
self.extra_generation_params["Masked area padding"] = self.inpaint_full_res_padding
else:
# 如果遮罩为空,切换到img2img模式
crop_region = None
image_mask = None
self.mask_for_overlay = None
self.inpaint_full_res = False
massage = 'Unable to perform "Inpaint Only mask" because mask is blank, switch to img2img mode.'
model_hijack.comments.append(massage)
logging.info(massage)
else:
image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
np_mask = np.array(image_mask)
np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
self.mask_for_overlay = Image.fromarray(np_mask)
self.overlay_images = []
# 处理潜在空间遮罩
latent_mask = self.latent_mask if self.latent_mask is not None else image_mask
# 设置颜色校正
add_color_corrections = opts.img2img_color_correction and self.color_corrections is None
if add_color_corrections:
self.color_corrections = []
imgs = []
for img in self.init_images:
# 保存初始图像
if opts.save_init_img:
self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest()
images.save_image(img, path=opts.outdir_init_images, basename=None,
forced_filename=self.init_img_hash, save_to_dirs=False,
existing_info=img.info)
# 展平图像
image = images.flatten(img, opts.img2img_background_color)
# 调整图像大小
if crop_region is None and self.resize_mode != 3:
image = images.resize_image(self.resize_mode, image, self.width, self.height)
# 应用遮罩
if image_mask is not None:
if self.mask_for_overlay.size != (image.width, image.height):
self.mask_for_overlay = images.resize_image(self.resize_mode,
self.mask_for_overlay,
image.width, image.height)
image_masked = Image.new('RGBa', (image.width, image.height))
image_masked.paste(image.convert("RGBA").convert("RGBa"),
mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
self.overlay_images.append(image_masked.convert('RGBA'))
# 裁剪区域处理
if crop_region is not None:
image = image.crop(crop_region)
image = images.resize_image(2, image, self.width, self.height)
# 填充遮罩区域
if image_mask is not None:
if self.inpainting_fill != 1:
image = masking.fill(image, latent_mask)
if self.inpainting_fill == 0:
self.extra_generation_params["Masked content"] = 'fill'
# 颜色校正
if add_color_corrections:
self.color_corrections.append(setup_color_correction(image))
# 转换为张量
image = np.array(image).astype(np.float32) / 255.0
image = np.moveaxis(image, 2, 0)
imgs.append(image)
# 批量处理图像
if len(imgs) == 1:
batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
if self.overlay_images is not None:
self.overlay_images = self.overlay_images * self.batch_size
if self.color_corrections is not None and len(self.color_corrections) == 1:
self.color_corrections = self.color_corrections * self.batch_size
elif len(imgs) <= self.batch_size:
self.batch_size = len(imgs)
batch_images = np.array(imgs)
else:
raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
image = torch.from_numpy(batch_images)
image = image.to(shared.device, dtype=devices.dtype_vae)
# 编码到潜在空间
if opts.sd_vae_encode_method != 'Full':
self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
self.init_latent = images_tensor_to_samples(image,
approximation_indexes.get(opts.sd_vae_encode_method),
self.sd_model)
devices.torch_gc()
# 调整大小模式3处理
if self.resize_mode == 3:
self.init_latent = torch.nn.functional.interpolate(self.init_latent,
size=(self.height // opt_f, self.width // opt_f),
mode="bilinear")
# 处理遮罩
if image_mask is not None:
init_mask = latent_mask
latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3],
self.init_latent.shape[2]))
latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
latmask = latmask[0]
if self.mask_round:
latmask = np.around(latmask)
latmask = np.tile(latmask[None], (self.init_latent.shape[1], 1, 1))
self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(devices.dtype)
self.nmask = torch.asarray(latmask).to(shared.device).type(devices.dtype)
# 填充遮罩内容
if self.inpainting_fill == 2:
self.init_latent = self.init_latent * self.mask + create_random_tensors(
self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]
) * self.nmask
self.extra_generation_params["Masked content"] = 'latent noise'
elif self.inpainting_fill == 3:
self.init_latent = self.init_latent * self.mask
self.extra_generation_params["Masked content"] = 'latent nothing'
# 设置图像条件
self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent,
image_mask, self.mask_round)
5.2 采样过程
图生图的采样过程在[StableDiffusionProcessingImg2Img.sample](file:///E:/project/stable-diffusion-webui/modules/processing.py#L1771-L1791)方法中实现:
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
# 生成噪声
x = self.rng.next()
# 应用噪声乘数
if self.initial_noise_multiplier != 1.0:
self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
x *= self.initial_noise_multiplier
# 执行采样前的脚本
if self.scripts is not None:
self.scripts.process_before_every_sampling(
p=self,
x=self.init_latent,
noise=x,
c=conditioning,
uc=unconditional_conditioning
)
# 执行图像到图像采样
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning,
unconditional_conditioning,
image_conditioning=self.image_conditioning)
# 应用遮罩混合
if self.mask is not None:
blended_samples = samples * self.nmask + self.init_latent * self.mask
if self.scripts is not None:
mba = scripts.MaskBlendArgs(samples, self.nmask, self.init_latent,
self.mask, blended_samples)
self.scripts.on_mask_blend(self, mba)
blended_samples = mba.blended_latent
samples = blended_samples
# 清理内存
del x
devices.torch_gc()
return samples
6. 关键技术点分析
6.1 条件处理
Stable Diffusion使用文本条件来指导图像生成过程。WebUI通过以下方式处理条件:
def setup_conds(self):
# 设置正面和负面提示词
prompts = prompt_parser.SdConditioning(self.prompts, width=self.width, height=self.height)
negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width,
height=self.height, is_negative_prompt=True)
# 查找采样器配置
sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
total_steps = sampler_config.total_steps(self.steps) if sampler_config else self.steps
self.step_multiplier = total_steps // self.steps
self.firstpass_steps = total_steps
# 获取无条件和条件编码
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning,
negative_prompts, total_steps, [self.cached_uc],
self.extra_network_data)
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning,
prompts, total_steps, [self.cached_c],
self.extra_network_data)
6.2 缓存机制
为了提高性能,WebUI实现了条件缓存机制:
def get_conds_with_caching(self, function, required_prompts, steps, caches,
extra_network_data, hires_steps=None):
"""
使用缓存存储结果,如果相同参数已被使用则返回缓存结果
"""
if shared.opts.use_old_scheduling:
old_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(
required_prompts, steps, hires_steps, False)
new_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(
required_prompts, steps, hires_steps, True)
if old_schedules != new_schedules:
self.extra_generation_params["Old prompt editing timelines"] = True
# 构建缓存参数
cached_params = self.cached_params(required_prompts, steps, extra_network_data,
hires_steps, shared.opts.use_old_scheduling)
# 检查缓存
for cache in caches:
if cache[0] is not None and cached_params == cache[0]:
return cache[1]
cache = caches[0]
# 计算条件并缓存结果
with devices.autocast():
cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps,
shared.opts.use_old_scheduling)
cache[0] = cached_params
return cache[1]
6.3 内存管理
WebUI实现了多种内存管理策略以适应不同硬件配置:
def close(self):
"""释放处理过程中占用的资源"""
self.sampler = None
self.c = None
self.uc = None
if not opts.persistent_cond_cache:
StableDiffusionProcessing.cached_c = [None, None]
StableDiffusionProcessing.cached_uc = [None, None]
# 在低显存模式下发送到CPU
if lowvram.is_enabled(shared.sd_model):
lowvram.send_everything_to_cpu()
# 清理GPU内存
devices.torch_gc()
7. 性能优化策略
7.1 Token Merging
Token Merging是一种有效的性能优化技术:
# 应用Token Merging
sd_models.apply_token_merging(p.sd_model, p.get_token_merging_ratio())
# 在高清修复完成后恢复
sd_models.apply_token_merging(p.sd_model, 0)
7.2 批处理优化
WebUI通过批处理来提高效率:
# 设置作业计数
if state.job_count == -1:
state.job_count = p.n_iter
# 迭代批次
for n in range(p.n_iter):
# 处理每个批次
p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
8. 错误处理与恢复
WebUI具有完善的错误处理机制:
try:
# 主要处理逻辑
res = process_images_inner(p)
except Exception as e:
# 错误处理
errors.report(f"Error processing images: {e}", exc_info=True)
finally:
# 资源清理和状态恢复
sd_models.apply_token_merging(p.sd_model, 0)
if p.override_settings_restore_afterwards:
for k, v in stored_opts.items():
setattr(opts, k, v)
总结
通过对Stable Diffusion WebUI图像生成核心流程的深入分析,我们可以看到其设计的精妙之处:
- 模块化设计:通过基类和派生类的方式组织代码,提高了代码的可维护性和扩展性
- 性能优化:实现了多种优化策略,如缓存、Token Merging、批处理等
- 灵活扩展:通过脚本系统提供了丰富的扩展点
- 错误处理:具备完善的错误处理和恢复机制
- 资源管理:有效地管理内存和计算资源
理解这些核心流程不仅有助于更好地使用WebUI,也为进行二次开发和性能优化提供了重要基础。随着技术的不断发展,图像生成流程也在持续演进,未来可能会引入更多优化技术和功能特性。
参考资料
- Stable Diffusion WebUI GitHub仓库: https://github.com/AUTOMATIC1111/stable-diffusion-webui
- Stable Diffusion论文: https://arxiv.org/abs/2206.00364
- Latent Diffusion Models: https://arxiv.org/abs/2112.10752
- WebUI Wiki文档: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki
1675

被折叠的 条评论
为什么被折叠?



