paper:MetaFormer Is Actually What You Need for Vision
official implementation:https://github.com/sail-sg/poolformer
third-party implementation:https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/metaformer.py
背景
这篇文章讨论了Transformer在计算机视觉任务中的应用,并提出了一种新的框架MetaFormer。研究者注意到,尽管Transformer的attention机制通常被认为是其优越性能的关键,但最近的研究表明,使用其他简单的token mixer(如spatial MLP)也能取得良好效果。因此,文章探讨了Transformer的通用架构MetaFormer,而不仅仅是特定的token mixer。
出发点
文章的出发点是验证一个假设,即Transformer模型的成功主要归因于其通用架构MetaFormer,而不是特定的token mixer(如attention)。为验证这一假设,研究者将Transformer中的attention模块替换为一个简单的空间池化操作,形成一个新的模型——PoolFormer。
创新点
- MetaFormer框架:提出并验证了MetaFormer这一通用架构,证明其比特定的token mixer(如attention)更为关键。
- 简单的token mixer:通过使用简单的空间池化操作,研究者验证了即使没有复杂的token mixer,MetaFormer架构也能实现优越的性能。
- 性能验证:实验结果显示,PoolFormer在ImageNet-1K分类任务中超过了精调的Vision Transformer和MLP-like基线模型,参数和计算量都显著减少。
总体而言,文章通过引入和验证MetaFormer这一新的通用架构,对Transformer和MLP-like模型在计算机视觉任务中的成功原因进行了深入探讨,并为未来的研究提供了新的方向。
方法介绍
MetaFormer
MetaFormer的整体结构如图1所示,输入 \(I\) 首先通过input embedding的处理,比如ViTs中的patch embedding
其中 \(X\in\mathbb{R}^{N\times C}\) 表示序列长度为 \(N\) 嵌入维度为 \(C\) 的embedding token,然后输入到重复的MetaFormer block中,每个block包含两个residual sub-blocks。具体来说,第一个sub-block包含一个token mixer用来进行token之间的信息交互,表示如下
其中 \(Norm(\cdot)\) 表示规范化比如Layer Normalization或Batch Normalization。\(TokenMixer(\cdot)\) 表示一个用于混合token信息的模块,比如Transformer中的各种attention机制或MLP类模型中的spatial MLP。注意token mixer的主要作用是用于混合token信息,尽管有些token mixer同时也混合通道信息例如attention。
第二个sub-block主要由一个具有非线性激活的两层MLP组成
其中 \(W_1\in \mathbb{R}^{C\times rC}\) 和 \(W_2\in \mathbb{R}^{rC\times C}\) 是MLP的可学习参数,\(r\) 是expansion ratio,\(\sigma(\cdot)\) 是非线性激活函数例如ReLU或GELU。
MetaFormer描述了一种通用的体系结构,通过指定token mixer的具体设计,可以立即获得不同的模型。例如如果将token mixer指定为attention或spatial MLP,就可以得到Transformer或MLP类的模型。
PoolFormer
作者认为Transformer或MLP类模型的成功主要是因为MetaFormer这种通用的架构,为了证明这点,作者故意使用了一个非常简单的操作即池化来作为token mixer,它没有可学习的参数,只是通过平均聚合每个位置周围token的特征来得到该位置的输出,对于输入 \(T^{C\times H\times W}\),池化可以表示为
其中 \(K\) 是pooling size,注意由于MetaFormer block本身有一个residual connection,这里减去了自身。
PoolFormer的整体结构如图2所示,这里和CNN一样采用了一种金字塔结构。具体来说一共有4个stage,每个stage的分辨率减半,一共有两组embedding size,对于小尺寸模型四个stage的embedding size分别为64,128,320,512,对于大尺寸模型为96,192,384,768。假设模型一共有 \(L\) 个block,四个stage的block数量分别为L/6, L/6, L/2, L/6。不同PoolFormer变种的具体配置如下
此外,作者对Layer Normalization进行了修改,和原始的LN沿通道维度计算均值和方差不同,这里修改后的LN即Modified Layer Normalization(MLN)沿通道和token维度计算均值和方差。
实验结果
在ImageNet上和其它结构的模型的性能对比如下表所示,可以看到,PoolFormer尽管结构简单也取得了极具竞争力或者更好的结果。
在COCO目标检测任务上,效果好好于ResNet
在ADE20K语义分割任务上效果也很好。
此外作者还进行了各个componets的消融实验,结果如下。首先是token mixer,可以看到即使换成identity mapping精度也达到了74.3%,表明更多起作用的是MetaFormer的整体结构,而不是token mixer的具体形式。本文提出的MLN比原始的LN效果要更好。如果将residual connection和channel mlp去掉,网络直接不收敛了,表明MetaFormer整体结构的重要性。
此外基于pooling的token mixer可以处理更长的序列,而基于attention的token mixer更适合捕获全局信息,因此作者又试验了一种混合结构,即早期stage采用pooling token mixer,后期stage采用attention token mixer,如最后一栏所示,可以看到这种混合结构是有助于提高模型性能的。
代码解析
这里以timm中的实现为例,输入大小为(1, 3, 224, 224),模型选择"poolformer_s24",模型配置如下
def poolformer_s24(pretrained=False, **kwargs) -> MetaFormer:
model_kwargs = dict(
depths=[4, 4, 12, 4],
dims=[64, 128, 320, 512],
downsample_norm=None,
mlp_act=nn.GELU,
mlp_bias=True,
norm_layers=GroupNorm1,
layer_scale_init_values=1e-5,
res_scale_init_values=None,
use_mlp_head=False,
**kwargs)
return _create_metaformer('poolformer_s24', pretrained=pretrained, **model_kwargs)
MetaFormerBlock的代码如下
class MetaFormerBlock(nn.Module):
"""
Implementation of one MetaFormer block.
"""
def __init__(
self,
dim,
token_mixer=Pooling,
mlp_act=StarReLU,
mlp_bias=False,
norm_layer=LayerNorm2d,
proj_drop=0.,
drop_path=0.,
use_nchw=True,
layer_scale_init_value=None,
res_scale_init_value=None,
**kwargs
):
super().__init__()
ls_layer = partial(Scale, dim=dim, init_value=layer_scale_init_value, use_nchw=use_nchw)
rs_layer = partial(Scale, dim=dim, init_value=res_scale_init_value, use_nchw=use_nchw)
self.norm1 = norm_layer(dim)
self.token_mixer = token_mixer(dim=dim, proj_drop=proj_drop, **kwargs)
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.layer_scale1 = ls_layer() if layer_scale_init_value is not None else nn.Identity()
self.res_scale1 = rs_layer() if res_scale_init_value is not None else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = Mlp(
dim,
int(4 * dim),
act_layer=mlp_act,
bias=mlp_bias,
drop=proj_drop,
use_conv=use_nchw,
)
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.layer_scale2 = ls_layer() if layer_scale_init_value is not None else nn.Identity()
self.res_scale2 = rs_layer() if res_scale_init_value is not None else nn.Identity()
def forward(self, x):
x = self.res_scale1(x) + \
self.layer_scale1(
self.drop_path1(
self.token_mixer(self.norm1(x))
)
)
x = self.res_scale2(x) + \
self.layer_scale2(
self.drop_path2(
self.mlp(self.norm2(x))
)
)
return x
其中self.token_mixer为平均池化并且减去输入,代码如下
class Pooling(nn.Module):
"""
Implementation of pooling for PoolFormer: https://arxiv.org/abs/2111.11418
"""
def __init__(self, pool_size=3, **kwargs):
super().__init__()
self.pool = nn.AvgPool2d(
pool_size, stride=1, padding=pool_size // 2, count_include_pad=False)
def forward(self, x):
y = self.pool(x)
return y - x
Modified Layer Normalization可以通过将pytorch自带的GroupNorm的group number设置为1来实现,代码如下
class GroupNorm1(nn.GroupNorm):
""" Group Normalization with 1 group.
Input: tensor in shape [B, C, *]
"""
def __init__(self, num_channels, **kwargs):
super().__init__(1, num_channels, **kwargs)
self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.fast_norm:
return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
else:
return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)