对Swin-Transformer的初步理解
Swin-Transformer是微软近期提出将Transformer用到密集图像预测任务中所提出的一种可作为骨干的Transformer骨干,目前正在各大CV领域疯狂屠榜!

Introduction
目前Transformer应用到图像领域主要有两大挑战:
- 视觉目标变化大,在不同的场景下视觉Transfrmer(如VIT)的性能未必好
- 图像分辨率高,像素点多,Transformer基于全局自注意力的计算导致计算量比较大
针对上述问题 Swin-Transformer架构被提出,这种架构包含滑窗操作,具有层级设计,其中滑窗操作包括不重叠的local window,和重叠的cross-window。在窗口中计算各自的注意力,这样做的好处是既能引入CNN卷积操作的局限性,另一方面能节省计算量。
Swin-Transformer
整体结构

整个模型采取层次化的设计(主流做法),一共有4个Stage,每个Stage通过Patch Merging来缩小输入特征图的分辨率(这点和CNN一样通过逐层来扩大感受野)。
- 在输入开始,做一个Patch Embedding/Patch Partition,将图片切成一个个小块,并嵌入到Embedding。
- 在每个Stage中,由Patch Merging和多个Block(上图右侧) 组成。
- 其中Patch Merging模块的作用是在每个Stage开头来降低图片分辨率。
- Block 块如图右所示,主要由LN(LayerNnorm),W MSA(Window Attention),MLP,SW MSA(Shifted Window Attention)
class SwinTransformer(nn.Module):
def __init__(self, img_size=224, patch_size=4, in_chans=3, num_class=1000,
embed_dim = 96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
window_size=7, mlp_ratio = 4, qkv_bias = True, qk_scale=None,
drop_rate = 0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape = False, patch_norm=True,
use_checkpoint=False, **kwargs):
super(SwinTransformer, self).__init__()
self.num_class = num_class
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
self.mlp_ratio = mlp_ratio
#split image into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patchs_resolution
self.patches_resloution = patches_resolution
# absolute position embedding
if self.ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
trunc_normal_(self.absolute_pos_embed, std=.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth 随机深度
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
# build layer
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(dim = int(embed_dim * 2 ** i_layer),
input_resolution = (patches_resolution[0] // (2 ** i_layer),
patches_resolution[1] // (2 ** i_layer)),
depth = depths[i_layer],
num_heads = num_heads[i_layer],
window_size = window_size,
mlp_ratio = mlp_ratio,
qkv_bias = qkv_bias,qk_scale=qk_scale,
drop_path = dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
norm_layer = norm_layer,
downsample = PatchMerging if (i_layer < self.num_layers -1) else None,
use_checkpoint = use_checkpoint)
self.layers.appens(layer)
self.norm = norm_layer(self.num_features)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.apply(self._init_weights)
def _int_weights(self,m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward_features(self, x):
x = self.patch_embed(x)
if self.ape:
x += self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x)
x = self.norm(x)
x = self.avgpool(x.transpose(1, 2))
x = torch.flatten(x, 1)
return x
def forward(self, x):
x = self.forward_features(x)
return x
与VIT的差别之处:
- VIT在输入会给Embedding(Patch Embeding前)进行位置编码。而Swin-Transformer中的此处位置编码是个可选项(self.ape=True/False),Swin-Transformer的位置编码是在计算Attention的时候做了个相对位置编码
Patch Embedding(Partition)
在将图片输入进Block之前,需要将图片切成一个个Patch,然后嵌入向量。
具体做法:将原始图片裁成一个个window_size * window_size的窗口大小,然后进行嵌入向量,这里的做法可以使用二维卷积层,将stride,kernel-size,设置为window_size大小。设定输出通道来确定嵌入向量的大小,然后将H,W维度展开,并移到第一维度
class PatchEmbed

本文介绍了微软提出的Swin-Transformer,一种针对图像领域设计的Transformer变种,通过滑动窗口和层次结构减少计算量,提升性能。核心内容包括整体结构、 PatchEmbedding、PatchMerging、WindowAttention、相对位置编码和ShiftedWindowAttention,以及TransformerBlock的应用和优化策略。
最低0.47元/天 解锁文章
4071





