文章目录
写在前面
根据博主文章,学习Transformer、Vit、SwinTransformer、SwinUnetr原理,原理博主文章已讲解较为详细,本文主要从代码角度学习各个模块的原理。
图解Vit 1:Vision Transformer——图像与Transformer基础 - 知乎
图解Vit 2:Vision Transformer——视觉问题中的注意力机制 - 知乎
图解Vit 3:Vision Transformer——ViT模型全流程拆解(Layer Normalization, Position Embedding) - 知乎
图解+代码 Swin Transformer 1: W-MSA和Patch Merging - 知乎
图解+代码 Swin Transformer 2: SW-MSA(Shifted Window Multi-head Self Attention) - 知乎
PatchEmbed 模块
Patch Embedding过程:一张[h, w, 3] img 按照window的大小,分成不同的patch,他们的维度变成[num_patch, Wh, Ww, 3]. 每一个patch将它的flatten成1维,然后过一个linear层,最后输出的就是一个token,所有token就是patch embedding。每一个token拉平就是embed_dim。
class PatchEmbed3D(nn.Module):
""" Video to Patch Embedding.
Args:
patch_size (int): Patch token size. Default: (2,4,4).
in_chans (int): Number of input video channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, patch_size=(2,4,4), in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
self.patch_size = patch_size
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
"""Forward function."""
# padding
_, _, D, H, W = x.size()
if W % self.patch_size[2] != 0:
x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))
if H % self.patch_size[1] != 0:
x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))
if D % self.patch_size[0] != 0:
x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))
x = self.proj(x) # B C D Wh Ww
if self.norm is not None:
D, Wh, Ww = x.size(2), x.size(3), x.size(4)
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)
return x
第一点
import torch
import torch.nn as nn
import torch.nn.functional as F
# 输入数据:形状为 (batch_size, channels, depth, height, width)
x = torch.randn(1, 3, 16, 128, 128) # 示例:1 个样本,3 个通道,16 帧,128×128 的图像
# 初始化 PatchEmbed3D
patch_embed = PatchEmbed3D(patch_size=(2, 4, 4), in_chans=3, embed_dim=96)
# 前向传播
output = patch_embed(x)
# 输出形状
print(output.shape) # 示例输出:torch.Size([1, 96, 8, 32, 32])
- 输出张量的形状为 (1, 96, 8, 32, 32),表示:
- 输入数据被分割成 (2, 4, 4) 大小的 patch。
- 每个 patch 被投影到 96 维的特征空间。
- 时间维度从 16 变为 8(因为 16 / 2 = 8)。
- 高度和宽度分别从 128 变为 32(因为 128 / 4 = 32)。
第二点
- F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2])):在宽度方向上进行填充。F.pad 的参数是一个元组,表示每个维度的填充量。这里只填充宽度的右侧。
- F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1])):在高度方向上进行填充。F.pad 的参数表示填充的顺序为:左侧、右侧、顶部、底部。这里只填充高度的底部。
- F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0])):在深度方向上进行填充。F.pad 的参数表示填充的顺序为:左侧、右侧、顶部、底部、前面、后面。这里只填充深度的后面。
第三点
- D, Wh, Ww = x.size(2), x.size(3), x.size(4)
- 这行代码获取了输入张量 x 的深度(D)、高度(Wh)和宽度(Ww)。
- x = x.flatten(2).transpose(1, 2)
- flatten(2):
- 将从第 2 维(从 0 开始计数)开始的维度展平。具体来说,将 D, Wh, Ww 三个维度展平为一个维度。
- 展平后的张量形状变为 (batch_size, embed_dim, D * Wh * Ww)。
- transpose(1, 2):
- 将第 1 维和第 2 维交换,使得特征维度(embed_dim)和展平后的空间维度(D * Wh * Ww)交换位置。
- 交换后的张量形状变为 (batch_size, D * Wh * Ww, embed_dim)。
- x = self.norm(x)
- 归一化层(如 nn.LayerNorm)通常作用于最后一个维度(即特征维度)。通过上述的 transpose 操作,特征维度已经移到了最后一个维度,因此可以直接应用归一化。
- x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)
- transpose(1, 2):
- 将第 1 维和第 2 维再次交换,恢复特征维度和空间维度的原始顺序。
- 交换后的张量形状变为 (batch_size, embed_dim, D * Wh * Ww)。
- view(-1, self.embed_dim, D, Wh, Ww):
- 将展平后的空间维度重新恢复为原始的 3D 形状 (D, Wh, Ww)。
- 最终的张量形状为 (batch_size, embed_dim, D, Wh, Ww)。
论文下载
Attention Is All You Need
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
Swin UNETR: Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images
SwinUNETR-V2: Stronger Swin Transformers with Stagewise Convolutions for 3D Medical Image Segmentation