Cross-Attention Makes Inference Cumbersome in Text-to-Image Diffusion Models

Cross-Attention Makes Inference Cumbersome in Text-to-Image Diffusion Models

本研究探讨了文本条件扩散模型中交叉注意在推理过程中的作用。我们发现 交叉注意输出经过几个推理步骤后会收敛到一个固定点 。收敛的时间点自然将整个推理过程分为两个阶段:

  • 初始语义规划阶段: 在此阶段,模型依赖于交叉关注规划面向文本的视觉语义,以及随后的
  • 保真度改进阶: 在此阶段,模型尝试从先前规划的语义生成图像。

作者发现在保真度改进阶段忽略文本条件不仅降低了计算复杂度,而且保持了模型的性能。这产生了一种简单且无需训练的方法,称为TGATE,用于高效生成,它在交叉注意力输出收敛时缓存它,并在剩余的推理步骤中保持固定。我们对MS-COCO验证集的实证研究证实了其有效性。

1. Introduction

一些研究强调了交叉注意对空间控制的重要性(Prompt-to-Prompt, Atten-and-Excite, Boxdiff),但很少(如果有的话)从去噪过程中的时间角度研究其作用。

在这里,我们提出了一个新的问题:“在文本到图像扩散模型的推理过程中,交叉注意力对每一步都是必要的吗?

为此,我们研究了在每个推理步骤中交叉关注对生成图像质量的影响。我们的发现突出了两个反直觉的观点:

  • 在最初的几个步骤中,交叉注意输出收敛到一个固定点。(收敛时间点将扩散模型去噪过程分为两个阶段:)

    • 初始阶段,模型依靠交叉注意规划面向文本的视觉语义, 我们将其表示为 语义规划阶段
    • 后续阶段,模型学习从先前的语义规划中生成图像,我们称之为保真度提升阶段
  • 交叉注意在保真度提高阶段是多余的。

    • 在语义规划阶段,交叉注意对产生有意义的语义起着至关重要的作用。然而,在后期阶段,交叉注意收敛,对生成过程的影响较小。
    • 事实上,在保真度提高阶段绕过交叉注意可以在保持图像生成质量的同时潜在地减少计算成本。

因为交叉注意中的缩放点积是一个二次复杂度的运算。随着现代模型中分辨率和令牌长度的不断增加,交叉注意不可避免地会导致昂贵的计算,成为移动设备等应用程序的重要延迟来源。

这一发现促使我们重新评估交叉注意的作用,并启发我们设计一种简单、有效、无需训练的方法,即暂时控制交叉注意(temporally gating the cross-attention (TGATE)),以提高效率并保持现成扩散模型的生成质量。

💡 需要注意的是:

  • TGATE不会导致性能下降,**因为交叉注意的结果是聚合和冗余的。**事实上,观察到在基线上的初始化距离(FID)略有改善。
  • TGATE可以在每张图像上减少65T多次累积操作(Multiple-Accumulate Operations,MACs),并在保真度提高阶段减小0.5B个参数,与基线模型(SDXL)相比,在没有训练成本的情况下,延迟减少了约50%。

2. Temporal Analysis of Cross-Attention

2.1 Cross-Attention.

UNet 中的交叉注意数学定义如下:

Cct=Softmax(Qzt⋅Kcd)⋅Vc\mathbf{C}_c^t=\text{Softmax}(\frac{Q_z^t\cdot K_c}{\sqrt{d}})\cdot V_cCct=Softmax(d QztKc)Vc

其中,QztQ_z^tQztztz_tzt 的投影,KcK_cKcVcV_c

### Cross-Attention Mechanism Required GPU Memory Size The amount of GPU memory required by a cross-attention mechanism depends on several factors including the dimensions of queries, keys, values, batch size, sequence lengths involved in both source and target sequences, as well as whether optimizations such as quantization or specific attention mechanisms like shifted sparse attention are applied. For standard implementations without optimization techniques: Given that each element typically uses float32 representation which takes up 4 bytes, - If \(Q\) represents Queries matrix with dimension \([B, T_q, d_k]\), - \(K\) stands for Keys matrix with dimension \([B, T_v, d_k]\), - And \(V\) denotes Values matrix also having shape \([B, T_v, d_v]\), where \(B\) is the batch size, \(T_q\) and \(T_v\) represent query and value/key token counts respectively while \(d_k\) and \(d_v\) denote key/value depth. The total immediate memory consumption can be roughly estimated using these matrices' sizes plus some overhead from intermediate computations during scaled dot-product calculation within cross-attention layers[^1]. However, when applying advanced methods mentioned previously—such as position interpolation expanding context length efficiently along with reducing precision via QLoRA to 4-bit weights—the overall demand on GPU RAM could significantly decrease due to more compact representations and potentially smaller effective model parameters[^2]. Additionally, batching multiple inference requests together helps increase efficiency per unit of allocated VRAM through better resource utilization rates over time. In practical scenarios involving large-scale models deployed under constrained hardware conditions, developers might further optimize their designs based upon insights into how different architectural choices impact computational costs associated not only with forward passes but backward propagation steps too whenever applicable; this includes careful selection between fully connected vs convolution-based components depending on task requirements since certain structures may offer superior parallel processing capabilities leading towards lower latency figures despite possibly higher absolute memory footprints compared against alternatives emphasizing locality preservation across longer distances inside input data streams[^3]. ```python import torch def estimate_cross_attention_memory(batch_size, seq_len_query, seq_len_value, dim_key, dim_val): """ Estimate the approximate GPU memory usage for one layer of cross-attention. Args: batch_size (int): Batch size used in computation. seq_len_query (int): Length of the query sequence. seq_len_value (int): Length of the value sequence. dim_key (int): Dimensionality of key vectors. dim_val (int): Dimensionality of value vectors. Returns: int: Estimated memory requirement in bytes. """ # Estimation considering single precision floating point numbers (float32) mem_per_element = 4 q_matrix_mem = batch_size * seq_len_query * dim_key * mem_per_element k_matrix_mem = batch_size * seq_len_value * dim_key * mem_per_element v_matrix_mem = batch_size * seq_len_value * dim_val * mem_per_element attn_scores_mem = batch_size * seq_len_query * seq_len_value * mem_per_element output_mem = batch_size * seq_len_query * dim_val * mem_per_element # Rough estimation adding all parts together return sum([ q_matrix_mem, k_matrix_mem, v_matrix_mem, attn_scores_mem, output_mem ]) # Example Usage batch_size_example = 32 seq_len_query_example = 512 seq_len_value_example = 768 dim_key_example = 64 dim_val_example = 64 estimated_memory_usage = estimate_cross_attention_memory( batch_size=batch_size_example, seq_len_query=seq_len_query_example, seq_len_value=seq_len_value_example, dim_key=dim_key_example, dim_val=dim_val_example) print(f"Estimated memory usage: {estimated_memory_usage / (1024 ** 2)} MB") ```
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值