【Python/Pytorch - 网络模型】-- 手把手搭建Swin-Transformer模型 - PatchEmbed 模块

在这里插入图片描述
文章目录

写在前面

根据博主文章,学习Transformer、Vit、SwinTransformer、SwinUnetr原理,原理博主文章已讲解较为详细,本文主要从代码角度学习各个模块的原理。

图解Vit 1:Vision Transformer——图像与Transformer基础 - 知乎
图解Vit 2:Vision Transformer——视觉问题中的注意力机制 - 知乎
图解Vit 3:Vision Transformer——ViT模型全流程拆解(Layer Normalization, Position Embedding) - 知乎
图解+代码 Swin Transformer 1: W-MSA和Patch Merging - 知乎
图解+代码 Swin Transformer 2: SW-MSA(Shifted Window Multi-head Self Attention) - 知乎

PatchEmbed 模块

Patch Embedding过程:一张[h, w, 3] img 按照window的大小,分成不同的patch,他们的维度变成[num_patch, Wh, Ww, 3]. 每一个patch将它的flatten成1维,然后过一个linear层,最后输出的就是一个token,所有token就是patch embedding。每一个token拉平就是embed_dim。

class PatchEmbed3D(nn.Module):
    """ Video to Patch Embedding.
    Args:
        patch_size (int): Patch token size. Default: (2,4,4).
        in_chans (int): Number of input video channels. Default: 3.
        embed_dim (int): Number of linear projection output channels. Default: 96.
        norm_layer (nn.Module, optional): Normalization layer. Default: None
    """
    def __init__(self, patch_size=(2,4,4), in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()
        self.patch_size = patch_size

        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
        """Forward function."""
        # padding
        _, _, D, H, W = x.size()
        if W % self.patch_size[2] != 0:
            x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))
        if H % self.patch_size[1] != 0:
            x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))
        if D % self.patch_size[0] != 0:
            x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))

        x = self.proj(x)  # B C D Wh Ww
        if self.norm is not None:
            D, Wh, Ww = x.size(2), x.size(3), x.size(4)
            x = x.flatten(2).transpose(1, 2)
            x = self.norm(x)
            x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)

        return x

第一点

import torch
import torch.nn as nn
import torch.nn.functional as F

# 输入数据:形状为 (batch_size, channels, depth, height, width)
x = torch.randn(1, 3, 16, 128, 128)  # 示例:1 个样本,3 个通道,16 帧,128×128 的图像

# 初始化 PatchEmbed3D
patch_embed = PatchEmbed3D(patch_size=(2, 4, 4), in_chans=3, embed_dim=96)

# 前向传播
output = patch_embed(x)

# 输出形状
print(output.shape)  # 示例输出:torch.Size([1, 96, 8, 32, 32])
  • 输出张量的形状为 (1, 96, 8, 32, 32),表示:
  • 输入数据被分割成 (2, 4, 4) 大小的 patch。
  • 每个 patch 被投影到 96 维的特征空间。
  • 时间维度从 16 变为 8(因为 16 / 2 = 8)。
  • 高度和宽度分别从 128 变为 32(因为 128 / 4 = 32)。

第二点

  • F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2])):在宽度方向上进行填充。F.pad 的参数是一个元组,表示每个维度的填充量。这里只填充宽度的右侧。
  • F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1])):在高度方向上进行填充。F.pad 的参数表示填充的顺序为:左侧、右侧、顶部、底部。这里只填充高度的底部。
  • F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0])):在深度方向上进行填充。F.pad 的参数表示填充的顺序为:左侧、右侧、顶部、底部、前面、后面。这里只填充深度的后面。

第三点

  • D, Wh, Ww = x.size(2), x.size(3), x.size(4)
  • 这行代码获取了输入张量 x 的深度(D)、高度(Wh)和宽度(Ww)。
  • x = x.flatten(2).transpose(1, 2)
  • flatten(2):
  • 将从第 2 维(从 0 开始计数)开始的维度展平。具体来说,将 D, Wh, Ww 三个维度展平为一个维度。
  • 展平后的张量形状变为 (batch_size, embed_dim, D * Wh * Ww)。
  • transpose(1, 2):
  • 将第 1 维和第 2 维交换,使得特征维度(embed_dim)和展平后的空间维度(D * Wh * Ww)交换位置。
  • 交换后的张量形状变为 (batch_size, D * Wh * Ww, embed_dim)。
  • x = self.norm(x)
  • 归一化层(如 nn.LayerNorm)通常作用于最后一个维度(即特征维度)。通过上述的 transpose 操作,特征维度已经移到了最后一个维度,因此可以直接应用归一化。
  • x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)
  • transpose(1, 2):
  • 将第 1 维和第 2 维再次交换,恢复特征维度和空间维度的原始顺序。
  • 交换后的张量形状变为 (batch_size, embed_dim, D * Wh * Ww)。
  • view(-1, self.embed_dim, D, Wh, Ww):
  • 将展平后的空间维度重新恢复为原始的 3D 形状 (D, Wh, Ww)。
  • 最终的张量形状为 (batch_size, embed_dim, D, Wh, Ww)。

论文下载

Attention Is All You Need
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
Swin UNETR: Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images
SwinUNETR-V2: Stronger Swin Transformers with Stagewise Convolutions for 3D Medical Image Segmentation

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

电科_银尘

你的鼓励将是我最大的创作动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值