1.图示

2.原理
Patch Merging层进行下采样。该模块的作用是做降采样,用于缩小分辨率,调整通道数 进而形成层次化的设计,同时也能节省一定运算量。
在CNN中,则是在每个Stage开始前用stride=2的卷积/池化层来降低分辨率。
patch Merging是一个类似于池化的操作,但是比Pooling操作复杂一些。池化会损失信息,patch Merging不会。
每次降采样是两倍,因此在行方向和列方向上,按位置间隔2选取元素,拼成新的patch,再把所有patch都concat起来作为一整个张量,最后展开。此时通道维度会变成原先的4倍(因为H,W各缩小2倍),此时再通过一个全连接层再调整通道维度为原来的两倍。
3.源码
import torch
import torch.nn as nn
import math
import numpy as np
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, norm_layer=nn.LayerNorm):
super().__init__()

本文详细介绍了Transformer中Patch Merging层的工作原理和实现方式,该层用于下采样,减少分辨率,调整通道数,并在保持信息完整性的前提下节省计算量。与传统的池化操作不同,Patch Merging通过拼接相邻的patch来实现降采样,最终通过线性层减少通道数。源代码展示了如何在PyTorch中实现这一过程。
最低0.47元/天 解锁文章
4612





