【Python/Pytorch - 网络模型】-- 手把手搭建Swin-Transformer模型 - WindowAttention3D模型

在这里插入图片描述
文章目录

写在前面

根据博主文章,学习Transformer、Vit、SwinTransformer、SwinUnetr原理,原理博主文章已讲解较为详细,本文主要从代码角度学习各个模块的原理。

图解Vit 1:Vision Transformer——图像与Transformer基础 - 知乎
图解Vit 2:Vision Transformer——视觉问题中的注意力机制 - 知乎
图解Vit 3:Vision Transformer——ViT模型全流程拆解(Layer Normalization, Position Embedding) - 知乎
图解+代码 Swin Transformer 1: W-MSA和Patch Merging - 知乎
图解+代码 Swin Transformer 2: SW-MSA(Shifted Window Multi-head Self Attention) - 知乎

WindowAttention3D

class WindowAttention3D(nn.Module):
    """ Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.
    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The temporal length, height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wd, Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), num_heads))  # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH

        # get pair-wise relative position index for each token inside the window
        coords_d = torch.arange(self.window_size[0])
        coords_h = torch.arange(self.window_size[1])
        coords_w = torch.arange(self.window_size[2])
        coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w))  # 3, Wd, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 3, Wd*Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 3, Wd*Wh*Ww, Wd*Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wd*Wh*Ww, Wd*Wh*Ww, 3
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 2] += self.window_size[2] - 1

        relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1)
        relative_coords[:, :, 1] *= (2 * self.window_size[2] - 1)
        relative_position_index = relative_coords.sum(-1)  # Wd*Wh*Ww, Wd*Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None):
        """ Forward function.
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, N, N) or None
        """
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # B_, nH, N, C

        q = q * self.scale
        attn = q @ k.transpose(-2, -1)

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index[:N, :N].reshape(-1)].reshape(
            N, N, -1)  # Wd*Wh*Ww,Wd*Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wd*Wh*Ww, Wd*Wh*Ww
        attn = attn + relative_position_bias.unsqueeze(0) # B_, nH, N, N

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

相对位置编码理解

有关swin transformer相对位置编码的理解:_swin transformer 相对编码-优快云博客

 coords = torch.stack(torch.meshgrid(coords_h, coords_w))  # 2, Wh, Ww
  • torch.stack(…):torch.stack 将生成的两个二维数组堆叠成一个三维张量。堆叠的维度是新的维度(默认为第0维)。因此,coords 的形状为 (2, Wh, Ww),其中:
  • 第一个维度(大小为2)分别表示行坐标和列坐标。
  • 第二个维度(大小为Wh)表示网格的高度。
  • 第三个维度(大小为Ww)表示网格的宽度。
  • torch.flatten(coords, 1):torch.flatten 函数用于将张量的指定维度展平。这里的参数 1 表示从第1维开始展平(即保留第0维)。因此,coords 的形状 (2, Wh, Ww) 会被展平为 (2, Wh * Ww)
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
    1. coords_flatten[:, :, None] 和 coords_flatten[:, None, :]
  • coords_flatten是一个二维张量,形状为 (2, N),其中:
  • 2 表示两个坐标维度(例如,行和列)。
  • N 表示网格中的点数(网格的总像素或单元数)。
  • coords_flatten[:, :, None]:在第三个维度上增加一个新的维度,形状变为 (2, N, 1)。
  • coords_flatten[:, None, :]:在第二个维度上增加一个新的维度,形状变为 (2, 1, N)。
    1. 广播机制
  • 当执行 coords_flatten[:, :, None] - coords_flatten[:, None, :] 时,广播机制会自动扩展两个张量的维度以匹配对方的形状:
  • coords_flatten[:, :, None] 的形状为 (2, N, 1)。
  • coords_flatten[:, None, :] 的形状为 (2, 1, N)。
    通过广播,两个张量的形状会被扩展为 (2, N, N),然后逐元素相减。
  1. 结果解释
  • 最终得到的 relative_coords 是一个三维张量,形状为 (2, N, N),其中:
  • 第一个维度(2):表示两个坐标的差异(例如,行坐标差和列坐标差)。
  • 第二个和第三个维度(N, N):表示网格中每个点与其他所有点的相对位置。
    例如,假设网格中有 4 个点,relative_coords 的形状为 (2, 4, 4)。每个点 (i, j) 的相对坐标差表示点 i 和点 j 之间的位置差异。

self.window_size

  relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
  relative_coords[:, :, 1] += self.window_size[1] - 1
  relative_coords[:, :, 2] += self.window_size[2] - 1
  • 假设窗口大小为 (Wd, Wh, Ww),分别表示窗口在深度、高度和宽度方向的大小。
  • 例如,Wd = 5,则窗口在深度方向上的范围为 0 到 4。
  • relative_coords[:, :, 0] += self.window_size[0] - 1:
  • 对于深度维度(0),将相对坐标差加上 Wd - 1,以确保最小的相对坐标差值为 0。
  • 假设 Wd = 5,则 self.window_size[0] - 1 = 4。
  • 如果相对坐标差为 -2,加上 4 后变为 2。
  • 作用:
  • 确保每个相对坐标差在平移后为非负数,并且范围为 [0, 2 * Wd - 2]。
  • 例如,原始相对坐标差的范围为 [-4, 4](假设窗口大小为 5),平移后变为 [0, 8]。

relative_coords

relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1)
relative_coords[:, :, 1] *= (2 * self.window_size[2] - 1)
relative_position_index = relative_coords.sum(-1)  # Wd*Wh*Ww, Wd*Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)

这两行代码对三维相对坐标差 relative_coords 的前两个维度应用了缩放因子,以便将三维相对坐标转换为一维索引,用于查询相对位置偏置表(relative_position_bias_table)。

  1. 内容分解
    假设相对坐标 relative_coords 的形状为 (N, N, 3),其中 N 是窗口中的元素数量,3 表示三维坐标(深度、高度、宽度)。
    第一行代码:
relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1)
  • relative_coords[:, :, 0]:表示三维相对坐标的第一个维度(深度方向)。

  • 2 * self.window_size[1] - 1:计算窗口在高度方向的扩展范围。如果窗口的高度为 Wh,那么扩展后的高度范围为 2*Wh-1。

  • 2 * self.window_size[2] - 1:计算窗口在宽度方向的扩展范围。如果窗口的宽度为 Ww,那么扩展后的宽度范围为 2*Ww-1。

  • *=:将深度方向的坐标值乘以高度和宽度的扩展范围之积。这相当于为深度方向的坐标值赋予一个较大的权重,使其在后续的一维索引计算中占据更高的基数。
    第二行代码:
    relative_coords[:, :, 1] *= (2 * self.window_size[2] - 1)

  • relative_coords[:, :, 1]:表示三维相对坐标的第二个维度(高度方向)。

  • 2 * self.window_size[2] - 1:计算窗口在宽度方向的扩展范围。

  • *=:将高度方向的坐标值乘以宽度方向的扩展范围。这相当于为高度方向的坐标值赋予一个中等权重,使其在后续的一维索引计算中占据中等的基数。

  1. 目的
    通过这两行代码,我们将三维的相对坐标 (dx, dy, dz) 转换为一个唯一的一维索引,以便能够从相对位置偏置表(relative_position_bias_table)中查询对应的偏置值。这类似于将三维坐标 (x, y, z) 转换为一维索引 z + y * Ww + x * Wh * Ww,其中 Ww 和 Wh 是窗口的宽度和高度。

  2. 一维索引计算示例

假设窗口大小为 Wd = 3、Wh = 2、Ww = 2,则:
- 各维度的扩展范围为:

- 深度方向:2*3-1 = 5
- 高度方向:2*2-1 = 3
- 宽度方向:2*2-1 = 3
对于一个相对坐标 (2, 1, 1),按以下方式计算一维索引:
1.深度坐标 dx = 2- 权重为:(2*2-1)*(2*2-1) = 3 * 3 = 9
- 乘积:2 * 9 = 18
2.高度坐标 dy = 1- 权重为:2*2-1 = 3
- 乘积:1 * 3 = 3
3.宽度坐标 dz = 1- 权重为:1(无需缩放)
- 乘积:1 * 1 = 1
总的一维索引为:18 + 3 + 1 = 22

注意力分数计算

  1. 计算注意力分数
q = q * self.scale
attn = q @ k.transpose(-2, -1)
- q * self.scale:
  • 将查询矩阵 q 与缩放因子 self.scale 相乘。self.scale 通常设置为 head_dim ** -0.5,其中 head_dim 是每个注意力头的维度。这个缩放因子是为了稳定注意力分数的数值分布,避免梯度消失或爆炸。

  • k.transpose(-2, -1):

  • 对键矩阵 k 的最后两个维度进行转置。假设 k 的形状为 (B_, N, C),转置后形状变为 (B_, C, N)。

  • q @ k.transpose(-2, -1):

  • 计算查询矩阵 q 和转置后的键矩阵 k 的矩阵乘积,得到原始的注意力分数矩阵 attn。形状为 (B_, N, N),表示每个查询与每个键之间的相似度。

  1. 添加相对位置偏置
relative_position_bias = self.relative_position_bias_table[self.relative_position_index[:N, :N].reshape(-1)].reshape(N, N, -1)
  • self.relative_position_bias_table:

  • 一个存储相对位置偏置的张量,形状为 (2Wd-1)(2Wh-1)(2*Ww-1), num_heads),其中 Wd、Wh、Ww 是窗口在深度、高度和宽度方向的大小。

  • self.relative_position_index[:N, :N]:

  • 一个二维索引矩阵,形状为 (N, N),其中 N 是窗口中的元素数量。它用于从 self.relative_position_bias_table 中提取对应的偏置值。

  • .reshape(-1):

  • 将索引矩阵展平为一维数组,以便从 self.relative_position_bias_table 中提取偏置值。

  • .reshape(N, N, -1):

  • 将提取的偏置值重新 reshape 为 (N, N, num_heads),其中 num_heads 是注意力头的数量。

  1. 调整维度顺序
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
  • permute(2, 0, 1):

  • 调整偏置张量的维度顺序,使其变为 (num_heads, N, N),以便与注意力分数矩阵的形状匹配。

  • contiguous():

  • 确保张量在内存中是连续的,便于后续操作。

  1. 合并注意力分数和相对位置偏置
attn = attn + relative_position_bias.unsqueeze(0)
  • relative_position_bias.unsqueeze(0):

  • 在第0维添加一个新的维度,将偏置张量的形状变为 (1, num_heads, N, N),以匹配注意力分数矩阵 attn 的形状 (B_, num_heads, N, N)。

  • attn + relative_position_bias.unsqueeze(0):

  • 将相对位置偏置加到注意力分数矩阵上,得到最终的注意力矩阵。这一步是为了引入位置信息,使注意力机制能够更好地捕捉不同位置之间的关系。

trunc_normal_(self.relative_position_bias_table, std=.02)

在 PyTorch 中,trunc_normal_ 是一个用于参数初始化的函数,它生成截断正态分布(truncated normal distribution)的随机数。以下是对代码的解释:

trunc_normal_(self.relative_position_bias_table, std=.02)
  • self.relative_position_bias_table:
  • std=.02:std 是截断正态分布的标准差。trunc_normal_ 函数会生成一个在均值附近截断的正态分布。默认情况下,截断范围是均值的两个标准差之外的值被重新采样。

assert

assert 0 <= self.shift_size[0] < self.window_size[0], "shift_size must in 0-window_size"
assert 0 <= self.shift_size[1] < self.window_size[1], "shift_size must in 0-window_size"
assert 0 <= self.shift_size[2] < self.window_size[2], "shift_size must in 0-window_size"

assert 是一个 Python 关键字,用于对表达式进行断言。如果表达式的值为 False,程序会抛出一个 AssertionError,并显示错误信息。

cyclic shift

# cyclic shift
if any(i > 0 for i in shift_size):
    shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3))
    attn_mask = mask_matrix
else:
    shifted_x = x
    attn_mask = None

这段代码实现了循环平移(cyclic shift)操作,通常用于视觉网络(如 Swin Transformer)中增强模型的空间感受能力。以下是对代码的详细解释:

  1. 循环平移条件判断
if any(i > 0 for i in shift_size):
  • shift_size:一个包含三个整数的数组或元组,表示在深度、高度和宽度方向上的平移尺寸。
  • 条件:如果 shift_size 的任何维度大于 0,则执行循环平移操作。
  1. 循环平移操作
shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3))
  • torch.roll:将张量 x 在指定维度上进行循环平移。
  • shifts:指定每个维度的平移量。例如,(-shift_size[0], -shift_size[1], -shift_size[2]) 表示在深度、高度和宽度方向上分别平移 -shift_size[0]、-shift_size[1] 和 -shift_size[2] 个单位。
  • dims:指定平移的维度。这里 dims=(1, 2, 3) 表示在第 1、2、3 维度上进行平移。
  • 负数是往右移,正数是往左移。
  1. 注意力掩码
attn_mask = mask_matrix
  • 如果执行了循环平移操作,则设置注意力掩码 attn_mask 为 mask_matrix。这个掩码用于在后续的注意力计算中屏蔽掉无效的位置。
  1. 不进行循环平移的情况
else:
    shifted_x = x
    attn_mask = None
  • 如果 shift_size 的所有维度都为 0,则直接使用原始输入张量 x,并且不使用注意力掩码(attn_mask = None)。
    关键点总结
  • 循环平移的作用:通过平移操作,打破局部特征的平移不变性,增强模型的空间感受能力。
  • torch.roll:用于实现张量的循环平移,shifts 参数控制平移量,dims 参数控制平移的维度。
  • attn_mask:用于在注意力计算中屏蔽无效区域,确保模型只关注有效的位置。

论文下载

Attention Is All You Need
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
Swin UNETR: Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images
SwinUNETR-V2: Stronger Swin Transformers with Stagewise Convolutions for 3D Medical Image Segmentation

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

电科_银尘

你的鼓励将是我最大的创作动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值