Swin Transformer之PatchMerging原理及源码

本文详细介绍了Transformer中Patch Merging层的工作原理和实现方式,该层用于下采样,减少分辨率,调整通道数,并在保持信息完整性的前提下节省计算量。与传统的池化操作不同,Patch Merging通过拼接相邻的patch来实现降采样,最终通过线性层减少通道数。源代码展示了如何在PyTorch中实现这一过程。

1.图示

 2.原理

Patch Merging层进行下采样。该模块的作用是做降采样,用于缩小分辨率,调整通道数 进而形成层次化的设计,同时也能节省一定运算量。

在CNN中,则是在每个Stage开始前用stride=2的卷积/池化层来降低分辨率。

patch Merging是一个类似于池化的操作,但是比Pooling操作复杂一些。池化会损失信息,patch Merging不会。

每次降采样是两倍,因此在行方向和列方向上,按位置间隔2选取元素,拼成新的patch,再把所有patch都concat起来作为一整个张量,最后展开。此时通道维度会变成原先的4倍(因为H,W各缩小2倍),此时再通过一个全连接层再调整通道维度为原来的两倍。

3.源码

import torch
import torch.nn as nn
import math
import numpy as np


class PatchMerging(nn.Module):
    r""" Patch Merging Layer.
    Args:
        input_resolution (tuple[int]): Resolution of input feature.
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, norm_layer=nn.LayerNorm):
        super().__init__()

     
### Swin TransformerPatch Merging 的工作原理 #### 工作原理概述 Patch Merging 层作为 Swin Transformer 架构中的重要组成部分,负责执行下采样的操作。具体来说,这一层会将相邻的多个 patch 合并成一个新的较大尺寸的 patch,从而减少特征图的空间维度,并增加通道数量[^3]。 这种设计不仅有助于构建更深层次的网络结构,还能够有效地降低后续计算成本,使得模型可以在不同尺度上捕捉到更加丰富的空间信息[^4]。 #### 实现细节 为了更好地理解其实现过程,下面给出 Python 伪代码来展示如何在一个二维输入张量 `x` 上应用 Patch Merging: ```python import torch.nn as nn import torch class PatchMerging(nn.Module): r""" Patch Merging Layer. Args: input_resolution (tuple[int]): Resolution of input feature. dim (int): Number of input channels. Attributes: reduction (nn.Linear): Linear layer to reduce the number of features after merging patches. norm (nn.LayerNorm): Normalization layer applied before linear transformation. Methods: forward(x): Forward pass through this module, where x is expected to have shape [B,H,W,C]. Returns output tensor with reduced spatial dimensions and increased channel count. """ def __init__(self, input_resolution, dim): super().__init__() self.input_resolution = input_resolution self.dim = dim # Define a linear projection that reduces dimensionality by half while doubling channels self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) self.norm = nn.LayerNorm(4 * dim) def forward(self, x): """ x: B, H*W, C """ H, W = self.input_resolution B, L, C = x.shape assert L == H * W, "input feature has wrong size" x = x.view(B, H, W, C) # Reshape into non-overlapping windows; each window contains four consecutive elements from original sequence x0 = x[:, 0::2, 0::2, :] # top-left corner of every block x1 = x[:, 1::2, 0::2, :] # bottom-left corner of every block x2 = x[:, 0::2, 1::2, :] # top-right corner of every block x3 = x[:, 1::2, 1::2, :] # bottom-right corner of every block x = torch.cat([x0, x1, x2, x3], -1) # concatenate all parts together along last axis x = x.view(B, -1, 4 * C) # reshape back to batched sequences x = self.norm(x) x = self.reduction(x) return x ``` 此段代码定义了一个名为 `PatchMerging` 的类,它接收两个参数:输入分辨率 (`input_resolution`) 和通道数 (`dim`)。在前向传播过程中,该函数首先验证输入数据是否符合预期大小;接着按照特定模式重组像素点位置关系,最后利用线性变换与标准化处理完成整个流程。
评论 7
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值