目录
前言摘要
文章提出了一种新的ViT(Vision Transformer)作为计算机视觉任务的通用主干。而为了解决图像与NLP在数据规模和分辨率上存在的差异,设计了一种类似于ResNet等传统卷积网络类似的分层(Stage)结构,对于不同尺度的目标更具灵活性。同时引入滑窗(Shifted Window)来进行非重叠局部窗口的自注意力计算;滑窗也允许跨窗口patch的连接,在降低计算量的同时实现不同窗口区域内容的交互。
论文名称: Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
论文地址: ICCV 2021 open access
代码地址: https://github.com/microsoft/Swin-Transformer
代码源自@太阳花的小绿豆,特别感谢!
导师代码地址: GitHub
导师代码讲解: https://www.bilibili.com/video/BV1yg411K7Yc
一、网络总体结构及代码框架
网络的总体结构如Fig.1所示,首先通过tokenization方法(由Patch Paritition
和Linear Embedding
组成)将输入图像生成token;而后通过四个不同的stage来构建尺度不同的特征图针对下游任务,每个stage中包含W-MSA
模块,将特征图划分成了多个不相交的窗体(Window),且MSA注意力交互只在每个窗体(Window)内进行。相对于ViT对全局进行Multi-Head Self-Attention能够减少计算量,尤其是在浅层特征图分辨率很大的时候。然而W-MSA阻碍不同窗口之间的信息传递,所以文章也提出了SW-MSA
模块,通过此方法能够实现跨窗口的信息交互;同时在不同stage间作者提出了Patch Mergring
下采样方法实现对token的下采样。
- Patch_Embed:三通道彩色图像在输入网络前需要token化。类似于ViT,对于第一个stage前的Patch Paritition和Linear Embedding采用Patch_Embed方法统一实现,具体是通过一个卷积层并展平(flatten)完成。
- Patch_Mergring:由于网络是类似于深度卷积的层次化stage结构,在每一个stage输出后需要对特征进行下采样,尺度减半通道数翻倍;步骤是对token重构的特征进行隔行采样并拼接得到4块patch,而后通过LN层和线性映射压缩通道数。
- Swin Transformer Block:滑窗注意力W-MSA及SW-MSA构成位置,这两个结构是先后成对使用的。基本结构和组成类似于ViT的transformer Block,包括MSA后的FNN当中的MLP和LN等;同时针对相应问题引入相对位置偏差和窗体分块掩码。
模型实现代码结构如Fig.2所示,下文会对具体模块和代码部分进行分析解读。
二、各部分方法&代码解析
1. Patch_Embed
文章结构图当中的Patch Paritition和Linear Embedding部分实际由PatchEmbed类来实现。初始化类时对patch_size尺寸、输入图像通道数in_channels和embed_dim维度进行定义。通过Conv卷积操作来实现嵌入。卷积核大小和步长都是patch_size,卷积核个数为embed_dim。输入维度为(B,C,H,W)
,输出维度为(B,H'*W',embed_dim)
。代码如下:
class PatchEmbed(nn.Module):
"""
2D Image to Patch Embedding
"""
def __init__(self, patch_size=4, in_c=3, embed_dim=96, norm_layer=None):
super().__init__()
patch_size = (patch_size, patch_size)
self.patch_size = patch_size # 每块patch的尺寸,也是卷积核尺寸和步长
self.in_chans = in_c # 输入图像的维度,通道数
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
_, _, H, W = x.shape
# padding
# 如果输入图片的H,W不是patch_size的整数倍,需要进行padding
pad_input = (H % self.patch_size[0] != 0) or (W % self.patch_size[1] != 0)
if pad_input:
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1],
0, self.patch_size[0] - H % self.patch_size[0],
0, 0))
# 下采样patch_size倍
x = self.proj(x)
_, _, H, W = x.shape
# flatten: [B, C, H, W] -> [B, C, HW]
# transpose: [B, C, HW] -> [B, HW, C]
x = x.flatten(2).transpose(1, 2) # 展平并交换维度
x = self.norm(x)
return x, H, W
2. Patch Merging
除第一stage以外,每个stage阶段在Swin Transformer Block前都需进行