原文
论文链接:Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
源码地址:https://github.com/microsoft/Swin-Transformer
引言
目前Transformer应用到图像领域主要有两大挑战:
- 视觉实体变化大,在不同场景下视觉Transformer性能未必很好。即目标尺寸多变。不像NLP任务中token大小基本相同,目标检测中的目标尺寸不一,用单层级的模型很难有好的效果。
- 图像分辨率高,像素点多,Transformer基于全局自注意力的计算导致计算量较大。即图片的高分辨率。尤其是在分割任务中,高分辨率会使得计算复杂度呈现输入图片大小的二次方增长,这显然是不能接受的。
针对上述两个问题,我们提出了一种包含滑窗操作,具有层级设计的Swin Transformer。
其中滑窗操作包括不重叠的local window,和重叠的cross-window。将注意力计算限制在一个窗口中,一方面能引入CNN卷积操作的局部性,另一方面能节省计算量。
在各大图像任务上,Swin Transformer都具有很好的性能。
整体架构
我们先看下Swin Transformer的整体架构
整个模型采取层次化的设计,一共包含4个Stage,每个stage都会缩小输入特征图的分辨率,像CNN一样逐层扩大感受野。
- 在输入开始的时候,做了一个Patch Embedding,将图片切成一个个图块,并嵌入到Embedding。
- 在每个Stage里,由Patch Merging和多个Block组成。
- 其中Patch Merging模块主要在每个Stage一开始降低图片分辨率。
- 而Block具体结构如右图所示,主要是LayerNorm,MLP,Window Attention 和 Shifted Window Attention组成 (为了方便讲解,我会省略掉一些参数)
它首先通过像ViT一样的分片模块将输入的RGB图像分片成不重叠的像素块。每个像素块被视为一个“token”,其特征被设置为原始像素RGB值的串联。我们使用的像素块是4×4的size,所以其特征维度为4×4×3=48。在这个原始值特征上应用一个线性嵌入层,将其投影到任意维(表示为C)。
在stage1中,几个Swin Transformer blocks算子被应用于这些像素块上。这些 Transformer blocks保持了H4×W4的tokens数量,并且伴随线性的嵌入层。
stage2中,为了产生一个层次化的表示,由于像素块的合并使得tokens的数量减少了。第一次patch merging layer合并了2×2领域内的像素块,并且使用一个线性层在4C的特征上进行合并。这个操作减少了2×2=4倍的tokens,并设置输出的维度为2C。这里的Transformer blocks应用于特征变换后,tokens的数量变为H8×W8。这第一个像素块融合和特征变换被称为stage2。这种操作进行叠加产生了stage3、stage4,如图所示,tokens的数量分别为: H 16 \cfrac H{16} 16H× H 16 \cfrac H{16} 16H、 H 32 \cfrac H{32} 32H× H 32 \cfrac H{32} 32H。这些阶段共同产生一个层次表示,具有与典型卷积网络相同的特征图分辨率,如VGG [51] and ResNet [29]。结果表明,该体系结构可以很方便地取代现有方法中的backbone,用于各种视觉任务。
class SwinTransformer(nn.Module):
def __init__(...):
super().__init__()
...
# absolute position embedding
if self.ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(...)
self.layers.append(layer)
self.norm = norm_layer(self.num_features)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x)
x = self.norm(x) # B L C
x = self.avgpool(x.transpose(1, 2)) # B C 1
x = torch.flatten(x, 1)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
其中有几个地方处理方法与ViT不同:
- ViT在输入会给embedding进行位置编码。而Swin-T这里则是作为一个可选项(self.ape),Swin-T是在计算Attention的时候做了一个相对位置编码
- ViT会单独加上一个可学习参数,作为分类的token。而Swin-T则是直接做平均,输出分类,有点类似CNN最后的全局平均池化层
接下来我们看下各个组件的构成
Patch Embedding
在输入进Block前,我们需要将图片切成一个个patch,然后嵌入向量。
具体做法是对原始图片裁成一个个 patch_size * patch_size的窗口大小,然后进行嵌入。
这里可以通过二维卷积层,将stride,kernelsize设置为patch_size大小。设定输出通道来确定嵌入向量的大小。最后将H,W维度展开,并移动到第一维度
import torch
import torch.nn as nn
class PatchEmbed(nn.Module):
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size) # -> (img_size, img_size)
patch_size = to_2tuple(patch_size) # -> (patch_size, patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
# 假设采取默认参数
x = self.proj(x) # 出来的是(N, 96, 224/4, 224/4)
x = torch.flatten(x, 2) # 把HW维展开,(N, 96, 56*56)
x = torch.transpose(x, 1, 2) # 把通道维放到最后 (N, 56*56, 96)
if self.norm is not None:
x = self.norm(x)
return x
Patch Merging
该模块的作用是在每个Stage开始前做降采样,用于缩小分辨率,调整通道数 进而形成层次化的设计,同时也能节省一定运算量。
在CNN中,则是在每个Stage开始前用stride=2的卷积/池化层来降低分辨率。
每次降采样是两倍,因此在行方向和列方向上,间隔2选取元素。
然后拼接在一起作为一整个张量,最后展开。此时通道维度会变成原先的4倍(因为H,W各缩小2倍),此时再通过一个全连接层再调整通道维度为原来的两倍。
class PatchMerging(nn.Module):
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H %