SwinTransformer学习记录(二)之SwinTransformer Block

本文详细介绍了SwinTransformer的关键组件,包括将图像分割为Patch的PatchEmbedding,用于降采样的PatchMerging,以及包含WindowAttention和ShiftedWindowAttention的SwinTransformerBlock。它强调了局部窗口注意力机制如何减少计算复杂度和相对位置编码的应用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

在前面的章节中,我们学习了SwinTransformer的整体框架,其主要由Patch Merging模块与SwinTansformer Block模块组成,

Patch Embedding

在输入进Swin Transformer Block 前,需要将图片切成一个个 patch,然后嵌入向量。

具体做法是对原始图片裁成一个个 window_size * window_size 的窗口大小,然后进行嵌入。

这里可以通过二维卷积层,将 stridekernel_size 设置为 window_size 大小。设定输出通道来确定嵌入向量的大小。最后将 H,W 维度展开,并移动到第一维度。这里的window_size设置为4,具体过程如下,其他阶段的Patch也是如法炮制。

在这里插入图片描述

class PatchEmbed(nn.Module):
    """ Image to Patch Embedding
    Args:
        patch_size (int): Patch token size. Default: 4.
        in_chans (int): Number of input image channels. Default: 3.
        embed_dim (int): Number of linear projection output channels. Default: 96.
        norm_layer (nn.Module, optional): Normalization layer. Default: None
    """
    def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()
        patch_size = to_2tuple(patch_size)
        self.patch_size = patch_size
        self.in_chans = in_chans
        self.embed_dim = embed_dim
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None
    def forward(self, x):
        """Forward function."""
        # padding
        _, _, H, W = x.size()
        if W % self.patch_size[1] != 0:
            x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
        if H % self.patch_size[0] != 0:
            x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
        x = self.proj(x)  # B C Wh Ww
        if self.norm is not None:
            Wh, Ww = x.size(2), x.size(3)
            x = x.flatten(2).transpose(1, 2)
            x = self.norm(x)
            x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
        return x

Patch Merging

该模块的作用是在每个stage 开始前做降采样,用于缩小分辨率,调整通道数。

在 CNN 中,则是在每个 Stage 开始前用stride=2的卷积/池化层来降低分辨率。
每次降采样是两倍,因此在行方向和列方向上,间隔 2 选取元素。

然后拼接在一起作为一整个张量,最后展开。此时通道维度会变成原先的 4 倍(因为 H,W 各缩小 2 倍),此时再通过一个全连接层再调整通道维度为原来的两倍。如下图所示:

在这里插入图片描述

class PatchMerging(nn.Module):
    """ Patch Merging Layer
    Args:
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """
    def __init__(self, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

彭祥.

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值