TNT (Transformer in Transformer) 论文代码简单复现:从理论到实践的完整实现

引言

Vision Transformer (ViT) 自2020年提出以来,在计算机视觉领域掀起了一场革命。然而,传统ViT在处理图像时将patch作为基本单位,这种粗粒度的处理方式可能会丢失重要的局部细节信息。为了解决这一问题,韩凯等人在2021年NeurIPS会议上提出了Transformer in Transformer (TNT) 架构,通过引入"视觉句子"和"视觉词汇"的概念,实现了更加精细的图像特征建模。

TNT的核心思想是将传统ViT中的patch(视觉句子)进一步细分为更小的sub-patch(视觉词汇),并通过嵌套的Transformer结构同时建模局部和全局的依赖关系。这种设计不仅保持了Transformer的强大表达能力,还能够捕获更丰富的图像细节信息。

本文将详细介绍TNT论文的代码复现过程,包括模型架构的实现、训练配置的优化以及在CIFAR-10数据集上的实验结果。通过完整的代码实现和实验验证,我们将深入理解TNT架构的工作原理和实际效果。

TNT模型架构深度解析

核心设计理念

TNT架构的设计灵感来源于自然语言处理中的句子-词汇层次结构。在传统的ViT中,图像被分割成固定大小的patch,每个patch被视为一个token。而TNT进一步将每个patch(视觉句子)细分为多个sub-patch(视觉词汇),从而实现了更细粒度的特征建模。

这种层次化的设计带来了两个主要优势:首先,内层Transformer能够建模patch内部的局部依赖关系,捕获细粒度的纹理和结构信息;其次,外层Transformer继续处理patch级别的全局依赖关系,保持了对整体图像结构的理解能力。

双层Transformer架构

TNT的核心是双层Transformer架构,包含内层Transformer(Inner Transformer)和外层Transformer(Outer Transformer)。内层Transformer专门处理视觉词汇之间的关系,而外层Transformer处理视觉句子之间的关系。

class TNT(nn.Module):
    def __init__(self, *, image_size, patch_dim, pixel_dim, patch_size, 
                 pixel_size, depth, num_classes, channels=3, heads=8, 
                 dim_head=64, ff_dropout=0., attn_dropout=0.):
        super().__init__()
        assert divisible_by(image_size, patch_size), 'image size must be divisible by patch size'
        assert divisible_by(patch_size, pixel_size), 'patch size must be divisible by pixel size'

        num_patch_tokens = (image_size // patch_size) ** 2
        self.image_size = image_size
        self.patch_size = patch_size
        
        # 初始化patch tokens (视觉句子)
        self.patch_tokens = nn.Parameter(torch.randn(num_patch_tokens + 1, patch_dim))
        
        # 像素级token提取器
        self.to_pixel_tokens = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> (b h w) c p1 p2', p1=patch_size, p2=patch_size),
            nn.Unfold(kernel_size=pixel_size, stride=pixel_size, padding=0),
            Rearrange('... c n -> ... n c'),
            nn.Linear(channels * pixel_size ** 2, pixel_dim)
        )

在这个实现中,to_pixel_tokens模块负责将输入图像转换为像素级的token序列。首先使用Rearrange操作将图像重新组织为patch的形式,然后通过Unfold操作进一步细分为sub-patch,最后通过线性层映射到像素维度的特征空间。

前向传播机制

TNT的前向传播过程体现了其双层架构的精妙设计。在每一层中,首先对像素级token进行自注意力计算和前馈网络处理,然后将处理后的像素级特征聚合到patch级别,最后对patch级token进行自注意力计算。

def forward(self, x):
    b, _, h, w, patch_size, image_size = *x.shape, self.patch_
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

智算菩萨

欢迎阅读最新融合AI编程内容

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值