Swin Transformer 结构&代码解析学习

本文深入解析Swin Transformer的结构与代码,包括Patch Embed、Patch Merging和Swin Transformer Block,重点介绍Window Multi-head Self Attention (W-MSA)和Shifted Window Multi-head Self Attention (SW-MSA),以及它们如何在降低计算量的同时增强计算机视觉任务的表现。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

在这里插入图片描述



前言摘要

  文章提出了一种新的ViT(Vision Transformer)作为计算机视觉任务的通用主干。而为了解决图像与NLP在数据规模和分辨率上存在的差异,设计了一种类似于ResNet等传统卷积网络类似的分层(Stage)结构,对于不同尺度的目标更具灵活性。同时引入滑窗(Shifted Window)来进行非重叠局部窗口的自注意力计算;滑窗也允许跨窗口patch的连接,在降低计算量的同时实现不同窗口区域内容的交互。

论文名称: Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
论文地址: ICCV 2021 open access
代码地址: https://github.com/microsoft/Swin-Transformer

代码源自@太阳花的小绿豆,特别感谢!
导师代码地址: GitHub
导师代码讲解: https://www.bilibili.com/video/BV1yg411K7Yc


一、网络总体结构及代码框架

  网络的总体结构如Fig.1所示,首先通过tokenization方法(由Patch ParititionLinear Embedding组成)将输入图像生成token;而后通过四个不同的stage来构建尺度不同的特征图针对下游任务,每个stage中包含W-MSA模块,将特征图划分成了多个不相交的窗体(Window),且MSA注意力交互只在每个窗体(Window)内进行。相对于ViT对全局进行Multi-Head Self-Attention能够减少计算量,尤其是在浅层特征图分辨率很大的时候。然而W-MSA阻碍不同窗口之间的信息传递,所以文章也提出了SW-MSA模块,通过此方法能够实现跨窗口的信息交互;同时在不同stage间作者提出了Patch Mergring下采样方法实现对token的下采样。

Fig.1 Swin Transformer网络结构
  • Patch_Embed:三通道彩色图像在输入网络前需要token化。类似于ViT,对于第一个stage前的Patch Paritition和Linear Embedding采用Patch_Embed方法统一实现,具体是通过一个卷积层并展平(flatten)完成。
  • Patch_Mergring:由于网络是类似于深度卷积的层次化stage结构,在每一个stage输出后需要对特征进行下采样,尺度减半通道数翻倍;步骤是对token重构的特征进行隔行采样并拼接得到4块patch,而后通过LN层和线性映射压缩通道数。
  • Swin Transformer Block:滑窗注意力W-MSA及SW-MSA构成位置,这两个结构是先后成对使用的。基本结构和组成类似于ViT的transformer Block,包括MSA后的FNN当中的MLP和LN等;同时针对相应问题引入相对位置偏差和窗体分块掩码。

在这里插入图片描述

Fig.2 Swin Transformer模型代码调用结构

模型实现代码结构Fig.2所示,下文会对具体模块和代码部分进行分析解读。


二、各部分方法&代码解析

1. Patch_Embed

  文章结构图当中的Patch ParititionLinear Embedding部分实际由PatchEmbed类来实现。初始化类时对patch_size尺寸、输入图像通道数in_channels和embed_dim维度进行定义。通过Conv卷积操作来实现嵌入。卷积核大小和步长都是patch_size,卷积核个数为embed_dim。输入维度为(B,C,H,W),输出维度为(B,H'*W',embed_dim)。代码如下:

class PatchEmbed(nn.Module):
    """
    2D Image to Patch Embedding
    """
    def __init__(self, patch_size=4, in_c=3, embed_dim=96, norm_layer=None):
        super().__init__()
        patch_size = (patch_size, patch_size) 
        self.patch_size = patch_size # 每块patch的尺寸,也是卷积核尺寸和步长
        self.in_chans = in_c # 输入图像的维度,通道数
        self.embed_dim = embed_dim
        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        _, _, H, W = x.shape
        # padding
        # 如果输入图片的H,W不是patch_size的整数倍,需要进行padding
        pad_input = (H % self.patch_size[0] != 0) or (W % self.patch_size[1] != 0)
        if pad_input:
            x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1],
                          0, self.patch_size[0] - H % self.patch_size[0],
                          0, 0))
        # 下采样patch_size倍
        x = self.proj(x)
        _, _, H, W = x.shape
        # flatten: [B, C, H, W] -> [B, C, HW]
        # transpose: [B, C, HW] -> [B, HW, C]
        x = x.flatten(2).transpose(1, 2) # 展平并交换维度
        x = self.norm(x) 
        return x, H, W

2. Patch Merging

在这里插入图片描述

Fig.3 Patch Merging实现步骤

  除第一stage以外,每个stage阶段在Swin Transformer Block前都需进行

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值