SMA2:代码实现详解——Image Encoder篇(Hiera章)

SMA2:代码实现详解——Image Encoder篇(Hiera)

写在前面

大家在SMA2:代码实现详解——Image Encoder篇(FpnNeck)下的留言我已收到,感谢大家的支持,后面如果遇到比较难以讲清的部分可能会使用视频的形式。博主最近要准备秋招,更新可能会慢许多,希望大家能谅解。

言归正传,在SMA2:代码实现详解——Image Encoder篇(FpnNeck)中,我们已经知道了SMA2的整体架构,并且介绍了Image Encoder组件中的FpnNeck。这一篇博客我们就来详细介绍Image Encoder的基本骨架backbone——Hiera

Hiera介绍

Hiera是文章Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles中提出的一种分层视觉Transformer架构。它不仅可以处理图像,而且这个架构可以应用于视频。Hiera是一个纯粹的简单分层ViT模型,不存在任何卷积、移位或者十字窗口操作,仅有Transformer结构组件。它比之前跨多个模型大小、领域和任务的工作更快、更准确。
在这里插入图片描述

Hiera与MAE(Masked AutoEncoder)

MAE(Masked AutoEncoder, 掩码自编码器)

图像MAE由论文Masked Autoencoders Are Scalable Vision Learners提出,它表明,MAE是计算机视觉的可扩展自监督学习器。方法非常简单:屏蔽输入图像的随机Patch并重建丢失的像素。它基于两个核心设计。首先,作者开发了一种非对称编码器-解码器架构,其中的编码器仅对Patch的可见子集(没有掩码标记)进行操作,而轻量级解码器可根据潜在表示和掩码标记重建原始图像。作者发现屏蔽高比例的输入图像(例如 75%)会产生一项不简单且有意义的自我监督任务。将这两种设计结合起来能够高效且有效地训练大型模型:加速训练(3 倍或更多)并提高准确性。可扩展方法允许学习泛化良好的高容量模型:例如,在仅使用 ImageNet-1K 数据的方法中,普通 ViT-Huge 模型实现了最佳准确率 (87.8%)。下游任务中的传输性能优于监督预训练,并显示出有希望的扩展行为。

Hiera便使用了MAE的方式进行训练。

Hiera架构

在这里插入图片描述

选择使用像MAE(如图所示)这样的强代理任务(pretext task)来教导模型。 Hiera完全由标准ViT块组成。为了提高效率,在前两个阶段使用“掩模单元”内的局部注意力,其余阶段使用全局注意力(Global Attention)。在每个阶段转换中,Q和跳跃连接的特征通过线性层加倍,空间维度通过2×2最大池池化。

SMA2中Hiera(HieraDet)的实现

class Hiera(nn.Module):
    """
    Reference: https://arxiv.org/abs/2306.00989
    """

    def __init__(self, ...):
        ...
        self.blocks = nn.ModuleList()

        for i in range(depth):
            dim_out = embed_dim
            ...
            block = MultiScaleBlock(
                dim=embed_dim,
                dim_out=dim_out,
                num_heads=num_heads,
                drop_path=dpr[i],
                q_stride=self.q_stride if i in self.q_pool_blocks else None,
                window_size=window_size,
            )
            embed_dim = dim_out
            self.blocks.append(block)

    def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
        h, w = hw
        window_embed = self.pos_embed_window
        pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
        pos_embed = pos_embed + window_embed.tile(
            [x // y for x, y in zip(pos_embed.shape, window_embed.shape)]
        )
        pos_embed = pos_embed.permute(0, 2, 3, 1)
        return pos_embed

    def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
        x = self.patch_embed(x)
        # x: (B, H, W, C)

        # Add pos embed
        x = x + self._get_pos_embed(x.shape[1:3])

        outputs = []
        for i, blk in enumerate(self.blocks):
            x = blk(x)
            if (i == self.stage_ends[-1]) or (
                i in self.stage_ends and self.return_interm_layers
            ):
                feats = x.permute(0, 3, 1, 2)
                outputs.append(feats)

        return outputs

首先,Hiera先将图片划分并映射为patch嵌入向量(上述代码35行),然后计算位置信息并相加(代码第39行)。值得注意的是,SMA2在实现Hiera中位置嵌入时,参照了Window Attention is Bugged: How not to Interpolate Position Embeddings一文,他们发现在使用窗口注意力的同时插值位置嵌入是错误的。Hiera和ViTDet两者确实都存在此错误。于是作者提出了一种简单的绝对窗口位置嵌入策略,它彻底解决了Hiera中的错误,并提高了ViTDet中模型的速度和性能。

代码的41-48行实际上就是Hiera主体ViT块的处理,值得关注的只有带有Q pooling的ViT块,这是在MultiScaleBlock中实现的。

class PatchEmbed(nn.Module):
    """
    Image to Patch Embedding.
    """

    def __init__(
        self,
        kernel_size: Tuple[int, ...] = (7, 7),
        stride: Tuple[int, ...] = (4, 4),
        padding: Tuple[int, ...] = (3, 3),
        in_chans: int = 3,
        embed_dim: int = 768,
    ):
        """
        Args:
            kernel_size (Tuple): kernel size of the projection layer.
            stride (Tuple): stride of the projection layer.
            padding (Tuple): padding size of the projection layer.
            in_chans (int): Number of input image channels.
            embed_dim (int):  embed_dim (int): Patch embedding dimension.
        """
        super().__init__()
        self.proj = nn.Conv2d(
            in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.proj(x)
        # B C H W -> B H W C
        x = x.permute(0, 2, 3, 1)
        return x

PatchEmbed模块将图片的形状(B,C,H,W)转化为更常见的适用于Transformer处理的形状(B, H, W, C),因为后面经过VIT块时会要求(B,L,C)的形式。实际上,这个模块的卷积映射继承了ViT的做法,直接利用了卷积的特性,通过指定Kernel_size与strides隐式划分了窗口,并且完成了线性变换得到patch enmbedding

值得注意的是位置嵌入的计算:

class Hiera(nn.Module):
 
    def __init__(...)
        super().__init__()
        ...
        self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
        self.pos_embed = nn.Parameter(
            torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size)
        )
        self.pos_embed_window = nn.Parameter(
            torch
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值