segment-anything-fast pytorch官方给出的优化版本

部署运行你感兴趣的模型镜像

sam opt by torch

非常非常非常强烈建议大家去读原文

https://pytorch.org/blog/accelerating-generative-ai/
https://github.com/pytorch-labs/segment-anything-fast
在这里插入图片描述

Optimizations

Baseline

原生SAM使用 float32 以及 batch size = 1。使用 PyTorch Profiler对内核进行跟踪,发现了两个可以进一步优化的地方:

  • long calls to aten::index
    由于tensor的index索引而产生的底层调用(例如[ ]),虽然aten::index本身的耗时很短,但是它会同时启动两个内核,从而导致 CUDA流同步(cudaStreamSynchronize)阻塞,CPU会等待GPU完成计算和流同步,导致空闲时间增加。
  • matrix multiplication on GPU
    在transformer-based 模型中常见的大矩阵相乘的时间损耗

Bfloat16 Half precision (+GPU syncs and batching)

  1. 为了减少大矩阵相乘的耗时,一个常见的做法就是使用半精度类型,使用 bfloat16
  2. 查找源码中存在GPU同步的地方,并试图将其移除
-   q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
-   k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
********************************************************************************
+   q_coords = torch.arange(q_size, device=rel_pos.device)[:, None] * max(k_size / q_size, 1.0)
+   k_coords = torch.arange(k_size, device=rel_pos.device)[None, :] * max(q_size / k_size, 1.0)
-   point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
-   point_embedding[labels == -1] = 0.0
-   point_embedding[labels == -1] += self.not_a_point_embed.weight
-   point_embedding[labels == 0] += self.point_embeddings[0].weight
-   point_embedding[labels == 1] += self.point_embeddings[1].weight
********************************************************************************
+   point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
+   point_embedding = torch.where((labels == -1).unsqueeze(-1).expand_as(point_embedding),
                                    torch.zeros_like(point_embedding) + self.not_a_point_embed.weight,
                                    point_embedding)
+   point_embedding = torch.where((labels == 0).unsqueeze(-1).expand_as(point_embedding),
                                    point_embedding + self.point_embeddings[0].weight,
                                    point_embedding)
+   point_embedding = torch.where((labels == 1).unsqueeze(-1).expand_as(point_embedding),
                                    point_embedding + self.point_embeddings[1].weight,
                                    point_embedding)

torch.where() << tensor

进一步对耗时和资源进行跟踪,同时更改batch_size为8来更直观的进行对比
在这里插入图片描述

在查看每个内核所花费的时间时,现在 SAM 的大部分 GPU 时间都花费在逐元素内核(elementwise kernels)和 softmax 操作上。矩阵乘法的相对开销小了很多。
当前的优化速度已经3x了
在这里插入图片描述

Torch.compile (+graph breaks and CUDA graphs)

PyTorch 对 torch.compile 做了以下优化:

  1. Fusing together sequences of operations such as nn.LayerNorm or nn.GELU into a single GPU kernel that is called and
  2. Epilogues: fusing operations that immediately follow matrix multiplication kernels to reduce the number of GPU kernel calls.

通过这些优化可以减少GPU 全局内存往返次数(roundtrips),从而加快推理速度。
advance:

  • 采用max-autotune模式,以获取最快的优化速度(编译时间最长,试图提供最快代码)
  • 通过设置 TORCH_LOGS=‘graph_breaks,recompiles’ 来手动验证是否出现图中断或重新编译
  • 通过zeros-padding来确保静态输入尺寸,从而避免重编译
predictor.model.image_encoder = torch.compile(
                                predictor.model.image_encoder,
                                mode=use_compile)

在这里插入图片描述

观测到softmax和GMM(General Matrix Multiply)占据了大量的计算时间以及现有的速度提升:
在这里插入图片描述
在这里插入图片描述

SDPA: scaled_dot_product_attention

接下来将解决注意力机制在transformer中的常见问题:计算损耗随序列长度呈二次方增长。
PyTorch 的 scaled_dot_product_attention 操作建立在 Flash Attention、FlashAttentionV2 和 xFormer 的内存高效注意力原理之上,从而可以显著加快 GPU 注意力的速度。

-   attn = (q * self.scale) @ k.transpose(-2, -1)
-   if self.use_rel_pos:
-       attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
-   attn = attn.softmax(dim=-1)
-   x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
********************************************************************************
+   attn_bias = None              
+   if self.use_rel_pos:
+       attn_bias = add_decomposed_rel_pos(q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
+   q = q.view(B, self.num_heads, H * W, -1)
+   k = k.view(B, self.num_heads, H * W, -1)
+   v = v.view(B, self.num_heads, H * W, -1)
+   attn_bias = attn_bias.view(B, self.num_heads, attn_bias.size(-1), attn_bias.size(-1))
+   x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias)
+   x = x.view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
-   _, _, _, c_per_head = q.shape
-   attn = q @ k.permute(0, 1, 3, 2)  # B x N_heads x N_tokens x N_tokens
-   attn = attn / math.sqrt(c_per_head)
-   attn = torch.softmax(attn, dim=-1)
-   # Get output
-   out = attn @ v
********************************************************************************
+   out = torch.nn.functional.scaled_dot_product_attention(q, k, v)

现在可以看到内存高效的注意力内核占用了 GPU 上大量的计算时间:
在这里插入图片描述

使用 PyTorch 的原生 scaled_dot_product_attention,可以显著增加批处理大小。下图为批大小为 32 及以上的变化。
在这里插入图片描述

Triton: Custom SDPA for fused relative positional encoding

在 image encoderk中可以明显看到明显的内存峰值:
在这里插入图片描述

进一步,峰值的出现具体发生在add_decomposed_rel_pos中:

    attn = (
        attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
***---->    ).view(B, q_h * q_w, k_h * k_w)<------***

这里的attn变量是2个小尺寸的tensor加法: r e l h ∈ ( B , q h , q w , k h , 1 ) , r e l w ∈ ( B , q h , q w , 1 , k w ) rel_h \in (B, q_h, q_w, k_h, 1), rel_w \in (B, q_h, q_w, 1, k_w) relh(B,qh,qw,kh,1),relw(B,qh,qw,1,kw)
如果我们不分配这么大的attn张量,而是将两个较小的 r e l h rel_h relh r e l w rel_w relw张量传入SDPA,并根据需要构建attn,预计会获得显著的性能提升。
但这并不是一个简单的修改;SDPA核心经过高度优化,是用CUDA编写的。但可以求助于Triton,他们有一个易于理解和使用的教程,介绍了FlashAttention实现。经过一番深入研究并与xFormer的Daniel Haziza密切合作,我们找到了一个输入形状的案例,相对来说比较容易实现核心的融合版本。
在这里插入图片描述
在这里插入图片描述

NT: NestedTensor and batching predict_torch

在 mask prediction pipeline中,我们观测到对于每张图像,我们有一个associated size, coords, 和 fg_labels Tensor。这些张量的batch_size大小都不同。每张图像本身的大小也不同。这种数据表示看起来像是Jagged Data (不规则数据)。使用PyTorch最近发布的NestedTensor,可以将batch coords 和 fg_labels Tensors修改为1个NestedTensor。这可以给prompt encoder 和 mask decoder带来很大提升:

torch.nested.nested_tensor(data, dtype=dtype, layout=torch.jagged)

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

int8: quantization and approximating matmul

在从fp32到bfloat16的基础上再进一步,引入int8量化。
在量化方法方面,我们关注于Dynamic quantization (动态量化)
减少精度可以立即带来峰值内存节省,但要实现推理加速,我们必须充分利用int8通过SAM操作。这需要构建一个高效的int8@int8矩阵乘法核心,以及转换逻辑,从高精度到低精度(量化),以及从低精度到高精度的逆向转换(反量化)。利用torch.compile的功能,我们可以将这些量化和反量化例程编译和融合成高效的单个核心和矩阵乘法的最后
虽然在推理时对模型进行量化常常会导致一些精度下降,但SAM在低精度推理时表现尤为稳健,几乎没有精度损失。
在这里插入图片描述

sparse: Semi-structured (2:4) sparsity

矩阵乘法仍然是我们的瓶颈。我们可以转向模型加速手册中另一种经典方法来近似矩阵乘法:稀疏化。通过稀疏化我们的矩阵(即将某些值置零),理论上可以使用更少的位来存储权重和激活张量。我们决定在张量中将哪些权重设为零的过程称为剪枝。剪枝的理念是权重张量中的小权重对层的净输出贡献较小,通常是权重与激活的乘积。剪枝掉小权重可以潜在地减小模型大小而几乎不损失精度。

剪枝的方法多种多样,从完全非结构化到高度结构化。虽然非结构化剪枝理论上对精度的影响最小,但 GPU 在进行大型密集矩阵乘法方面尽管非常高效,然而在稀疏情况下可能还会遭受显着的性能下降。PyTorch 最近支持的一种剪枝方法旨在寻求平衡,称为半结构化(或 2:4)稀疏性。这种稀疏存储将原始张量减少了 50%,同时产生密集张量输出。参见下图的说明。
在这里插入图片描述

为了使用这种稀疏存储格式和相关的快速内核,接下来要做的是剪枝权重。本文在 2:4 的稀疏度下选择最小的两个权重进行剪枝,将权重从默认的 PyTorch(“strided”)布局更改为这种新的半结构化稀疏布局很容易。要实现 apply_sparse (model),只需要 32 行 Python 代码:


import torch
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor

# Sparsity helper functions
def apply_fake_sparsity(model):
    """
    This function simulates 2:4 sparsity on all linear layers in a model.
    It uses the torch.ao.pruning flow.
    """
    # torch.ao.pruning flow
    from torch.ao.pruning import WeightNormSparsifier
    sparse_config = []
    for name, mod in model.named_modules():
        if isinstance(mod, torch.nn.Linear):
            sparse_config.append({"tensor_fqn": f"{name}.weight"})

    sparsifier = WeightNormSparsifier(sparsity_level=1.0,
                                      sparse_block_shape=(1,4),
                                      zeros_per_block=2)
    sparsifier.prepare(model, sparse_config)
    sparsifier.step()

    sparsifier.step()
    sparsifier.squash_mask()


def apply_sparse(model):
    apply_fake_sparsity(model)
    for name, mod in model.named_modules():
        if isinstance(mod, torch.nn.Linear):
            mod.weight = torch.nn.Parameter(to_sparse_semi_structured(mod.weight))

在这里插入图片描述

Conclusion

总结一下,我们很高兴地宣布我们迄今为止实现的最快的Segment Anything。我们使用了许多新发布的功能,将Meta原始的SAM纯粹重写为PyTorch,并没有损失准确性:

  • Torch.compile PyTorch的原生JIT编译器,提供快速、自动融合PyTorch操作的功能
  • GPU量化加速采用降低精度操作的模型
  • Scaled Dot Product Attention (SDPA) Attention的新型、内存高效实现
  • Semi-Structured (2:4) Sparsity 使用更少的位数存储权重和激活的模型加速
  • Nested Tensor 高度优化的、用于非均匀批处理和图像尺寸的不规则数组处理
  • Triton kernels Triton的自定义GPU操作,通过Triton轻松构建和优化

您可能感兴趣的与本文相关的镜像

PyTorch 2.6

PyTorch 2.6

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值