Perturbed-Attention Guidance(PAG) 笔记

Self-Rectifying Diffusion Sampling with Perturbed-Attention Guidance
Github

摘要

近期研究表明,扩散模型能够生成高质量样本,但其质量在很大程度上依赖于采样引导技术,如分类器引导(CG)和无分类器引导(CFG)。这些技术在无条件生成或诸如图像恢复等各种下游任务中往往并不适用。在本文中,我们提出了一种新颖的采样引导方法,称为Perturbed-Attention Guidance(PAG),它能在无条件和条件设置下提高扩散样本的质量,并且无需额外的训练或集成外部模块。PAG 旨在通过去噪过程逐步增强样本的结构。它通过用单位矩阵替换 UNet 中的self-attention map来生成结构退化的中间样本,这是考虑到自注意力机制捕捉结构信息的能力,并引导去噪过程远离这些退化样本。在 ADM 和 Stable Diffusion 中,PAG 在条件甚至无条件场景下都显著提高了样本质量。此外,在诸如空提示的 ControlNet 以及图像修复(如修补和去模糊)等现有引导(如 CG 或 CFG)无法充分利用的各种下游任务中,PAG 也显著提高了基线性能。
在这里插入图片描述
研究表明,在diffusion U-Net的self-attention 模块中,query-key 主要影响structure ,values主要影响appearance。
在这里插入图片描述
如果直接扰动Vt 的话,会导致 out-of-distribution (OOD),因此选择使用单位矩阵替换query-key 部分。
在这里插入图片描述
在这里插入图片描述
那么具体扰动Unet的哪一部分呢?作者使用了5k个样本,在PAG guidance scale s = 2.5 and DDIM 25 step的条件下,表现最好的是mid-block “m0”
在这里插入图片描述

代码

Diffusers 已经支持PAG用在多种任务中,并且可以和ControlNet、 IP-Adapter 一起使用。

from diffusers import AutoPipelineForText2Image
from diffusers.utils import load_image
import torch

pipeline = AutoPipelineForText2Image.from_pretrained(
    "~/.cache/modelscope/hub/AI-ModelScope/stable-diffusion-xl-base-1___0",
    enable_pag=True,  ##add
    pag_applied_layers=["mid"], ##add
    torch_dtype=torch.float16
)
pipeline.enable_model_cpu_offload()

prompt = "an insect robot preparing a delicious meal, anime style"
generator = torch.Generator(device="cpu").manual_seed(0)
images = pipeline(
    prompt=prompt,
    num_inference_steps=25,
    guidance_scale=7.0,
    generator=generator,
    pag_scale=2.5,
).images

images[0].save("pag.jpg")

PAG代码细节

如果同时使用PAG和CFG,那么输入到Unet中prompt_embeds定义如下,也就是[uncond,cond,cond]

    def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free_guidance):

        cond = torch.cat([cond] * 2, dim=0)

        if do_classifier_free_guidance:
            cond = torch.cat([uncond, cond], dim=0)
        return cond

PAGCFGIdentitySelfAttnProcessor2_0计算,其中[uncond,cond]正常计算SA,第二个cond则计算PSA。

class PAGCFGIdentitySelfAttnProcessor2_0:
    r"""
    Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
    PAG reference: https://arxiv.org/abs/2403.17377
    """

    def __init__(self):
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError(
                "PAGCFGIdentitySelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
            )

    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.FloatTensor,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        temb: Optional[torch.FloatTensor] = None,
    ) -> torch.Tensor:
        residual = hidden_states
        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim
        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        # chunk
        hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
        hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])

        # original path
        batch_size, sequence_length, _ = hidden_states_org.shape

        if attention_mask is not None:
            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
            # scaled_dot_product_attention expects attention_mask shape to be
            # (batch, heads, source_length, target_length)
            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

        if attn.group_norm is not None:
            hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states_org)
        key = attn.to_k(hidden_states_org)
        value = attn.to_v(hidden_states_org)

        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads

        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        # the output of sdp = (batch, num_heads, seq_len, head_dim)
        # TODO: add support for attn.scale when we move to Torch 2.1
        hidden_states_org = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        )

        hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        hidden_states_org = hidden_states_org.to(query.dtype)

        # linear proj
        hidden_states_org = attn.to_out[0](hidden_states_org)
        # dropout
        hidden_states_org = attn.to_out[1](hidden_states_org)

        if input_ndim == 4:
            hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)

        # perturbed path (identity attention)
        batch_size, sequence_length, _ = hidden_states_ptb.shape

        if attn.group_norm is not None:
            hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)

        value = attn.to_v(hidden_states_ptb)
        hidden_states_ptb = value
        hidden_states_ptb = hidden_states_ptb.to(query.dtype)

        # linear proj
        hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
        # dropout
        hidden_states_ptb = attn.to_out[1](hidden_states_ptb)

        if input_ndim == 4:
            hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)

        # cat
        hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states

经过Unet后,noise_pred的计算方法。

    def _apply_perturbed_attention_guidance(
        self, noise_pred, do_classifier_free_guidance, guidance_scale, t, return_pred_text=False
    ):
        r"""
        Apply perturbed attention guidance to the noise prediction.

        Args:
            noise_pred (torch.Tensor): The noise prediction tensor.
            do_classifier_free_guidance (bool): Whether to apply classifier-free guidance.
            guidance_scale (float): The scale factor for the guidance term.
            t (int): The current time step.
            return_pred_text (bool): Whether to return the text noise prediction.

        Returns:
            Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: The updated noise prediction tensor after applying
            perturbed attention guidance and the text noise prediction.
        """
        pag_scale = self._get_pag_scale(t)
        if do_classifier_free_guidance:
            noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3)
            noise_pred = (
                noise_pred_uncond
                + guidance_scale * (noise_pred_text - noise_pred_uncond)
                + pag_scale * (noise_pred_text - noise_pred_perturb)
            )
        else:
            noise_pred_text, noise_pred_perturb = noise_pred.chunk(2)
            noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb)
        if return_pred_text:
            return noise_pred, noise_pred_text
        return noise_pred
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值