图像生成核心流程源码解读

摘要

Stable Diffusion WebUI的核心功能是生成高质量的图像,其背后涉及复杂的处理流程。本文将深入分析WebUI中图像生成的核心流程,从用户界面交互到模型推理,再到最终图像输出的全过程。

关键词: Stable Diffusion, WebUI, 图像生成, 源码分析, 深度学习

1. 引言

Stable Diffusion WebUI作为最受欢迎的AI图像生成工具之一,其核心功能就是根据用户输入的参数生成图像。这个过程涉及多个组件的协同工作,包括模型加载、参数处理、采样器调度、VAE编解码等。理解图像生成的核心流程对于优化性能、调试问题以及二次开发都具有重要意义。

本文将从源码角度深入分析图像生成的完整流程,帮助开发者和高级用户更好地理解和使用WebUI。

2. 核心处理类结构

2.1 StableDiffusionProcessing基类

[StableDiffusionProcessing](file:///E:/project/stable-diffusion-webui/modules/processing.py#L79-L163)是所有图像生成处理类的基类,定义了通用的属性和方法:

@dataclass(repr=False)
class StableDiffusionProcessing:
    sd_model: object = None
    outpath_samples: str = None
    outpath_grids: str = None
    prompt: str = ""
    prompt_for_display: str = None
    negative_prompt: str = ""
    styles: list[str] = None
    seed: int = -1
    subseed: int = -1
    subseed_strength: float = 0
    # ... 更多属性

该类使用Python的dataclass装饰器,自动提供了构造函数和其他常用方法。它包含了图像生成所需的所有基本参数,如提示词、种子、尺寸、采样器设置等。

2.2 文生图处理类

[StableDiffusionProcessingTxt2Img](file:///E:/project/stable-diffusion-webui/modules/processing.py#L1241-L1313)继承自[StableDiffusionProcessing](file:///E:/project/stable-diffusion-webui/modules/processing.py#L79-L163),专门处理文本到图像的生成任务:

@dataclass(repr=False)
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
    enable_hr: bool = False
    denoising_strength: float = 0.75
    firstphase_width: int = 0
    firstphase_height: int = 0
    hr_scale: float = 2.0
    # ... 高清修复相关参数

2.3 图生图处理类

[StableDiffusionProcessingImg2Img](file:///E:/project/stable-diffusion-webui/modules/processing.py#L1623-L1648)同样继承自[StableDiffusionProcessing](file:///E:/project/stable-diffusion-webui/modules/processing.py#L79-L163),用于处理图像到图像的变换任务:

@dataclass(repr=False)
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
    init_images: list = None
    resize_mode: int = 0
    denoising_strength: float = 0.75
    image_cfg_scale: float = None
    mask: Any = None
    # ... 图像处理相关参数

3. 图像生成主流程

3.1 process_images函数

[process_images](file:///E:/project/stable-diffusion-webui/modules/processing.py#L874-L923)函数是图像生成的入口点,负责协调整个生成过程:

def process_images(p: StableDiffusionProcessing) -> Processed:
    if p.scripts is not None:
        p.scripts.before_process(p)

    # 处理设置覆盖
    stored_opts = {k: opts.data[k] if k in opts.data else opts.get_default(k) for k in p.override_settings.keys() if k in opts.data}

    try:
        # 处理模型检查点覆盖
        if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None:
            p.override_settings.pop('sd_model_checkpoint', None)
            sd_models.reload_model_weights()

        # 应用设置
        for k, v in p.override_settings.items():
            opts.set(k, v, is_api=True, run_callbacks=False)

            if k == 'sd_model_checkpoint':
                sd_models.reload_model_weights()

            if k == 'sd_vae':
                sd_vae.reload_vae_weights()

        # 应用Token Merging优化
        sd_models.apply_token_merging(p.sd_model, p.get_token_merging_ratio())

        # 修复无效的采样器和调度器
        sd_samplers.fix_p_invalid_sampler_and_scheduler(p)

        # 执行实际的图像生成
        with profiling.Profiler():
            res = process_images_inner(p)

    finally:
        # 恢复设置
        sd_models.apply_token_merging(p.sd_model, 0)

        if p.override_settings_restore_afterwards:
            for k, v in stored_opts.items():
                setattr(opts, k, v)

                if k == 'sd_vae':
                    sd_vae.reload_vae_weights()

    return res

3.2 process_images_inner函数

[process_images_inner](file:///E:/project/stable-diffusion-webui/modules/processing.py#L925-L1083)函数包含了图像生成的核心逻辑:

def process_images_inner(p: StableDiffusionProcessing) -> Processed:
    """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""

    # 初始化设备和种子
    devices.torch_gc()
    seed = get_fixed_seed(p.seed)
    subseed = get_fixed_seed(p.subseed)

    # 设置面部修复和tiling选项
    if p.restore_faces is None:
        p.restore_faces = opts.face_restoration

    if p.tiling is None:
        p.tiling = opts.tiling

    # 处理模型维度
    if hasattr(shared.sd_model, 'fix_dimensions'):
        p.width, p.height = shared.sd_model.fix_dimensions(p.width, p.height)

    # 设置模型和VAE信息
    p.sd_model_name = shared.sd_model.sd_checkpoint_info.name_for_extra
    p.sd_model_hash = shared.sd_model.sd_model_hash
    p.sd_vae_name = sd_vae.get_loaded_vae_name()
    p.sd_vae_hash = sd_vae.get_loaded_vae_hash()

    # 应用圆形扩展
    modules.sd_hijack.model_hijack.apply_circular(p.tiling)
    modules.sd_hijack.model_hijack.clear_comments()

    # 填充字段和设置提示词
    p.fill_fields_from_opts()
    p.setup_prompts()

    # 设置种子
    if isinstance(seed, list):
        p.all_seeds = seed
    else:
        p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))]

    # 加载嵌入
    if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
        model_hijack.embedding_db.load_textual_inversion_embeddings()

    # 执行脚本处理
    if p.scripts is not None:
        p.scripts.process(p)

    # 初始化输出容器
    infotexts = []
    output_images = []
    
    # 主要生成循环
    with torch.no_grad(), p.sd_model.ema_scope():
        with devices.autocast():
            p.init(p.all_prompts, p.all_seeds, p.all_subseeds)

            # 加载预览模型
            if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN":
                sd_vae_approx.model()

            # 应用UNet
            sd_unet.apply_unet()

        # 设置作业计数
        if state.job_count == -1:
            state.job_count = p.n_iter

        # 迭代批次
        for n in range(p.n_iter):
            p.iteration = n

            # 处理中断和跳过
            if state.skipped:
                state.skipped = False

            if state.interrupted or state.stopping_generation:
                break

            # 重新加载模型权重
            sd_models.reload_model_weights()

            # 设置当前批次的参数
            p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
            p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
            p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
            p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]

            # 设置随机数生成器
            latent_channels = getattr(shared.sd_model, 'latent_channels', opt_C)
            p.rng = rng.ImageRNG((latent_channels, p.height // opt_f, p.width // opt_f), 
                                 p.seeds, subseeds=p.subseeds, 
                                 subseed_strength=p.subseed_strength, 
                                 seed_resize_from_h=p.seed_resize_from_h, 
                                 seed_resize_from_w=p.seed_resize_from_w)

            # 执行批处理前的脚本
            if p.scripts is not None:
                p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)

            # 解析额外网络提示词
            p.parse_extra_network_prompts()

            # 激活额外网络
            if not p.disable_extra_networks:
                with devices.autocast():
                    extra_networks.activate(p, p.extra_network_data)

            # 执行批处理脚本
            if p.scripts is not None:
                p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)

            # 设置条件
            p.setup_conds()

            # 更新额外生成参数
            p.extra_generation_params.update(model_hijack.extra_generation_params)

            # 保存参数文件
            if n == 0 and not cmd_opts.no_prompt_history:
                with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file:
                    processed = Processed(p, [])
                    file.write(processed.infotext(p, 0))

            # 添加注释
            for comment in model_hijack.comments:
                p.comment(comment)

            # 设置作业状态
            if p.n_iter > 1:
                shared.state.job = f"Batch {n+1} out of {p.n_iter}"

            # 应用Alpha调度覆盖
            sd_models.apply_alpha_schedule_override(p.sd_model, p)

            # 执行采样
            with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
                samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, 
                                       seeds=p.seeds, subseeds=p.subseeds, 
                                       subseed_strength=p.subseed_strength, prompts=p.prompts)

            # 执行采样后的脚本
            if p.scripts is not None:
                ps = scripts.PostSampleArgs(samples_ddim)
                p.scripts.post_sample(p, ps)
                samples_ddim = ps.samples

            # 解码样本
            if getattr(samples_ddim, 'already_decoded', False):
                x_samples_ddim = samples_ddim
            else:
                devices.test_for_nans(samples_ddim, "unet")
                x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, 
                                                    target_device=devices.cpu, 
                                                    check_for_nans=True)

            # 处理解码后的图像
            x_samples_ddim = torch.stack(x_samples_ddim).float()
            x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)

            # 清理内存
            del samples_ddim

            if lowvram.is_enabled(shared.sd_model):
                lowvram.send_everything_to_cpu()

            devices.torch_gc()

            state.nextjob()

            # 执行批处理后的脚本
            if p.scripts is not None:
                p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)

                p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
                p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]

                batch_params = scripts.PostprocessBatchListArgs(list(x_samples_ddim))
                p.scripts.postprocess_batch_list(p, batch_params, batch_number=n)
                x_samples_ddim = batch_params.images

            # 定义信息文本函数
            def infotext(index=0, use_main_prompt=False):
                return create_infotext(p, p.prompts, p.seeds, p.subseeds, 
                                      use_main_prompt=use_main_prompt, index=index, 
                                      all_negative_prompts=p.negative_prompts)

            # 保存样本
            save_samples = p.save_samples()

            # 处理每个图像
            for i, x_sample in enumerate(x_samples_ddim):
                p.batch_index = i

                x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
                x_sample = x_sample.astype(np.uint8)

                # 面部修复
                if p.restore_faces:
                    if save_samples and opts.save_images_before_face_restoration:
                        images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", 
                                        p.seeds[i], p.prompts[i], opts.samples_format, 
                                        info=infotext(i), p=p, suffix="-before-face-restoration")

                    devices.torch_gc()

                    x_sample = modules.face_restoration.restore_faces(x_sample)
                    devices.torch_gc()

                image = Image.fromarray(x_sample)

                # 图像后处理脚本
                if p.scripts is not None:
                    pp = scripts.PostprocessImageArgs(image)
                    p.scripts.postprocess_image(p, pp)
                    image = pp.image

                mask_for_overlay = getattr(p, "mask_for_overlay", None)

                # 应用遮罩叠加
                if not shared.opts.overlay_inpaint:
                    overlay_image = None
                elif getattr(p, "overlay_images", None) is not None and i < len(p.overlay_images):
                    overlay_image = p.overlay_images[i]
                else:
                    overlay_image = None

                if p.scripts is not None:
                    ppmo = scripts.PostProcessMaskOverlayArgs(i, mask_for_overlay, overlay_image)
                    p.scripts.postprocess_maskoverlay(p, ppmo)
                    mask_for_overlay, overlay_image = ppmo.mask_for_overlay, ppmo.overlay_image

                # 颜色校正
                if p.color_corrections is not None and i < len(p.color_corrections):
                    if save_samples and opts.save_images_before_color_correction:
                        image_without_cc, _ = apply_overlay(image, p.paste_to, overlay_image)
                        images.save_image(image_without_cc, p.outpath_samples, "", 
                                         p.seeds[i], p.prompts[i], opts.samples_format, 
                                         info=infotext(i), p=p, suffix="-before-color-correction")
                    image = apply_color_correction(p.color_corrections[i], image)

                # 应用叠加
                image, original_denoised_image = apply_overlay(image, p.paste_to, overlay_image)

                # 后期合成处理
                if p.scripts is not None:
                    pp = scripts.PostprocessImageArgs(image)
                    p.scripts.postprocess_image_after_composite(p, pp)
                    image = pp.image

                # 保存图像
                if save_samples:
                    images.save_image(image, p.outpath_samples, "", p.seeds[i], 
                                     p.prompts[i], opts.samples_format, info=infotext(i), p=p)

                # 添加PNG信息
                text = infotext(i)
                infotexts.append(text)
                if opts.enable_pnginfo:
                    image.info["parameters"] = text
                output_images.append(image)

                # 处理遮罩相关输出
                if mask_for_overlay is not None:
                    if opts.return_mask or opts.save_mask:
                        image_mask = mask_for_overlay.convert('RGB')
                        if save_samples and opts.save_mask:
                            images.save_image(image_mask, p.outpath_samples, "", 
                                             p.seeds[i], p.prompts[i], opts.samples_format, 
                                             info=infotext(i), p=p, suffix="-mask")
                        if opts.return_mask:
                            output_images.append(image_mask)

                    if opts.return_mask_composite or opts.save_mask_composite:
                        image_mask_composite = Image.composite(
                            original_denoised_image.convert('RGBA').convert('RGBa'),
                            Image.new('RGBa', image.size),
                            images.resize_image(2, mask_for_overlay, image.width, image.height).convert('L')
                        ).convert('RGBA')
                        
                        if save_samples and opts.save_mask_composite:
                            images.save_image(image_mask_composite, p.outpath_samples, "", 
                                             p.seeds[i], p.prompts[i], opts.samples_format, 
                                             info=infotext(i), p=p, suffix="-mask-composite")
                        if opts.return_mask_composite:
                            output_images.append(image_mask_composite)

            # 清理内存
            del x_samples_ddim
            devices.torch_gc()

        # 处理网格图像
        if not infotexts:
            infotexts.append(Processed(p, []).infotext(p, 0))

        p.color_corrections = None

        index_of_first_image = 0
        unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
        if (opts.return_grid or opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
            grid = images.image_grid(output_images, p.batch_size)

            if opts.return_grid:
                text = infotext(use_main_prompt=True)
                infotexts.insert(0, text)
                if opts.enable_pnginfo:
                    grid.info["parameters"] = text
                output_images.insert(0, grid)
                index_of_first_image = 1
            if opts.grid_save:
                images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], 
                                 p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True), 
                                 short_filename=not opts.grid_extended_filename, p=p, grid=True)

    # 停用额外网络
    if not p.disable_extra_networks and p.extra_network_data:
        extra_networks.deactivate(p, p.extra_network_data)

    devices.torch_gc()

    # 返回处理结果
    res = Processed(
        p,
        images_list=output_images,
        seed=p.all_seeds[0],
        info=infotexts[0],
        subseed=p.all_subseeds[0],
        index_of_first_image=index_of_first_image,
        infotexts=infotexts,
    )

    # 执行最后的后处理脚本
    if p.scripts is not None:
        p.scripts.postprocess(p, res)

    return res

4. 文生图流程详解

4.1 初始化过程

文生图的初始化过程在[StableDiffusionProcessingTxt2Img.init](file:///E:/project/stable-diffusion-webui/modules/processing.py#L1315-L1420)方法中实现:

def init(self, all_prompts, all_seeds, all_subseeds):
    if self.enable_hr:
        # 设置高清修复相关参数
        self.extra_generation_params["Denoising strength"] = self.denoising_strength

        # 处理高清修复检查点
        if self.hr_checkpoint_name and self.hr_checkpoint_name != 'Use same checkpoint':
            self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name)

            if self.hr_checkpoint_info is None:
                raise Exception(f'Could not find checkpoint with name {self.hr_checkpoint_name}')

            self.extra_generation_params["Hires checkpoint"] = self.hr_checkpoint_info.short_title

        # 设置高清修复采样器和调度器
        if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
            self.extra_generation_params["Hires sampler"] = self.hr_sampler_name

        # 设置高清修复提示词
        def get_hr_prompt(p, index, prompt_text, **kwargs):
            hr_prompt = p.all_hr_prompts[index]
            return hr_prompt if hr_prompt != prompt_text else None

        def get_hr_negative_prompt(p, index, negative_prompt, **kwargs):
            hr_negative_prompt = p.all_hr_negative_prompts[index]
            return hr_negative_prompt if hr_negative_prompt != negative_prompt else None

        self.extra_generation_params["Hires prompt"] = get_hr_prompt
        self.extra_generation_params["Hires negative prompt"] = get_hr_negative_prompt

        # 设置其他高清修复参数
        self.extra_generation_params["Hires schedule type"] = None

        if self.hr_scheduler is None:
            self.hr_scheduler = self.scheduler

        # 设置缩放模式
        self.latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
        
        # 计算目标分辨率
        self.calculate_target_resolution()

        # 更新进度条
        if not state.processing_has_refined_job_count:
            if state.job_count == -1:
                state.job_count = self.n_iter
            if getattr(self, 'txt2img_upscale', False):
                total_steps = (self.hr_second_pass_steps or self.steps) * state.job_count
            else:
                total_steps = (self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count
            shared.total_tqdm.updateTotal(total_steps)
            state.job_count = state.job_count * 2
            state.processing_has_refined_job_count = True

        # 设置高清修复步骤和上采样器
        if self.hr_second_pass_steps:
            self.extra_generation_params["Hires steps"] = self.hr_second_pass_steps

        if self.hr_upscaler is not None:
            self.extra_generation_params["Hires upscaler"] = self.hr_upscaler

4.2 采样过程

文生图的采样过程在[StableDiffusionProcessingTxt2Img.sample](file:///E:/project/stable-diffusion-webui/modules/processing.py#L1422-L1503)方法中实现:

def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
    # 创建采样器
    self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)

    if self.firstpass_image is not None and self.enable_hr:
        # 处理已有的首遍图像
        if self.latent_scale_mode is None:
            # 直接使用图像
            image = np.array(self.firstpass_image).astype(np.float32) / 255.0 * 2.0 - 1.0
            image = np.moveaxis(image, 2, 0)

            samples = None
            decoded_samples = torch.asarray(np.expand_dims(image, 0))
        else:
            # 编码图像到潜在空间
            image = np.array(self.firstpass_image).astype(np.float32) / 255.0
            image = np.moveaxis(image, 2, 0)
            image = torch.from_numpy(np.expand_dims(image, axis=0))
            image = image.to(shared.device, dtype=devices.dtype_vae)

            if opts.sd_vae_encode_method != 'Full':
                self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method

            samples = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model)
            decoded_samples = None
            devices.torch_gc()
    else:
        # 正常生成图像
        x = self.rng.next()
        if self.scripts is not None:
            self.scripts.process_before_every_sampling(
                p=self,
                x=x,
                noise=x,
                c=conditioning,
                uc=unconditional_conditioning
            )

        # 执行采样
        samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, 
                                     image_conditioning=self.txt2img_image_conditioning(x))
        del x

        # 如果不启用高清修复,直接返回样本
        if not self.enable_hr:
            return samples

        devices.torch_gc()

        # 解码样本
        if self.latent_scale_mode is None:
            decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, 
                                                            target_device=devices.cpu, 
                                                            check_for_nans=True)).to(dtype=torch.float32)
        else:
            decoded_samples = None

    # 重新加载高清修复模型
    with sd_models.SkipWritingToConfig():
        sd_models.reload_model_weights(info=self.hr_checkpoint_info)

    # 执行高清修复通道
    return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)

4.3 高清修复通道

高清修复通道在[sample_hr_pass](file:///E:/project/stable-diffusion-webui/modules/processing.py#L1505-L1617)方法中实现:

def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
    if shared.state.interrupted:
        return samples

    self.is_hr_pass = True
    target_width = self.hr_upscale_to_x
    target_height = self.hr_upscale_to_y

    def save_intermediate(image, index):
        """保存高清修复前的图像"""
        if not self.save_samples() or not opts.save_images_before_highres_fix:
            return

        if not isinstance(image, Image.Image):
            image = sd_samplers.sample_to_image(image, index, approximation=0)

        info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], 
                              iteration=self.iteration, position_in_batch=index)
        images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], 
                         opts.samples_format, info=info, p=self, suffix="-before-highres-fix")

    # 创建高清修复采样器
    img2img_sampler_name = self.hr_sampler_name or self.sampler_name
    self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)

    if self.latent_scale_mode is not None:
        # 潜在空间上采样
        for i in range(samples.shape[0]):
            save_intermediate(samples, i)

        samples = torch.nn.functional.interpolate(samples, 
                                                size=(target_height // opt_f, target_width // opt_f), 
                                                mode=self.latent_scale_mode["mode"], 
                                                antialias=self.latent_scale_mode["antialias"])

        # 设置图像条件
        if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
            image_conditioning = self.img2img_image_conditioning(
                decode_first_stage(self.sd_model, samples), samples)
        else:
            image_conditioning = self.txt2img_image_conditioning(samples)
    else:
        # 解码空间上采样
        lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)

        batch_images = []
        for i, x_sample in enumerate(lowres_samples):
            x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
            x_sample = x_sample.astype(np.uint8)
            image = Image.fromarray(x_sample)

            save_intermediate(image, i)

            # 上采样图像
            image = images.resize_image(0, image, target_width, target_height, 
                                      upscaler_name=self.hr_upscaler)
            image = np.array(image).astype(np.float32) / 255.0
            image = np.moveaxis(image, 2, 0)
            batch_images.append(image)

        decoded_samples = torch.from_numpy(np.array(batch_images))
        decoded_samples = decoded_samples.to(shared.device, dtype=devices.dtype_vae)

        if opts.sd_vae_encode_method != 'Full':
            self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
        samples = images_tensor_to_samples(decoded_samples, 
                                          approximation_indexes.get(opts.sd_vae_encode_method))

        image_conditioning = self.img2img_image_conditioning(decoded_samples, samples)

    shared.state.nextjob()

    # 裁剪样本
    samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, 
                     self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]

    # 创建噪声
    self.rng = rng.ImageRNG(samples.shape[1:], self.seeds, subseeds=self.subseeds, 
                           subseed_strength=self.subseed_strength, 
                           seed_resize_from_h=self.seed_resize_from_h, 
                           seed_resize_from_w=self.seed_resize_from_w)
    noise = self.rng.next()

    # 激活高清修复额外网络
    if not self.disable_extra_networks:
        with devices.autocast():
            extra_networks.activate(self, self.hr_extra_network_data)

    # 计算高清修复条件
    with devices.autocast():
        self.calculate_hr_conds()

    # 应用Token Merging
    sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))

    # 执行高清修复前的脚本
    if self.scripts is not None:
        self.scripts.before_hr(self)
        self.scripts.process_before_every_sampling(
            p=self,
            x=samples,
            noise=noise,
            c=self.hr_c,
            uc=self.hr_uc,
        )

    # 执行高清修复采样
    samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, 
                                         steps=self.hr_second_pass_steps or self.steps, 
                                         image_conditioning=image_conditioning)

    # 恢复Token Merging设置
    sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())

    self.sampler = None
    devices.torch_gc()

    # 解码最终样本
    decoded_samples = decode_latent_batch(self.sd_model, samples, 
                                         target_device=devices.cpu, 
                                         check_for_nans=True)

    self.is_hr_pass = False
    return decoded_samples

5. 图生图流程详解

5.1 初始化过程

图生图的初始化过程在[StableDiffusionProcessingImg2Img.init](file:///E:/project/stable-diffusion-webui/modules/processing.py#L1650-L1769)方法中实现:

def init(self, all_prompts, all_seeds, all_subseeds):
    # 设置去噪强度
    self.extra_generation_params["Denoising strength"] = self.denoising_strength

    # 设置图像CFG比例
    self.image_cfg_scale: float = self.image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None

    # 创建采样器
    self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
    crop_region = None

    image_mask = self.image_mask

    # 处理遮罩
    if image_mask is not None:
        # 将RGBA遮罩转换为二值遮罩
        image_mask = create_binary_mask(image_mask, round=self.mask_round)

        # 反转遮罩
        if self.inpainting_mask_invert:
            image_mask = ImageOps.invert(image_mask)
            self.extra_generation_params["Mask mode"] = "Inpaint not masked"

        # 应用遮罩模糊
        if self.mask_blur_x > 0:
            np_mask = np.array(image_mask)
            kernel_size = 2 * int(2.5 * self.mask_blur_x + 0.5) + 1
            np_mask = cv2.GaussianBlur(np_mask, (kernel_size, 1), self.mask_blur_x)
            image_mask = Image.fromarray(np_mask)

        if self.mask_blur_y > 0:
            np_mask = np.array(image_mask)
            kernel_size = 2 * int(2.5 * self.mask_blur_y + 0.5) + 1
            np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur_y)
            image_mask = Image.fromarray(np_mask)

        # 全局修复或局部修复
        if self.inpaint_full_res:
            self.mask_for_overlay = image_mask
            mask = image_mask.convert('L')
            crop_region = masking.get_crop_region_v2(mask, self.inpaint_full_res_padding)
            if crop_region:
                crop_region = masking.expand_crop_region(crop_region, self.width, self.height, 
                                                       mask.width, mask.height)
                x1, y1, x2, y2 = crop_region
                mask = mask.crop(crop_region)
                image_mask = images.resize_image(2, mask, self.width, self.height)
                self.paste_to = (x1, y1, x2-x1, y2-y1)
                self.extra_generation_params["Inpaint area"] = "Only masked"
                self.extra_generation_params["Masked area padding"] = self.inpaint_full_res_padding
            else:
                # 如果遮罩为空,切换到img2img模式
                crop_region = None
                image_mask = None
                self.mask_for_overlay = None
                self.inpaint_full_res = False
                massage = 'Unable to perform "Inpaint Only mask" because mask is blank, switch to img2img mode.'
                model_hijack.comments.append(massage)
                logging.info(massage)
        else:
            image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
            np_mask = np.array(image_mask)
            np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
            self.mask_for_overlay = Image.fromarray(np_mask)

        self.overlay_images = []

    # 处理潜在空间遮罩
    latent_mask = self.latent_mask if self.latent_mask is not None else image_mask

    # 设置颜色校正
    add_color_corrections = opts.img2img_color_correction and self.color_corrections is None
    if add_color_corrections:
        self.color_corrections = []
        
    imgs = []
    for img in self.init_images:
        # 保存初始图像
        if opts.save_init_img:
            self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest()
            images.save_image(img, path=opts.outdir_init_images, basename=None, 
                             forced_filename=self.init_img_hash, save_to_dirs=False, 
                             existing_info=img.info)

        # 展平图像
        image = images.flatten(img, opts.img2img_background_color)

        # 调整图像大小
        if crop_region is None and self.resize_mode != 3:
            image = images.resize_image(self.resize_mode, image, self.width, self.height)

        # 应用遮罩
        if image_mask is not None:
            if self.mask_for_overlay.size != (image.width, image.height):
                self.mask_for_overlay = images.resize_image(self.resize_mode, 
                                                          self.mask_for_overlay, 
                                                          image.width, image.height)
            image_masked = Image.new('RGBa', (image.width, image.height))
            image_masked.paste(image.convert("RGBA").convert("RGBa"), 
                              mask=ImageOps.invert(self.mask_for_overlay.convert('L')))

            self.overlay_images.append(image_masked.convert('RGBA'))

        # 裁剪区域处理
        if crop_region is not None:
            image = image.crop(crop_region)
            image = images.resize_image(2, image, self.width, self.height)

        # 填充遮罩区域
        if image_mask is not None:
            if self.inpainting_fill != 1:
                image = masking.fill(image, latent_mask)

                if self.inpainting_fill == 0:
                    self.extra_generation_params["Masked content"] = 'fill'

        # 颜色校正
        if add_color_corrections:
            self.color_corrections.append(setup_color_correction(image))

        # 转换为张量
        image = np.array(image).astype(np.float32) / 255.0
        image = np.moveaxis(image, 2, 0)

        imgs.append(image)

    # 批量处理图像
    if len(imgs) == 1:
        batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
        if self.overlay_images is not None:
            self.overlay_images = self.overlay_images * self.batch_size

        if self.color_corrections is not None and len(self.color_corrections) == 1:
            self.color_corrections = self.color_corrections * self.batch_size

    elif len(imgs) <= self.batch_size:
        self.batch_size = len(imgs)
        batch_images = np.array(imgs)
    else:
        raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")

    image = torch.from_numpy(batch_images)
    image = image.to(shared.device, dtype=devices.dtype_vae)

    # 编码到潜在空间
    if opts.sd_vae_encode_method != 'Full':
        self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method

    self.init_latent = images_tensor_to_samples(image, 
                                               approximation_indexes.get(opts.sd_vae_encode_method), 
                                               self.sd_model)
    devices.torch_gc()

    # 调整大小模式3处理
    if self.resize_mode == 3:
        self.init_latent = torch.nn.functional.interpolate(self.init_latent, 
                                                         size=(self.height // opt_f, self.width // opt_f), 
                                                         mode="bilinear")

    # 处理遮罩
    if image_mask is not None:
        init_mask = latent_mask
        latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], 
                                                  self.init_latent.shape[2]))
        latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
        latmask = latmask[0]
        if self.mask_round:
            latmask = np.around(latmask)
        latmask = np.tile(latmask[None], (self.init_latent.shape[1], 1, 1))

        self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(devices.dtype)
        self.nmask = torch.asarray(latmask).to(shared.device).type(devices.dtype)

        # 填充遮罩内容
        if self.inpainting_fill == 2:
            self.init_latent = self.init_latent * self.mask + create_random_tensors(
                self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]
            ) * self.nmask
            self.extra_generation_params["Masked content"] = 'latent noise'

        elif self.inpainting_fill == 3:
            self.init_latent = self.init_latent * self.mask
            self.extra_generation_params["Masked content"] = 'latent nothing'

    # 设置图像条件
    self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, 
                                                            image_mask, self.mask_round)

5.2 采样过程

图生图的采样过程在[StableDiffusionProcessingImg2Img.sample](file:///E:/project/stable-diffusion-webui/modules/processing.py#L1771-L1791)方法中实现:

def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
    # 生成噪声
    x = self.rng.next()

    # 应用噪声乘数
    if self.initial_noise_multiplier != 1.0:
        self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
        x *= self.initial_noise_multiplier

    # 执行采样前的脚本
    if self.scripts is not None:
        self.scripts.process_before_every_sampling(
            p=self,
            x=self.init_latent,
            noise=x,
            c=conditioning,
            uc=unconditional_conditioning
        )
    
    # 执行图像到图像采样
    samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, 
                                         unconditional_conditioning, 
                                         image_conditioning=self.image_conditioning)

    # 应用遮罩混合
    if self.mask is not None:
        blended_samples = samples * self.nmask + self.init_latent * self.mask

        if self.scripts is not None:
            mba = scripts.MaskBlendArgs(samples, self.nmask, self.init_latent, 
                                       self.mask, blended_samples)
            self.scripts.on_mask_blend(self, mba)
            blended_samples = mba.blended_latent

        samples = blended_samples

    # 清理内存
    del x
    devices.torch_gc()

    return samples

6. 关键技术点分析

6.1 条件处理

Stable Diffusion使用文本条件来指导图像生成过程。WebUI通过以下方式处理条件:

def setup_conds(self):
    # 设置正面和负面提示词
    prompts = prompt_parser.SdConditioning(self.prompts, width=self.width, height=self.height)
    negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, 
                                                  height=self.height, is_negative_prompt=True)

    # 查找采样器配置
    sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
    total_steps = sampler_config.total_steps(self.steps) if sampler_config else self.steps
    self.step_multiplier = total_steps // self.steps
    self.firstpass_steps = total_steps

    # 获取无条件和条件编码
    self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, 
                                         negative_prompts, total_steps, [self.cached_uc], 
                                         self.extra_network_data)
    self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, 
                                        prompts, total_steps, [self.cached_c], 
                                        self.extra_network_data)

6.2 缓存机制

为了提高性能,WebUI实现了条件缓存机制:

def get_conds_with_caching(self, function, required_prompts, steps, caches, 
                          extra_network_data, hires_steps=None):
    """
    使用缓存存储结果,如果相同参数已被使用则返回缓存结果
    """
    if shared.opts.use_old_scheduling:
        old_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(
            required_prompts, steps, hires_steps, False)
        new_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(
            required_prompts, steps, hires_steps, True)
        if old_schedules != new_schedules:
            self.extra_generation_params["Old prompt editing timelines"] = True

    # 构建缓存参数
    cached_params = self.cached_params(required_prompts, steps, extra_network_data, 
                                      hires_steps, shared.opts.use_old_scheduling)

    # 检查缓存
    for cache in caches:
        if cache[0] is not None and cached_params == cache[0]:
            return cache[1]

    cache = caches[0]

    # 计算条件并缓存结果
    with devices.autocast():
        cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, 
                           shared.opts.use_old_scheduling)

    cache[0] = cached_params
    return cache[1]

6.3 内存管理

WebUI实现了多种内存管理策略以适应不同硬件配置:

def close(self):
    """释放处理过程中占用的资源"""
    self.sampler = None
    self.c = None
    self.uc = None
    if not opts.persistent_cond_cache:
        StableDiffusionProcessing.cached_c = [None, None]
        StableDiffusionProcessing.cached_uc = [None, None]

# 在低显存模式下发送到CPU
if lowvram.is_enabled(shared.sd_model):
    lowvram.send_everything_to_cpu()

# 清理GPU内存
devices.torch_gc()

7. 性能优化策略

7.1 Token Merging

Token Merging是一种有效的性能优化技术:

# 应用Token Merging
sd_models.apply_token_merging(p.sd_model, p.get_token_merging_ratio())

# 在高清修复完成后恢复
sd_models.apply_token_merging(p.sd_model, 0)

7.2 批处理优化

WebUI通过批处理来提高效率:

# 设置作业计数
if state.job_count == -1:
    state.job_count = p.n_iter

# 迭代批次
for n in range(p.n_iter):
    # 处理每个批次
    p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
    p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
    p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
    p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]

8. 错误处理与恢复

WebUI具有完善的错误处理机制:

try:
    # 主要处理逻辑
    res = process_images_inner(p)
except Exception as e:
    # 错误处理
    errors.report(f"Error processing images: {e}", exc_info=True)
finally:
    # 资源清理和状态恢复
    sd_models.apply_token_merging(p.sd_model, 0)
    
    if p.override_settings_restore_afterwards:
        for k, v in stored_opts.items():
            setattr(opts, k, v)

总结

通过对Stable Diffusion WebUI图像生成核心流程的深入分析,我们可以看到其设计的精妙之处:

  1. 模块化设计:通过基类和派生类的方式组织代码,提高了代码的可维护性和扩展性
  2. 性能优化:实现了多种优化策略,如缓存、Token Merging、批处理等
  3. 灵活扩展:通过脚本系统提供了丰富的扩展点
  4. 错误处理:具备完善的错误处理和恢复机制
  5. 资源管理:有效地管理内存和计算资源

理解这些核心流程不仅有助于更好地使用WebUI,也为进行二次开发和性能优化提供了重要基础。随着技术的不断发展,图像生成流程也在持续演进,未来可能会引入更多优化技术和功能特性。

参考资料

  1. Stable Diffusion WebUI GitHub仓库: https://github.com/AUTOMATIC1111/stable-diffusion-webui
  2. Stable Diffusion论文: https://arxiv.org/abs/2206.00364
  3. Latent Diffusion Models: https://arxiv.org/abs/2112.10752
  4. WebUI Wiki文档: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

CarlowZJ

我的文章对你有用的话,可以支持

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值