机器学习周报(12.9-12.15)

摘要

本篇博客介绍了采用类似于卷积核的移动窗口进行图像特征提取的Swin Transformer网络模型,这是一种基于Transformer的计算机视觉模型,通过引入类似CNN卷积核的Shifted Windows(移动窗口),其通过将 self-attention 计算限制在不重叠的局部窗口上,由于窗口大小固定,每个窗口内计算复杂度固定,计算复杂度跟图像大小呈线性关系,从而带来更高的效率。Swin Transformer Block的核心机制,包括W-MSA(窗口多头自注意力)和SW-MSA(移动窗口多头自注意力),这些机制从全局的自注意力变为了窗口内的自注意力,使得模型在减少计算量的同时捕捉多尺度特征。最后,通过使用预训练模型对花类数据集进行微调;通过预测代码和结果展示了模型对图像分类的高准确率。

Abstract

This blog post introduces the Swin Transformer network model, which adopts shifting windows similar to convolutional kernels for image feature extraction. This computer vision model, based on the Transformer architecture, introduces Shifted Windows akin to CNN convolutional kernels. By limiting self-attention computations to non-overlapping local windows, and with a fixed window size, the computational complexity within each window is constant, and the overall computational complexity is linear with respect to the image size, thus offering greater efficiency. The core mechanisms of the Swin Transformer Block include W-MSA (Window Multi-Head Self-Attention) and SW-MSA (Shifted Window Multi-Head Self-Attention), which transition from global self-attention to window-internal self-attention, allowing the model to capture multi-scale features while reducing computational load. Finally, the post demonstrates the model’s high accuracy in image classification by fine-tuning a pre-trained model on a flower dataset; the prediction code and results showcase the model’s high accuracy in image classification tasks.

1 Swin Transformer

Swin Transformer提出了类似CNN卷积核的Shifted Windows(移动窗口),以减少patch之间做自注意力的次数,达到节省时间的目的。

接下来我将参照Swin Transformer模型整体结构图来讲解该模型,网络结构图如下所示:

在这里插入图片描述

1.1 输入

以输入图像大小为 224x224 为例,将图像tensor 224x224x3 从1处传入模型作为开始。

1.2 Patch Partition

到达第2步Patch Partiton,因为是以Transformer为基础,如同ViT一样需要将图片打成patch传入模型。这里将图像大小 224x224 通过卷积核大小为 4x4,步伐为4的卷积将图像转换为 56x56 的大小,通道数从3变为3x4x4=48维。

1.3 Linear Embedding

接下来数据将传入Swin Transformer最重要的第3大部分,首先在第一个stage会经过全连接层,这里的通道数在论文中是设置好的,即C=96。如上图所示,56x56x48 的tensor经过Linear Embedding维度变化为56x56x96。

1.4 Patch Merging

这里先说一下Patch Merging,因为只要第一个stage先经过Linear Embedding,后3个stage都是Patch Merging。这里的Patch Merging类似于CNN中的池化,为了使窗口获得多尺度的特征,需要在每一个block之前进行下采样,这样在窗口大小不变的情况下图像变小、维度增加,可以获得更多的特征信息。如下图所示:

在这里插入图片描述

论文中提到将图像下采样2倍,在Patch Merging中会隔一个点采样。细心的人这时会发现,我们图像的维度从 4x4x1 变为了 2x2x4,但是在模型中是从变为,即大小缩小一半,维度扩大2倍。但是在上图的操作中维度扩大了4倍,论文中为了使其与CNN中池化保持一致,在Patch Merging之后会在再通过一次卷积使维度缩小一半,以达到大小缩小一半,维度扩大2倍的效果。

到这一步Swin Transformer已经拥有了感知多尺度特征的手段。

1.5 Swin Transformer Block

在这里插入图片描述
Swin Transformer中最重要的模块,这里是用到了Transformer中的编码器,但是做了一些改动,Transformer中是单纯的多头自注意力(MSA),而这里用到了W-MSA和SW-MSA。

  1. W-MSA
    引入Windows Multi-head Self-Attention(窗口多头自注意力)就是为了解决开篇提到的计算量问题。W-MSA会将图像划分为 MxM (在论文中M默认为7)的不重叠窗口,然后多头自注意力只在同一个窗口内做,这样就从全局的自注意力变为了窗口内的自注意力,大大减小了计算量

在这里插入图片描述
如上公式,若图像长宽为 224x224 、M=7,则从 22 4 4 224^4 2244降为 22 4 4 × 49 224^4×49 2244×49

  1. SW-MSA

    如果只进行W-MSA会发现一个问题,窗口之间是没用通信的,也就是窗口与窗口之间是独立计算的,就不能像卷积一样移动去感受不同像素范围。于是,作者引入了Shifted Windows Multi-Head Self-Attention(SW-MSA)模块让窗口移动起来。

在这里插入图片描述

例如上图所示,将图像分为了4个窗口,在W-MSA中4个窗口分别做MSA;然后在下一层SW-MSA左上角的窗口会向右和向下移动个步伐。如右图所示将图像分割为9个窗口,会发现上一层不属于同一个窗口的patch,经过移动后融合为一个窗口,窗口与窗口之间成功进行了信息交流。

如果9个窗口分别进行SW-MSA,那么从上一层4个相同的窗口变为9个不完全相同的窗口进行计算,不但不利于并行计算,而且窗口数量也增大两倍多,通过W-MSA积攒的计算优势一去不复返。

那么该如何解决呢?有人想到padding补0,这样虽然可以保证窗口大小相同进行并行计算,但是数量是增多的。论文中,作者提出通过掩码的方式进行自注意力的计算

在这里插入图片描述
A、B、C经过顺时针循环位移,将图像窗口补为大小相同的4个窗口。但是同一个窗口中,颜色不同的部分不能直接进行自注意力,需要在计算之后分别乘上对应的掩码。

在这里插入图片描述

掩码是加上一个很大的负值,在经过softmax之后该部分就会变为0。1的部分全属于同一窗口可直接进行自注意力计算。

1.6 代码

Swin Transformer网络模型如下:

""" Swin Transformer
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`
    - https://arxiv.org/pdf/2103.14030
Code/weights from https://github.com/microsoft/Swin-Transformer
"""
 
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
import numpy as np
from typing import Optional
 
 
def drop_path_f(x, drop_prob: float = 0., training: bool = False):
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
    'survival rate' as the argument.
    """
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    output = x.div(keep_prob) * random_tensor
    return output
 
 
class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob
 
    def forward(self, x):
        return drop_path_f(x, self.drop_prob, self.training)
 
 
def window_partition(x, window_size: int):
    """
    将feature map按照window_size划分成一个个没有重叠的window
    Args:
        x: (B, H, W, C)
        window_size (int): window size(M)
    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    # permute: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H//Mh, W//Mh, Mw, Mw, C]
    # view: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B*num_windows, Mh, Mw, C]
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows
 
 
def window_reverse(windows, window_size: int, H: int, W: int):
    """
    将一个个window还原成一个feature map
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size(M)
        H (int): Height of image
        W (int): Width of image
    Returns:
        x: (B, H, W, C)
    """
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    # view: [B*num_windows, Mh, Mw, C] -> [B, H//Mh, W//Mw, Mh, Mw, C]
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    # permute: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B, H//Mh, Mh, W//Mw, Mw, C]
    # view: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H, W, C]
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x
 
 
class PatchEmbed(nn.Module):
    """
    2D Image to Patch Embedding
    """
    def __init__(self, patch_size=4, in_c=3, embed_dim=96, norm_layer=None):
        super().__init__()
        patch_size = (patch_size, patch_size)
        self.patch_size = patch_size
        self.in_chans = in_c
        self.embed_dim = embed_dim
        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
 
    def forward(self, x):
        _, _, H, W = x.shape
 
        # padding
        # 如果输入图片的H,W不是patch_size的整数倍,需要进行padding
        pad_input = (H % self.patch_size[0] != 0) or (W % self.patch_size[1] != 0)
        if pad_input:
            # to pad the last 3 dimensions,
            # (W_left, W_right, H_top,H_bottom, C_front, C_back)
            x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1],
                          0, self.patch_size[0] - H % self.patch_size[0],
                          0, 0))
 
        # 下采样patch_size倍
        x = self.proj(x)
        _, _, H, W = x.shape
        # flatten: [B, C, H, W] -> [B, C, HW]
        # transpose: [B, C, HW] -> [B, HW, C]
        x = x.flatten(2).transpose(1, 2)
        x = self.norm(x)
        return x, H, W
 
 
class PatchMerging(nn.Module):
    r""" Patch Merging Layer.
    Args:
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """
 
    def __init__(self, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)
 
    def forward(self, x, H, W):
        """
        x: B, H*W, C
        """
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
 
        x = x.view(B, H, W, C)
 
        # padding
        # 如果输入feature map的H,W不是2的整数倍,需要进行padding
        pad_input = (H % 2 == 1) or (W % 2 == 1)
        if pad_input:
            # to pad the last 3 dimensions, starting from the last dimension and moving forward.
            # (C_front, C_back, W_left, W_right, H_top, H_bottom)
            # 注意这里的Tensor通道是[B, H, W, C],所以会和官方文档有些不同
            x = F.pad(x
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值