SwinTransformer与Vit细节总结

本文介绍了VisionTransformer(ViT)的基本原理,包括如何将图像切分为patch并进行编码。接着,重点讨论了SwinTransformer,它结合了ResNet和Transformer的特点,通过patchmerging和window-basedmultiheadself-attention(W-MSA)以及shiftedW-MSA实现层次化和减少复杂度。SwinTransformer在保持Transformer的全局信息处理能力的同时,引入了局部连接性和多尺度特征,提高了模型性能。

建议通过标题来快速跳转

Vit (Vision Transformer)

Vit把图片打成了patch,然后过标准的TransformerEncoder,最后用CLS token来做分类。关于怎么打成patch再写一些介绍,假设一个图片是224×224×3,每个patch大小是16×16,那么就会有224×224/(16×16)=196的seq_length,每个patch的维度就是16×16×3=768,这个768再过一个Linear层,最终一个图就可以用196×768表示了,再补个cls token,就成了197×768

Vit的位置编码

作者在文中试了几种方式,发现这几种在分类上效果差不多

  • 1-dimensional positional embedding
  • 2-dimensional positional embedding
  • Relative positional embeddings

Vit少了Inductive bias

In CNNs, locality, two-dimensional neighborhood structure, and translation equivariance are baked into each layer throughout the whole model。卷积和FFN相比主要的优先就是局部连接权值共享

SwinTransformer

SwinTransformer可以看成是披着ResNet外壳的vision transformer,swin 就是两个关键词:patch + 多尺度。下面结合code来说一些重点的细节:

总览图

在这里插入图片描述
这里W-MSA缩写是window-multi head self attention,SW-MSA缩写是shifted window-multi head self attention。整个模型采取层次化的设计,一共包含4个Stage,每个stage都会缩小输入特征图的分辨率,像CNN一样逐层扩大感受野。

  • 在输入开始的时候,做了一个Patch Embedding,将图片切成一个个图块,并嵌入到Embedding
  • 在每个Stage里,由Patch Merging和多个Block组成。其中Patch Merging模块主要在每个Stage一开始降低图片分辨率(把H×W×C转成(H/2)×(W/2)×(2C)),把进而形成层次化的设计,同时也能节省一定运算量。
  • 而Block具体结构如右图所示,主要是LayerNorm,MLP,window-multi head self attention和 shifted window-multi head self attention组成。所以一个Block里至少有两个MSA结构

结合代码实现看更多细节

Patch Embedding

在输入进Block前,我们需要将图片切成一个个patch,然后嵌入向量。采用patch_size * patch_size的窗口大小,通过nn.Conv2d,将stride,kernelsize设置为patch_size大小,patch_size设置为4。值得注意的是SwinTransformer的patch_size×patch_size是4×4,而Vit的patch_size×patch_size是16×16,所以SwinTransformer的序列长度就会长很多,这对于Transformer是吃不消的,因此就有了W-MSA放在一个窗口内减少复杂度

class PatchEmbed(nn.Module):
    def __init__(self,
                 img_size=224,
                 patch_size=4,
                 in_chans=3,
                 embed_dim=96,
                 norm_layer=None):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        patches_resolution = [
            img_size[0] // patch_size[0], img_size[1] // patch_size[1]
        ]
        self.img_size = img_size
        self.patch_size = patch_size
        self.patches_resolution = patches_resolution
        self.num_patches = patches_resolution[0] * patches_resolution[1]

        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.proj = nn.Conv2d(in_chans,
                              embed_dim,
                              kernel_size=patch_size,
                              stride=patch_size)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
        B, C, H, W = x.shape
        # FIXME look at relaxing size constraints
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({
     
     H}*{
     
     W}) doesn't match model ({
     
     self.img_size[0]}*{
     
     self.img_size[1]})."
        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C
        if self.norm is not None:
            x = self.norm(x)
        return x

Patch Merging

这一步用Yi Zhu老师的图最好了,Patch Merging模块主要在每个Stage一开始降低图片分辨率(把H×W×C转成(H/2)×(W/2)×(2C)),把进而形成层次化的设计,同时也能节省一定运算量。
在这里插入图片描述

class PatchMerging(nn.Module):
    r""" Patch Merging Layer.

    Args:
        input_resolution (tuple[int]): Resolution of input feature.
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """
    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({
     
     H}*{
     
     W}) are not even."

        x = x.view(B, H, W, C)

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)

        return x

Window Attention

这部分关键点就两点:

  • 刚才提到的用了Window减少复杂度
  • 加了相对位置编码,把SoftMax(QKT/d)SoftMax(QK^T/\sqrt d)SoftMax(QKT/d )变成SoftMax((QKT+B)/d)SoftMax((QK^T+B)/\sqrt d)Sof
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值