Day 1: Swin Transformer: Hierarchical Vision Transformer using Shifted Window

第一篇论文是最近大火的 Swin Transformer. Swin 应该是 Shifted Windows 的缩写,也是全文最重要的一个贡献之处。
稍微概括一下,本文的主要几个贡献点有:

  • 提出 shifted windows 的概念,在做到仅限于 local self-attention 的同时,将全局也打通了。
  • 以往的 Transformer,包括 Vit 和 DeiT,在计算量上都是随着输入图片的尺寸呈二次方增长,因此在高精度输入图片的情况下非常不实用,会占用大量的内存空间。本论文可以使计算量与图片尺寸呈线性增长,提高了实用性。
  • 此论文有可能成为计算机视觉处理上的一个转折点,有可能在将来替代CNN,成为新的方向。

Introduction

之前 Transformer在视觉领域没有太多出彩的原因有二。

  1. 缩放(scale)。不像在自然语言处理中,每一个单词符号可以作为它的一个基本元素,视觉元素会在大小上有非常大的差异,比如目标检测中,需要检测的目标就可大可小,差异明显。
  2. 和文本句子中的单词相比,图片中的像素有着高得多的像素。这在语义分割中又会带来问题,特别是它的计算复杂度和图片的尺寸是呈成二次方的关系增长。

Core Ideas and Contribution

  1. 提出了“滑动窗口”方法,在让自注意力集中在不重叠的局部窗口的情况下,允许不同窗口间的信息交流,极大地提升了效率。(Proposed a shifted windowing scheme, brought greater efficiency by limiting self-attention computation to non-overlapping local windows while also allowing for cross-window connection.)
  2. 达到了计算复杂度和图片尺寸成线性增长。(It achieved linear computational complexity with respect to image size.)

Methods and Approaches

General Description

Constructs a hierarchical representation by

  • 从图中可以看出,我们先从小尺寸的图片模块开始(图中被灰色边框框住的方块),之后在Transformer更深层中,将它们合并到一起,形成一个分层次的结构。这样的话,之后我们可以利用像特征金字塔这样的高级技巧。(start from small-sized patches (in gray), and gradually merging neighboring patches in deeper Transformer layers.)
  • 我们通过只在每个不重叠的窗口中(图中红色边框的部分)计算自注意力,达到线性复杂度。linear computational complexity is achieved by computing self-attention locally within non-overlapping windows that partition an image (in red)
  • 每个窗口中的图片块(灰色框)数量是固定的,所以计算复杂度是线性的。
    在这里插入图片描述

Shift Window

  • Swin Transformer 最关键的部分就是在连续的自注意力层之间滑动窗口,这些滑动的窗口在前后层之间充当桥梁,提供连接,极大地增强了模型的性能。It bridge the windows of the preceding layer, providing connections among them that significantly enhance modeling power.
  • 所有在同一个 window 内的 query 图片块共享一套 key,使在硬件中对内存的读取更加容易且方便。然而以前的方法在普通的硬件上会有时延问题,因为不同的 query 像素有不同的 key。(这一块理解不够,翻译可能不准)Old sliding window methods suffer from low latency on general hardware due to different key sets for different query pixels.
  • 实际应用中,像普通的“滑动窗口”方法,在内存的读取上效率会很低。it is difficult for a sliding-window based self-attention layer to have efficient memory access in practice
    在这里插入图片描述

Detailed Description

Overall Structure

  • 首先,将输入图片像ViT中一样,分成不重叠的图片块(每一个红色框)。每个图片块被当成一个 token,RGB三个通道的值concat起来作为这个token的特征。在我们的实现中,patch size 是 4 × 4 4 \times 4 4×4,因此每个patch的特征维度是 4 × 4 × 3 = 48 4 \times 4 \times 3 = 48 4×4×3=48splits an input RGB image into non-overlapping patches by a patch splitting module. Each patch is treated as a "token" and its feature is set as a concatenation of the raw pixel RGB value (need to check the code and see how it is done).
    • In this paper, the authors used patch size of 4 x 4, and the feature dimension of each patch is 4 x 4 x 3 = 48. (since they concatenated the RGB channnels, why is it x 3 here?)
  • 用一个线性嵌入层,将这个feature映射到一个指定大小的维度(命名为C)。a linear embedding layer is applied on this raw-valued feature to project it to an arbitrary dimension (denoted as C)
  • 将带有修改后的 self-attention 模块的Transformer应用到这些图片块 token 上,并始终保持token的总数为( H 4 × W 4 \frac{H}{4} \times \frac{W}{4} 4H×4W),且和接下来的线性嵌入层一起被称为“Stage 1”。Transformer blocks with modified self-attention computation (Swin Transformer Blocks) are then applied on these patch tokens, the blocks maintain the number of tokens ( H 4 × W 4 \frac{H}{4} \times \frac{W}{4} 4H×4W), together with the linear embedding are referred as “Stage 1”.
  • 为了形成一个层级结构,当层数越来越深时,我们用一个模块融合层,来减少token的数量。To produce a hierarchical representation, the number of takens is reduced by patch merging layers when the network goes deeper.
  • 第一个融合层将邻近的 2 × 2 2 \times 2 2×2个patch的全集(即邻近的几个红框圈起来的部分,见上图)拼接(concat)在一起,然后在形成的大小为 4 C 4C 4C的特征上(个人理解:原本每个token的特征维度都分别为C。为了减少token的数量(看上图,即从最下面的 4 × 4 4 \times 4 4×4个token到中间的 2 × 2 2 \times 2 2×2个token),当将邻近的几个patch融合后,新加入的patch的特征维度全部叠加(concat)在原来的patch上,因此每个 token 的特征维度从之前的 C C C上升到现在的 4 C 4C 4C,之后请自行看代码确定理解无误),加上一层线性层。这样使得 token 的数量减少了4倍,输出的维度设为 2 C 2C 2C。再之后,Swin Transformer 开始作用在输出上,并将分辨率保持在 H 8 × W 8 \frac{H}{8} \times \frac{W}{8} 8H×8W。第一个融合模块和特征转换一起被称为“Stage 2”。
  • 之后同样的结构重复两次,生成 stage 3 和 stage 4,输出的分辨率分别为( H 16 × W 16 \frac{H}{16} \times \frac{W}{16} 16H×16W) 和( H 32 × W 32 \frac{H}{32} \times \frac {W}{32} 32H×32W)。The first patch merging layer concatenates the features of each group of 2 × 2 neighboring patches, and applies a linear layer on the 4C-dimensional concatenated features. This reduces the number of tokens by a multiple of 2×2 = 4 (2× downsampling of resolution), and the output dimension is set to 2C. Swin Transformer blocks are applied afterwards for feature transformation, with the resolution kept at H 8 × W 8 \frac{H}{8} \times \frac{W}{8} 8H×8W. This block is called “Stage 2”, and is repeated twice, get Stage 3 and Stage 4 with resolution of ( H 16 × W 16 \frac{H}{16} \times \frac{W}{16} 16H×16W) and ( H 32 × W 32 \frac{H}{32} \times \frac {W}{32} 32H×32W), respectively.

Swin Transformer Block

  • 将原始Transformer中的多头注意力模块替换成基于滑动窗口的模块,其它部分保持不变。It is built by replacing the standard multi-head self attention (MSA) module in a Transformer block by a module based on shifted window , with other layers the same.
  • 这个 Swin Transformer Block 由一个基于滑动窗口的自注意力模块,后跟一个两层的MLP(两层MLP之间使用GELU非线性层)组成。 It consists of a shifted window based MSA module, followed by a 2-layer MLP with GELU non-linearity in between.
  • 每个多头注意力层,和每个MLP层之间,都加上一层LayerNorm,模块后面使用残差连接。A LayerNorm (LN) layer is applied before each MSA module and each MLP, a residual connection is also applied after each module.

在这里插入图片描述

Shifted Window based Self-Attention

  • Standard Transformer conduct global self-attention, where the relationship between a token and all other tokens are computed, thus leads to a quadratic complexity with respect to the number of tokens, making it unsuitable for many vision problems. (main reason why previous Transformer models were consuming a lot of memory)

Self-attention in non-overlapped windows

  • They propose it to compute self-attention within local windows. The windows are arranged to evenly partition the image in a non-overlapping manner. Supposing each window contains M × M M \times M M×M patches, then the computation complexity of a global MSA module and a window based one on an image of h × w h \times w h×w patches are:

Ω ( M S A ) = 4 h w C 2 + 2 ( h w ) 2 C Ω ( W − M S A ) = 4 h w C 2 + 2 M 2 h w C \begin{array}{l}\Omega(\mathrm{MSA})=4 h w C^{2}+2(h w)^{2} C \\\Omega(\mathrm{W}-\mathrm{MSA})=4 h w C^{2}+2 M^{2} h w C\end{array} Ω(MSA)=4hwC2+2(hw)2CΩ(WMSA)=4hwC2+2M2hwC

  • 上式可知,前者与size成二次方,后者成线性关系。where the former is quadratic to patch number h w hw hw, and the latter is linear when M M M is fixed (the authors use 7 by default).

Shifted window partitioning in successive blocks

  • the first module uses a regular window partitioning strategy which starts from the top-left pixel, and the 8 × 8 feature map is evenly partitioned into 2 × 2 windows of size 4 × 4 (M = 4). Then, the next module adopts a windowing configuration that is shifted from that of the preceding layer, by displacing the windows by ( ⌊ M 2 ⌋ , ⌊ M 2 ⌋ ) \left(\left\lfloor\frac{M}{2}\right\rfloor,\left\lfloor\frac{M}{2}\right\rfloor\right) (2M,2M) pixels from the regularly partitioned windows.
  • Swin Transformer blocks are computed as

z ^ l = W − M S A ( L N ( z l − 1 ) ) + z l − 1 z l = MLP ⁡ ( L N ( z ^ l ) ) + z ^ l , z ^ l + 1 = S W − M S A ( L N ( z l ) ) + z l z l + 1 = M L P ( LN ⁡ ( z ^ l + 1 ) ) + z ^ l + 1 \begin{array}{l}\hat{\mathbf{z}}^{l}=\mathrm{W}-\mathrm{MSA}\left(\mathrm{LN}\left(\mathbf{z}^{l-1}\right)\right)+\mathbf{z}^{l-1} \\\mathbf{z}^{l}=\operatorname{MLP}\left(\mathrm{LN}\left(\hat{\mathbf{z}}^{l}\right)\right)+\hat{\mathbf{z}}^{l}, \\\hat{\mathbf{z}}^{l+1}=\mathrm{SW}-\mathrm{MSA}\left(\mathrm{LN}\left(\mathbf{z}^{l}\right)\right)+\mathbf{z}^{l} \\\mathbf{z}^{l+1}=\mathrm{MLP}\left(\operatorname{LN}\left(\hat{\mathbf{z}}^{l+1}\right)\right)+\hat{\mathbf{z}}^{l+1}\end{array} z^l=WMSA(LN(zl1))+zl1zl=MLP(LN(z^l))+z^l,z^l+1=SWMSA(LN(zl))+zlzl+1=MLP(LN(z^l+1))+z^l+1

  • Issue: it will result in more windows, from ⌈ h M ⌉ × ⌈ w M ⌉ \left\lceil\frac{h}{M}\right\rceil \times\left\lceil\frac{w}{M}\right\rceil Mh×Mw to ( ⌈ h M ⌉ + 1 ) × ( ⌈ w M ⌉ + 1 ) \left(\left\lceil\frac{h}{M}\right\rceil+1\right) \times \left(\left\lceil\frac{w}{M}\right\rceil+1\right) (Mh+1)×(Mw+1), and some of the windows will be smaller than M × M M \times M M×M. (To make the window size ( M , M M,M M,M) divisible by the feature map size of ( h , w h,w h,w), bottom-right padding is employed on the feature map if needed)

Efficient batch computation for shifted configuration

not sure how it works yet

Found a good explanation here:
https://blog.youkuaiyun.com/jackzhang11/article/details/116274498

  • cyclic-shifting toward the top-left direction
  • After this shift, a batched window may be composed of several sub-windows that are not adjacent in the feature map, so a masking mechanism is employed to limit self-attention computation to within each sub-window

在这里插入图片描述

Relative Position Bias

  • When computing self-attention, we include a relative position bias B ∈ R M 2 × M 2 B \in \mathbb{R}^{M^{2} \times M^{2}} BRM2×M2 to each head in computing similarity:

    Attention ⁡ ( Q , K , V ) = SoftMax ⁡ ( Q K T / d + B ) V \operatorname{Attention}(Q, K, V)=\operatorname{SoftMax}\left(Q K^{T} / \sqrt{d}+B\right) V Attention(Q,K,V)=SoftMax(QKT/d +B)V

    where Q , K , V ∈ R M 2 × d Q, K, V \in \mathbb{R}^{M^{2} \times d} Q,K,VRM2×d are the query, key and value matrices; d d d is the query/key dimension, and M 2 M^{2} M2 is the number of patches in a window. Since the relative position along each axis lies in the range [ − M + 1 , M − 1 ] [-M+1, M-1] [M+1,M1], we parameterize a smaller-sized bias matrix B ^ ∈ R ( 2 M − 1 ) × ( 2 M − 1 ) \hat{B} \in \mathbb{R}^{(2 M-1) \times(2 M-1)} B^R(2M1)×(2M1), and values in B B B are taken from B ^ \hat{B} B^.

### Swin Transformer 的复现教程 #### 1. 模型概述 Swin Transformer 是一种分层视觉变换器 (Hierarchical Vision Transformer),它通过滑动窗口机制构建局部表示并支持跨窗口连接[^1]。该模型的核心组件包括分层设计、移位窗口 (Shifted Window) 和自注意力机制。 #### 2. 数据预处理 数据预处理阶段涉及将输入图像划分为多个 patch,并将其映射到 token 序列中。具体过程如下: - 将输入图像切分成大小为 \(P \times P\) 的 patches。 - 对每个 patch 进行线性嵌入操作,得到初始的 token 表示。 - 使用卷积下采样层进一步减少空间分辨率,形成多尺度特征图。 此阶段通常称为“阶段 1”,其中 transformer 块的数量为 \(H/4 \times W/4\),即每张图片被分解成若干 tokens[^2]。 #### 3. 移位窗口机制 为了提高效率和建模能力,Swin Transformer 引入了移位窗口策略。在标准窗口划分的基础上,每隔一层会调整窗口的位置以引入交叉窗口的信息交互。这种方法显著提升了性能,在 ImageNet-1K 图像分类任务上 top-1 准确率提高了 +1.1%,而在 COCO 目标检测任务中则分别提升 +2.8 box AP 和 +2.2 mask AP[^4]。 #### 4. PyTorch 实现代码 以下是基于 PyTorch 的 Swin Transformer 核心模块实现: ```python import torch from torch import nn class PatchEmbed(nn.Module): """Patch Embedding Layer""" def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96): super().__init__() self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x): x = self.proj(x).flatten(2).transpose(1, 2) return x class Mlp(nn.Module): """Multilayer Perceptron""" def __init__(self, in_features, hidden_features=None, out_features=None): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = nn.GELU() self.fc2 = nn.Linear(hidden_features, out_features) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.fc2(x) return x class WindowAttention(nn.Module): """Window-based Multi-head Self Attention (MSA) module with relative position bias.""" def __init__(self, dim, window_size, num_heads): super().__init__() self.dim = dim self.window_size = window_size self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=True) self.attn_drop = nn.Dropout(0.) self.proj = nn.Linear(dim, dim) def forward(self, x): B_, N, C = x.shape qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) return x class SwinTransformerBlock(nn.Module): """Swin Transformer Block""" def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0): super().__init__() self.input_resolution = input_resolution self.window_size = window_size self.shift_size = shift_size if min(self.input_resolution) <= self.window_size: self.shift_size = 0 self.window_size = min(self.input_resolution) self.norm1 = nn.LayerNorm(dim) self.attn = WindowAttention( dim, window_size=(self.window_size, self.window_size), num_heads=num_heads ) self.norm2 = nn.LayerNorm(dim) mlp_hidden_dim = int(dim * 4) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim) def forward(self, x): H, W = self.input_resolution B, L, C = x.shape assert L == H * W, "input feature has wrong size" shortcut = x x = self.norm1(x) x = x.view(B, H, W, C) # cyclic shift if self.shift_size > 0: shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) else: shifted_x = x # partition windows x_windows = shifted_x.unfold(1, self.window_size, self.window_size)\ .unfold(2, self.window_size, self.window_size) x_windows = x_windows.contiguous().view(-1, self.window_size*self.window_size, C) # attention and projection attn_windows = self.attn(x_windows) attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) # reverse windows shifted_x = attn_windows.permute(0, 1, 2, 3).contiguous().view(B, H, W, C) # reverse cyclic shift if self.shift_size > 0: x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) else: x = shifted_x x = x.view(B, H*W, C) # FFN x = shortcut + x x = x + self.mlp(self.norm2(x)) return x class BasicLayer(nn.Module): """A basic Swin Transformer layer for one stage.""" def __init__(self, dim, depth, num_heads, window_size=7): super().__init__() self.blocks = nn.ModuleList([ SwinTransformerBlock( dim=dim, input_resolution=(window_size, window_size), num_heads=num_heads, window_size=window_size, shift_size=0 if i % 2 == 0 else window_size // 2 ) for i in range(depth)]) def forward(self, x): for blk in self.blocks: x = blk(x) return x class SwinTransformer(nn.Module): """Overall architecture of the Swin Transformer model.""" def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7): super().__init__() self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) dpr = [x.item() for x in torch.linspace(0, 0.1, sum(depths))] self.layers = nn.ModuleList() for i_layer in range(len(depths)): layer = BasicLayer( dim=int(embed_dim * 2 ** i_layer), depth=depths[i_layer], num_heads=num_heads[i_layer], window_size=window_size ) self.layers.append(layer) self.norm = nn.LayerNorm(int(embed_dim * 2 ** (len(depths)-1))) self.avgpool = nn.AdaptiveAvgPool1d(1) self.head = nn.Linear(int(embed_dim * 2 ** (len(depths)-1)), num_classes) if num_classes > 0 else nn.Identity() def forward(self, x): x = self.patch_embed(x) for layer in self.layers: x = layer(x) x = self.norm(x.mean(1)) x = self.head(x) return x ``` #### 5. 训练与验证流程 训练过程中可以使用常见的优化算法(如 Adam 或 SGD),并通过学习率调度器动态调整超参数。对于下游任务(如目标检测或语义分割),可以通过微调预训练权重来加速收敛。 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值