即插即用&无需微调 | ViT自适应性token采样技术(ECCV2022)

作者 | 努力努力再努力的 编辑 | FightinCV

点击下方卡片,关注“自动驾驶之心”公众号

ADAS巨卷干货,即可获取

点击进入→自动驾驶之心【全栈算法】技术交流群

1. 论文和代码地址

1ae66b5e8fe82b5ad2545748853f34ae.png

论文题目:ATS:Adaptive Token Sampling For Efficient Vision Transformers

论文地址:https://doi.org/10.48550/arXiv.2111.15667[1]

代码地址:https://github.com/adaptivetokensampling/ATS[2]

2. 动机

基于CNNs的图像分类任务取得了巨大的进展,视觉Transformer的出现为图像分类任务提供了新的强有力的模型,传统的vision transformer的计算成本高、参数量大,模型不适合部署到边缘设备,虽然可以通过减少网络中的token数量来实现GFLOPs数量的减少,但是没有办法针对不同的输入图片设置最佳的tokens。

为了解决这个问题,DynamicViT提出了训练一个token打分神经网络来预测哪个tokens是多余的。但打分网络额外引入了计算开销,需要和视觉transformer一起进行训练,需要修改损失函数,增加一个额外了损失项和超参数。应用到不同的设备上的时候,模型需要重新训练。

在分类过程中,不是所有图像信息都是必要的,图像中的部分像素是多余的或者不相干的,是否相关的判断取决于图片本身,本文提出了基于自注意矩阵的AST模块,对token进行打分,以最小的信息损失去掉输入中的冗余信息,解决了DynamicViT引入额外开销的限制,并且不需要预训练就可以达到降低视觉Transformer计算成本,减少模型参数的效果。

3. 方法

模型准确性与输入patch数量相关,传统CNN使用池化操作,导致网络的空间分辨率逐渐下降,会导致模型的准确度下降,这种,静态采样会导致忽略重要信息或者信息冗余。

所以本文提出了在不同的stage自适应token数量的方法,以期实现不忽略重要信息并且不浪费计算资源的目标。

ATS模块是一个无参数的可微分模块,其中:K是token采样的最大数量;K'是采样tokens,具有动态自适应的能力。

3.1 token打分

标准自注意力层的注意力矩阵A计算公式,其中Q是queries,K是keys,V是values,d表示:

a6b26b55f487d21fe307919005bdae41.png

由于Softmax函数,注意力矩阵A的每行的和为1。输出token是通过注意力矩阵权重对V进行加权:

d997302238832034495c6c1a953cee9b.png

A中的注意力权重表示了所有tokens对输出t3oken的贡献度,A1,j表示输入token j 对输出分类token 的重要性,用A中第一行的权重作为注意力矩阵A剪枝的显著性评分。

输出token由A和V决定,所以将V的范数考虑到token j 的显著性打分计算公式中,如下:

d33346adeead445ad0e9798c17a4eabf.png

对于多头注意力层,对每个头的打分进行计算,然后加总。

3.2token采样

对每个token有了打分后,需要从注意力矩阵中删除对应的token。若采用直接删除打分低的tokens没有得到很好的结果。

这可能是因为early stage的特征并没有明显的区别,需要对比所有的token,通过softmax函数降低相似的token的权重,但是相似的token中有一个是有用的,若全部丢弃则会忽略信息。

本文提出根据分数进行抽样的方法:相似token采样比例与该相似token占总token的比例相等,并且,早期要比后期选择更多的tokens。

使用逆变换抽样,基于tokens对应的分数S对tokens进行抽样,分数S是归一化的,可以解释为概率,所以累计概率分布函数CDF是:

60dedab74bc92fdf72ef73177af627db.png

值得注意的是,因为j=1表示分类token,因为分类token需要被保留,所以j从2开始,通过累计概率分布函数CDF的反函数得到抽样函数:

ed6938b8bfb50ac1d46662f9d00d3f38.png

选择tokens之后,重新调整注意力矩阵A,得到

a5c4df6a5daf21f964f31a614d6c6bd7.png

使用调整后的注意力矩阵计算输出:

b9ad6440246e73070756af9a762e4148.png

3.3模型架构

6c2e5cd054e68dff93eeee92a9a31702.png

ATS被整合到视觉Transformer块的自注意层中,首先利用自注意层中分类token的自注意权重来计算,然后对分数使用逆变换来选择一个tokens的子集,最后对输出tokens软降采样,以最小的信息损失从输出标记中去除冗余信息

4.实验

4.1SOTA结果

deb2b5badb2a310c8cbeec17fa4ca3ae.png

融合ATS模块自适应模型与在ImageNet-1K 上进行图像分类。我们将ATS模块合并到DeiT-S 模型的第3到第11阶段。将ATS模块集成到CvT- 13 和CvT-21 的第三阶段的第1到第9块中。可以看出ATS模块减少了所有视觉transformer模型的GFLOPs,而没有向主干模型添加任何额外的参数。将加入ATS的模型与DynamicViT和HVT进行比较,后者为模型增加了额外的参数,我们的方法在精度和GFLOPs之间实现了更好的权衡。

45d2b150a6c88875663b3ab2ffd4163c.png

如图1所示,该方法显著降低了不同尺寸的视觉transformer的gflop,而没有明显的精度损失。

4.2模型参数消融实验

eca894d4e4937fe218e41cc019b77089.png

图(a)显示了采用逆变换采样方法对输入标记进行软降采样。评估了模型在选择显著性得分最高的前K标记s时的性能,逆变换抽样方法优于Top-K选择。图(b)显示了微调能提高精度。图(c)显示了单阶段和多阶段将ATS块集成到视觉变压器模型中的效果。在单阶段模型中,我们将ATS模块集成到DeiT-S的第三阶段中。在多阶段模型中,我们将我们的ATS模块集成到DeiT-S的第3-11阶段中。多阶段DeiT-S+ATS的性能优于单阶段DeiT-S+ATS。这是因为多阶段DeiT-S+ATS模型可以通过在早期阶段丢弃更少的token来逐渐减少GFLOPs,而单阶段DeiT-S+ATS模型必须在早期阶段丢弃更多的tolen来达到相同的GFLOPs水平。

4.3显著性评分方法消融实验

b27c9d40fdc99a75a14a7e38a435eab9.png

使用分类token的注意权重在较低的FLOPs状态下表现得更好。这可能是由于分类token的注意权重是选择候选token的一个更强的信号,分类token稍后将用于预测模型最后阶段的分类概率。因此,其对应的注意权重显示了哪些token对输出分类token的影响更大 只显示相加权重之后注意力最高的token对分类token不一定有用。随机选择token,不是选择分类token,使用注意权重来进行分数分配,这种方法相对很差。

等式加入V的二范式进行实验,发现效果打给提高了0.2%

4.4自适应采样

70f90d17811f5fdc2847b4cabac71bfe.png

每个ATS阶段的采样token数量的直方图,所有阶段,选定的token的数量和所有图像都是不一样的。

4.5可视化

a17c8d1c4c9b1028a4a5c95b206249e7.png

为了更好地理解ATS模块的操作方式,在图3中可视化了token采样过程(逆变换采样)。在DeiT-S网络的第3到第11阶段整合了ATS模块。在每个阶段被丢弃的token被表示为输入图像上的一个掩码。DeiT-S+ATS模型逐渐去除不相关的标记,并对对模型预测更重要的标记进行采样。我们的方法将与目标对象相关的token标识为信息最丰富的token。

5. 总结

本文提出了一个新的无参数模型ATS,AST在ViT的stage重选择信息丰富的最独特的标记,以便为每个图像使用尽可能多的标记,但不超过必要的标记。ATS模块可以直接引入到训练好的视觉transformer中,不需要引入额外的参数,也可以不用额外的训练和微调。在ImageNet-1K图像识别数据集上评估了我们的方法,并将我们的ATS模块合并到三个不同的最先进的视觉Transformer中。结果表明,ATS模块将计算成本(GFLOPs)降低了27%到37%,而精度下降可以忽略不计。

参考资料

[1]

https://doi.org/10.48550/arXiv.2111.15667: https://doi.org/10.48550/arXiv.2111.15667

[2]

https://github.com/adaptivetokensampling/ATS: https://github.com/adaptivetokensampling/ATS

往期回顾

DETR系列大盘点 | 端到端Transformer目标检测算法汇总!

12d1043ebfa1fe488cfccc3785b6fcdb.png

自动驾驶之心】全栈技术交流群

自动驾驶之心是首个自动驾驶开发者社区,聚焦目标检测、语义分割、全景分割、实例分割、关键点检测、车道线、目标跟踪、3D目标检测、BEV感知、多传感器融合、SLAM、光流估计、深度估计、轨迹预测、高精地图、规划控制、模型部署落地、自动驾驶仿真测试、硬件配置、AI求职交流等方向;

d737cdaa9ad5d9026a3cdab2dea4a8a4.jpeg

添加汽车人助理微信邀请入群

备注:学校/公司+方向+昵称

自动驾驶之心【知识星球】

想要了解更多自动驾驶感知(分类、检测、分割、关键点、车道线、3D目标检测、多传感器融合、目标跟踪、光流估计、轨迹预测)、自动驾驶定位建图(SLAM、高精地图)、自动驾驶规划控制、领域技术方案、AI模型部署落地实战、行业动态、岗位发布,欢迎扫描下方二维码,加入自动驾驶之心知识星球(三天内无条件退款),日常分享论文+代码,这里汇聚行业和学术界大佬,前沿技术方向尽在掌握中,期待交流!

b24f7f7c97ebce01f5cf809172948a71.jpeg

### ViT(Vision Transformer)即插即用特征增强模块概述 视觉Transformer (ViT) 是一种基于Transformer架构的模型,在计算机视觉领域取得了显著成果。为了进一步提升其性能,许多研究人员提出了多种即插即用的特征增强模块。这些模块可以轻松集成到现有的ViT框架中,从而提高模型的表现能力。 #### 完全注意力网络 (FANs) 完全注意力网络 (Fully Attentional Networks, FANs)[^1] 提出了通过强化自注意力机制来改善ViT的学习能力和鲁棒性。具体来说,FANs引入了一种新的注意力通道模块,该模块增强了自注意力在学习鲁棒特征表示方面的作用。这种方法不仅提升了模型对噪声数据的容忍度,还在面对不同类型的图像损坏时表现得更加稳定。 以下是实现FANs的一个简单代码片段: ```python import torch.nn as nn class FullyAttentionModule(nn.Module): def __init__(self, dim): super(FullyAttentionModule, self).__init__() self.attention = nn.MultiheadAttention(dim, num_heads=8) def forward(self, x): attn_output, _ = self.attention(x, x, x) return attn_output + x # Residual connection ``` #### 协调注意力 (CA - Coordinate Attention) 协调注意力(CA)[^2]是一种轻量级的注意力机制,旨在捕捉空间维度上的依赖关系。它通过对输入特征图的高度和宽度分别应用一维卷积操作,提取坐标级别的上下文信息。此方法能够有效减少计算开销并保持较高的精度增益。 下面展示如何将CA应用于PyTorch中的ViT结构: ```python import torch from torch import nn class CoordAtt(nn.Module): def __init__(self, inp, oup, reduction=32): super(CoordAtt, self).__init__() self.pool_h = nn.AdaptiveAvgPool2d((None, 1)) self.pool_w = nn.AdaptiveAvgPool2d((1, None)) temp_c = max(oup//reduction, 4) self.conv1 = nn.Conv2d(inp, temp_c, kernel_size=1, stride=1, padding=0) self.bn1 = nn.BatchNorm2d(temp_c) self.act = nn.ReLU() self.conv_h = nn.Conv2d(temp_c, oup, kernel_size=1, stride=1, padding=0) self.conv_w = nn.Conv2d(temp_c, oup, kernel_size=1, stride=1, padding=0) def forward(self, x): identity = x n,c,h,w = x.size() x_h = self.pool_h(x) x_w = self.pool_w(x).permute(0, 1, 3, 2) y = torch.cat([x_h, x_w], dim=2) y = self.conv1(y) y = self.bn1(y) y = self.act(y) x_h, x_w = torch.split(y, [h, w], dim=2) x_w = x_w.permute(0, 1, 3, 2) a_h = self.conv_h(x_h).sigmoid() a_w = self.conv_w(x_w).sigmoid() out = identity * a_w * a_h return out ``` 以上两种技术只是众多可用于ViT即插即用模块的一部分。每种模块都有各自的特点以及适用场景,因此可以根据实际需求选择合适的方案进行实验验证。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值