Swin Transformer—— 基于Transformer的图像识别模型

SwinTransformer通过引入移动窗口和层级结构,改进了ViT在图像识别中的性能,通过小尺寸补丁计算注意力实现精细特征提取和较低计算成本。它在多个视觉任务中表现出色,且被广泛应用于多模态模型的backbone。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

概述

Swin Transformer是微软研究院于2021年在ICCV上发表的一篇论文,因其在多个视觉任务中的出色表现而被评为当时的最佳论文。它引入了移动窗口的概念,提出了一种层级式的Vision Transformer,将Shifted Windows(移动窗口)作为其主要贡献。这个概念使得Swin Transformer可以像卷积神经网络一样进行分块,并进行层级式的特征提取,从而在特征表示中引入多尺度的概念。

在OpenAI发布的Sora中也出现了视频patches的概念,这进一步表明了Vision Transformer和Swin Transformer在引入patch概念方面的重要性。目前,许多多模态模型的backbone都采用了这两种模型,因此理解和应用它们的原理对于掌握和应用这些优秀的多模态模型非常必要。

在 Swin Transformer之前,基于Transformer的图像识别模型是视觉变换器(ViT)。它将图像视为由 16x16 个单词组成的句子,是自然语言处理中使用的变换器在图像识别中的首次应用。

本文指出了文本和图像之间的差异,并提出了 Swin Transformer,使 ViT 更适应图像领域。

文字和图像的两个区别如下。

  • 与文字符号不同,图像中的视觉元素在比例上差异很大
  • 图像中的像素比文件中的文字具有更高的分辨率(更多信息)。

为了消除这些差异

  • 计算不同贴片尺寸下的关注度
  • 用较小的补丁尺寸计算关注度。

下图说明了 ViT 和 Swin Transformer在这些方面的区别。

用较小的斑块尺寸计算注意力可以获得精细的特征,但计算成本较高。

这就是在 Swin 变换器中引入基于移位窗口的自注意的原因。多个补丁被合并到一个窗口中,注意力计算只在该窗口中进行,从而减少了计算量。

在下一节中,我们将了解斯温变换器的整体情况,然后了解一些更微小的细节,包括基于移位窗口的自我关注。

Swin Transformer

大画面

下面是 Swin Transformer的全貌。

首先,对输入图像进行补丁分割。

补丁分割:将 4x4 像素分割为一个补丁;由于 ViT将 16x16像素作为一个补丁,因此斯温变换器可以提取更精细的特征。

然后进行线性嵌入。

线性嵌入:将补丁(4x4x3ch)转换为 C 维标记,其中 C 取决于模型的大小。

对于从每个补丁中获得的标记,Swin Transformer Block 会计算关注度并进行特征提取。

Swin Transformer区块:用基于移位窗口的自保持(W-MSA 和 SW-MSA)取代常规变压器区块中使用的多头自保持(MSA)。以下章节将提供更多信息。下文将对它们进行更详细的介绍。其他配置与普通变压器几乎完全相同。

目前看到的线性嵌入和变换块部分被称为第 1 阶段;共有 1 到 4 个阶段,但每个阶段的补丁大小不同,因此可以在不同尺度上进行特征提取。不同大小的补丁是由补丁合并(Patch Merging)产生的,它将邻域中的补丁聚合在一起。

补丁合并:在每个阶段,相邻的(2 × 2)补丁(标记)合并在一起,形成一个标记。具体来说,合并 2 × 2 标记,并通过线性层将所得的 4C 维向量变为 2C 维。例如,在第 2 阶段,(H/4)×(W/4)×C 维度被简化为 (H/8)×(W/8)×2C 维度。

基于移动窗口的自我关注

从计算复杂度的角度解释了普通变压器和斯温变压器模块注意力计算的区别。

法线变换器计算所有标记之间的距离,其中 h 和 w 是图像中垂直和水平斑块的数量,计算量如下

另一方面,Swin 变换器只计算由多个补丁组成的窗口内的关注度:一个窗口包含 M x M 个补丁,基本固定为 M = 7。计算复杂度如下式所示。

在普通变换器中,计算复杂度的增加与补丁数 (hw) 的平方成正比。然而,由于 M = 7,影响很小,即使是补丁数 (hw) 的增加也保持在幂级数以内。这使得 Swin变换器可以计算小尺寸的贴片。

接下来介绍将图像划分为窗口的方法:窗口的排列方式是将图像平均划分为 M x M 个补丁。以这种方式排列的每个窗口都会计算注意力,因此即使是相邻的补丁,如果它们是不同的窗口,也不会计算注意力。为了解决窗口边界问题,在计算第一个注意力(W-MSA:基于窗口的多头自注意力)后,窗口会被移动,注意力会被再次计算(SW-.MSA:基于移动窗口的多头自注意)。

如下图所示,在原窗口分割的基础上移动([M/2], [M/2])个像素。


代码模型:

class PatchEmbed(nn.Module):
    
    def __init__(self, patch_size=4, in_c=3, embed_dim=96, norm_layer=None):
        super().__init__()
        patch_size = (patch_size, patch_size)
        self.patch_size = patch_size
        self.in_chans = in_c
        self.embed_dim = embed_dim
        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        _, _, H, W = x.shape

        # 如果输入图片的 H,W 不是patch_size的整数倍,需要进行padding
        pad_input = (H % self.patch_size[0] != 0) or (W % self.patch_size[1] != 0)
        if pad_input:
            # to pad the last 3 dimensions, (W_left, W_right, H_top,H_bottom, C_front, C_back)
            x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1],
                          0, self.patch_size[0] - H % self.patch_size[0],
                          0, 0))

        # 下采样patch_size倍
        x = self.proj(x)
        _, _, H, W = x.shape
        # flatten: [B, C, H, W] -> [B, C, HW];  transpose: [B, C, HW] -> [B, HW, C]
        x = x.flatten(2).transpose(1, 2)
        x = self.norm(x)
        return x, H, W

移位配置的高效批量计算

SW-MSA 的窗口大小不同,窗口数量也会增加。因此,如果直接进行处理,就会出现计算量比 W-MSA 增加的问题。因此,在 SW-MSA 中,使用一种称为循环移动的方法进行伪操作,而不是实际改变窗口的排列。

如下图所示,整个图像向左上方移动,溢出区域插入空白区域 (循环移动 )。通过这种方法,它的计算方法与 W-MSA 窗口中的 Attention 计算方法相同。此外,由于窗口中可能包含不相邻的斑块,因此要对这些部分进行掩膜处理。在最终输出中,将执行循环移位的反向操作(反向循环移位),将补丁恢复到原始位置。


代码实现:

class PatchMerging(nn.Module):

    def __init__(self, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x, H, W):
        """
        x: B, H*W, C
        """
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        x = x.view(B, H, W, C)

        # 如果输入feature map的H,W不是2的整数倍,需要进行padding
        pad_input = (H % 2 == 1) or (W % 2 == 1)
        if pad_input:
            # to pad the last 3 dimensions, starting from the last dimension and moving forward.
            # (C_front, C_back, W_left, W_right, H_top, H_bottom)
            # 注意这里的Tensor通道是[B, H, W, C],所以会和官方文档有些不同
            x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))

        x0 = x[:, 0::2, 0::2, :]  # [B, H/2, W/2, C]
        x1 = x[:, 1::2, 0::2, :]  # [B, H/2, W/2, C]
        x2 = x[:, 0::2, 1::2, :]  # [B, H/2, W/2, C]
        x3 = x[:, 1::2, 1::2, :]  # [B, H/2, W/2, C]
        x = torch.cat([x0, x1, x2, x3], -1)  # [B, H/2, W/2, 4*C]
        x = x.view(B, -1, 4 * C)  # [B, H/2*W/2, 4*C]

        x = self.norm(x)
        x = self.reduction(x)  # [B, H/2*W/2, 2*C]

        return x

结构变体

Swin 变压器有 T、S、B 和 L 四种尺寸,每级的尺寸(dim)、头(head)和块数各不相同,如下表所示。

试验

在 ImageNet-1K 的图像识别任务、COCO 的物体检测任务和 ADE20K 的语义分割任务中与其他模型进行了比较,结果都达到了最高准确率。(实验结果详见本文第四章表 1~表 3)。

在 SW-MSA 中进行的消融研究证实,在这两项任务中,引入 SW-MSA 的准确率都高于单独引入 W-MSA。

摘要

与在所有斑块之间计算注意力的 ViT 不同,注意力计算和斑块聚合可以在相邻斑块的窗口中重复进行,从而可以在不同尺度上提取特征。另一个优点是不在所有斑块之间计算注意力,从而降低了计算复杂度,并能从较小的斑块尺寸中提取特征。

### 使用Transformer模型进行图像识别的应用 #### 应用实例 在计算机视觉领域,Transformer模型已经广泛应用于多种任务中。对于图像分类而言,Transformer能够处理复杂的图像数据并将其映射到预定义的类别上[^2]。具体来说,在给定一幅图片的情况下,经过训练后的Transformer网络会预测该图所属的具体类目。 除了简单的分类外,目标检测也是另一个重要的应用场景。通过利用自注意力机制捕捉全局上下文信息的能力,使得基于Transformer的目标定位更加精准可靠;它不仅限于框选物体边界框位置坐标,还能提供更细致入微的空间分布情况说明。 此外,借助强大的生成对抗网络(GANs),结合编码器-解码器架构下的Transformers还可以创造出逼真的合成影像或是编辑已有素材,从而满足不同创意需求。 #### 实现方法概述 为了更好地适应二维空间内的像素排列特点以及保持局部结构不变性,研究者们提出了专门针对视觉任务优化过的变体——Vision Transformers (ViT)[^4] 和 Swin Transformers 。这类改进型框架通常采用分片(patch)策略将原始输入切分成若干子区域后再送入后续层间传递计算流程之中,以此达到降维增效的目的同时保留必要的语义关联度。 下面是一个简化版的Python代码片段展示如何构建一个基本的 Vision Transformer 模型来进行图像分类: ```python import torch.nn as nn from transformers import ViTModel, ViTConfig class ImageClassifier(nn.Module): def __init__(self, num_labels=10): super().__init__() configuration = ViTConfig(image_size=224, patch_size=16, num_channels=3, hidden_size=768) self.vit = ViTModel(configuration) self.classifier = nn.Linear(768, num_labels) def forward(self, pixel_values): outputs = self.vit(pixel_values=pixel_values).last_hidden_state[:, 0] logits = self.classifier(outputs) return logits ``` 此段程序创建了一个继承自 `nn.Module` 的新类 `ImageClassifier` ,其中包含了两个主要组件:一个是负责提取特征表示向量序列的 VIT 预训练骨干网路(`vit`);另一个则是用来完成最终决策判断工作的全连接线性变换层 (`classifier`). 当接收到一批次待测样本时,先调用前者获取对应每一幅画作的整体印象描述符,再经后者转换成概率得分矩阵形式输出供下游解析使用.
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

知来者逆

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值