Vision Transformer with Super Token Sampling

Vision Transformer with Super Token Sampling
具有超级令牌采样的 Vision Transformer
论文地址

1. 简介

现有的ViT模型在浅层网络中存在一下问题:

  1. 高冗余性:浅层特征通常以局部信息为主,直接应用在全局自注意力机制会导致大量的冗余计算。
  2. 长程依赖建模不足:为了减少计算开销,许多方法采用局部自注意力或早起卷积模块,这牺牲了对长程依赖关系的捕获能力。

论文中提出了一种新的诗句Transformer模型–Super Token Vision Transformer(SViT),通过引入超令牌(super tokens)来解决这一问题。下面是文章的几个关键点:
3. 提出Super Token Attention(STA):一种高效的全局上下文建模机制,通过超令牌采样和稀疏关联学习降低计算复杂度 。STA将全局自注意力分解为稀疏关联矩阵和低维注意力矩阵的乘积,从而显著提高效率。
4. 构建SViT架构:SViT架构的每个阶段的核心模块是Super Token Transformer(STT)块,包含三个关键组件:

  • Convolutional Position Embedding(PCE):为每个token添加位置信息。
  • Super Token Attention(STA):捕获全局依赖关系。
  • Convolutional Feed-Forward Network(ConvFFN):增强局部特征表示。

2. 网络结构

下面是该网络的总体结构图:
请添加图片描述
我们可以发现,通过骨干网络后,每次都是将特征图通过STT Block进行处理,一共进行了四次,最终通过 1 × 1 1 \times 1 1×1的卷积核Avg Pool恢复图像的,最后进行分类。下面进行详细解读:

1. Overall

对于给定的输入图像,首先将其输入到一个由四个 3 × 3 3 \times 3 3×3卷积组成的stem模块中,步长分别为2、1、2、1来提取局部特征,对于每个STT模块之间,使用步长为2的 3 × 3 3 \times 3 3×3卷积来减少标记数量。在STT中,给定输入标记张量 X i n X_{in} Xin,首先使用CPE,也就是 3 × 3 3 \times 3 3×3的深度卷积为所有标记添加位置信息,接着使用STA,然后使用卷积FFN来增强局部表示,有两个 1 × 1 1 \times 1 1×1卷积、一个 3 × 3 3 \times 3 3×3深度卷积和GELU组成。下面是STA:

2. Super Token Attention

STA模块包含三个过程:STS(采样)、MHSA(多头自注意力)、TU(上采样),首先通过STS算法将tokens聚合成super tokens,然后是在super tokens space中执行MHSA以建模全局依赖关系,最后通过TU算法将super tokens映射会vision token。

2.1 Super Token Sampling

在STS中,论文中说到“我们调整了SSN中的基于k-means的超级像素算法,将其从像素空间应用到标记空间”。SSN是什么我也不知道,但是我们知道STS它是基于标记空间来做处理的。首先给定一个视觉标记 X ∈ R N × C X \in R^{N \times C} XRN×C(其中 N = H × W N = H \times W N=H×W是标记数量),每个表 X i ∈ R 1 × C X_i \in R^{1 \times C} XiR1×C被嘉定属于m个超级标记 S ∈ R m × C S \in R^{m \times C} SRm×C中的一个,因此需要计算 X − S X-S XS的关联图 Q ∈ R N × m Q \in R^{N \times m} QRN×m。首先,这里会通过对规则网格区域内的标记进行平均(也就是平均池化)来进行超级标记 S 0 S_0 S0。如果网格大小为 h × w h \times w h×w,则超级标记的数量为: m = H h × W w m = \frac{H}{h} \times \frac{W}{w} m=hH×wW Q t = S o f t m a x ( X S t − 1 T d ) Q_t = Softmax(\frac{XS_{t-1}^T}{\sqrt d}) Qt=Softmax(d XSt1T),d是通道数C,超级标记被更新为标记的加权和: S ^ = ( Q t ^ T X ) \hat{S}=(\hat{Q_t}^TX) S^=(Qt^TX),其中 Q t ^ \hat{Q_t} Qt^是列归一化的 Q T Q_T QT。简单来说,就是我们将特征图进行分块,对每一块进行超级标记,也就是展平为一个向量,通过关系图 Q t Q_t Qt来更新超级标记。

  1. 什么是标记空间?
    在ViT中,标记空间(Token Space)是指输入图像经过分割和嵌入后形成的离散化表示空间。标记是输入图像的离散化单元,通常通过对图像进行分块并嵌入得到。每个标记可以看作是一个高维向量,表示图像某个区域的特征。标记空间就是有所有标记组成的集合,通常表示为一个矩阵,在这个空间中,模型通过自注意力机制或其他操作来建模标记之间的关系,从而捕获图像的局部和全局信息。标记空间和像素空间是不一样的,像素空间是输入图像的原始像素值。
  2. 什么是视觉标记(Visual Tokens)?
    在ViT及其相关模型中,视觉标记是输入图像的特征表示,它们通常是输入图像的特征表示。其实就是将输入图像划分为固定大小的非重叠区域,每个区域被展平并线性嵌入为一个向量,这些向量就是视觉标记,视觉标记不是特征图,它可以看作是特征图的一种离散化表示。

2.2 Self-Attention for Super Tokens

这里就是transformer计算注意力图的过程。略

2.3 Token Upsampling

这一步主要是将通过自注意力过程的的tokens映射会vision tokens,并将其添加到原始的X中,这里使用关联映射Q从super tokensS中上采样tokens: T U ( A t t n ( S ) ) = Q A t t n ( S ) TU(Attn(S)) = QAttn(S) TU(Attn(S))=QAttn(S)

3. 实验

我将该论文提出的模块加到了UNetFormer这个结构的解码器阶段,结果如下:
请添加图片描述
光看结果的话,分割效果是很差的,再来看一下各项指标:依次是更改的、SFANet、UNetFormer
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

### Token Merging Technique for Faster Vision Transformer (ViT) In the context of enhancing computational efficiency within Vision Transformers, token merging techniques play a crucial role by reducing the number of tokens processed through subsequent layers. This reduction not only decreases memory usage but also accelerates computation times significantly. #### Reducing Computational Complexity Token merging strategies aim to decrease the sequence length while preserving essential information from earlier stages of processing. One common approach involves hierarchical aggregation where adjacent patches are merged into larger ones at deeper network levels[^1]. By doing so, each layer processes fewer yet more informative tokens compared to its predecessors. For instance, consider an initial image divided into smaller non-overlapping regions or patches; these form individual input tokens fed into early transformer blocks. As data propagates forward through multiple self-attention operations across several encoder stacks, certain methods dynamically combine spatially close elements based on learned representations rather than fixed rules[^3]. This dynamic combination allows models like MPViT to achieve better performance with reduced resource consumption since it effectively captures both local details and global contexts without overwhelming system resources during training or inference phases. #### Implementation Example Below is a simplified Python implementation demonstrating how one might implement such a mechanism using PyTorch: ```python import torch from torch import nn class TokenMerging(nn.Module): def __init__(self, dim_in, dim_out=None): super().__init__() if dim_out is None: dim_out = dim_in * 2 self.linear_proj = nn.Linear(dim_in, dim_out) def forward(self, x): # Input shape: B × L × C B, L, C = x.shape H = W = int(L ** 0.5) # Assuming square-shaped feature maps # Reshape to batched images img = x.view(B, H, W, C).permute(0, 3, 1, 2) # B×C×H×W format # Apply average pooling as simple example of downsampling operation pooled_img = nn.functional.avg_pool2d(img, kernel_size=2, stride=2) # Flatten back to sequence representation after downscaling new_seq_len = (H // 2) * (W // 2) out = pooled_img.permute(0, 2, 3, 1).view(B, new_seq_len, -1) return self.linear_proj(out) # Usage demonstration if __name__ == "__main__": sample_input = torch.randn((8, 64, 768)) # Batch size 8, Sequence Length 64, Feature Dim 768 merger = TokenMerging(768) output = merger(sample_input) print(f"Output Shape After Token Merge: {output.shape}") ``` The provided code snippet illustrates a basic method for implementing token merging via averaging followed by linear projection. More sophisticated approaches may involve learnable parameters that adaptively determine which tokens should be combined together depending upon their semantic relevance identified throughout previous transformations applied over them.
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值