对Swin-Transformer的理解

本文介绍了微软提出的Swin-Transformer,一种针对图像领域设计的Transformer变种,通过滑动窗口和层次结构减少计算量,提升性能。核心内容包括整体结构、 PatchEmbedding、PatchMerging、WindowAttention、相对位置编码和ShiftedWindowAttention,以及TransformerBlock的应用和优化策略。


        Swin-Transformer是微软近期提出将Transformer用到密集图像预测任务中所提出的一种可作为骨干的Transformer骨干,目前正在各大CV领域疯狂屠榜!

在这里插入图片描述

Introduction

        目前Transformer应用到图像领域主要有两大挑战:

  • 视觉目标变化大,在不同的场景下视觉Transfrmer(如VIT)的性能未必好
  • 图像分辨率高,像素点多,Transformer基于全局自注意力的计算导致计算量比较大
    针对上述问题 Swin-Transformer架构被提出,这种架构包含滑窗操作,具有层级设计,其中滑窗操作包括不重叠的local window,和重叠的cross-window。在窗口中计算各自的注意力,这样做的好处是既能引入CNN卷积操作的局限性,另一方面能节省计算量。
    Swin-Transformer

整体结构

在这里插入图片描述
        整个模型采取层次化的设计(主流做法),一共有4个Stage,每个Stage通过Patch Merging来缩小输入特征图的分辨率(这点和CNN一样通过逐层来扩大感受野)。

  • 在输入开始,做一个Patch Embedding/Patch Partition,将图片切成一个个小块,并嵌入到Embedding。
  • 在每个Stage中,由Patch Merging和多个Block(上图右侧) 组成。
  • 其中Patch Merging模块的作用是在每个Stage开头来降低图片分辨率。
  • Block 块如图右所示,主要由LN(LayerNnorm),W MSA(Window Attention),MLP,SW MSA(Shifted Window Attention)
class SwinTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_class=1000,
                 embed_dim = 96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
                 window_size=7, mlp_ratio = 4, qkv_bias = True, qk_scale=None,
                 drop_rate = 0., attn_drop_rate=0., drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm, ape = False, patch_norm=True,
                 use_checkpoint=False, **kwargs):
        super(SwinTransformer, self).__init__()
        self.num_class = num_class
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.ape = ape
        self.patch_norm = patch_norm
        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
        self.mlp_ratio = mlp_ratio

        #split image into non-overlapping patches
        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)
        num_patches = self.patch_embed.num_patches
        patches_resolution = self.patch_embed.patchs_resolution
        self.patches_resloution = patches_resolution

        # absolute position embedding
        if self.ape:
            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
            trunc_normal_(self.absolute_pos_embed, std=.02)
        self.pos_drop = nn.Dropout(p=drop_rate)

        # stochastic depth 随机深度
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]

        # build layer
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = BasicLayer(dim = int(embed_dim * 2 ** i_layer),
                               input_resolution = (patches_resolution[0] // (2 ** i_layer),
                                                 patches_resolution[1] // (2 ** i_layer)),
                               depth = depths[i_layer],
                               num_heads = num_heads[i_layer],
                               window_size = window_size,
                               mlp_ratio = mlp_ratio,
                               qkv_bias = qkv_bias,qk_scale=qk_scale,
                               drop_path = dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                               norm_layer = norm_layer,
                               downsample = PatchMerging if (i_layer < self.num_layers -1) else None,
                               use_checkpoint = use_checkpoint)
            self.layers.appens(layer)
        self.norm = norm_layer(self.num_features)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.apply(self._init_weights)

    def _int_weights(self,m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
    def forward_features(self, x):
        x = self.patch_embed(x)
        if self.ape:
            x += self.absolute_pos_embed
        x = self.pos_drop(x)

        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)
        x = self.avgpool(x.transpose(1, 2))
        x = torch.flatten(x, 1)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        return x

        与VIT的差别之处:

  1. VIT在输入会给Embedding(Patch Embeding前)进行位置编码。而Swin-Transformer中的此处位置编码是个可选项(self.ape=True/False),Swin-Transformer的位置编码是在计算Attention的时候做了个相对位置编码
Patch Embedding(Partition)

        在将图片输入进Block之前,需要将图片切成一个个Patch,然后嵌入向量。
具体做法:将原始图片裁成一个个window_size * window_size的窗口大小,然后进行嵌入向量,这里的做法可以使用二维卷积层,将stride,kernel-size,设置为window_size大小。设定输出通道来确定嵌入向量的大小,然后将H,W维度展开,并移到第一维度

class PatchEmbed
评论 4
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值