引言
Vision Transformer (ViT) 自2020年提出以来,在计算机视觉领域掀起了一场革命。然而,传统ViT在处理图像时将patch作为基本单位,这种粗粒度的处理方式可能会丢失重要的局部细节信息。为了解决这一问题,韩凯等人在2021年NeurIPS会议上提出了Transformer in Transformer (TNT) 架构,通过引入"视觉句子"和"视觉词汇"的概念,实现了更加精细的图像特征建模。
TNT的核心思想是将传统ViT中的patch(视觉句子)进一步细分为更小的sub-patch(视觉词汇),并通过嵌套的Transformer结构同时建模局部和全局的依赖关系。这种设计不仅保持了Transformer的强大表达能力,还能够捕获更丰富的图像细节信息。
本文将详细介绍TNT论文的代码复现过程,包括模型架构的实现、训练配置的优化以及在CIFAR-10数据集上的实验结果。通过完整的代码实现和实验验证,我们将深入理解TNT架构的工作原理和实际效果。
TNT模型架构深度解析
核心设计理念
TNT架构的设计灵感来源于自然语言处理中的句子-词汇层次结构。在传统的ViT中,图像被分割成固定大小的patch,每个patch被视为一个token。而TNT进一步将每个patch(视觉句子)细分为多个sub-patch(视觉词汇),从而实现了更细粒度的特征建模。
这种层次化的设计带来了两个主要优势:首先,内层Transformer能够建模patch内部的局部依赖关系,捕获细粒度的纹理和结构信息;其次,外层Transformer继续处理patch级别的全局依赖关系,保持了对整体图像结构的理解能力。
双层Transformer架构
TNT的核心是双层Transformer架构,包含内层Transformer(Inner Transformer)和外层Transformer(Outer Transformer)。内层Transformer专门处理视觉词汇之间的关系,而外层Transformer处理视觉句子之间的关系。
class TNT(nn.Module):
def __init__(self, *, image_size, patch_dim, pixel_dim, patch_size,
pixel_size, depth, num_classes, channels=3, heads=8,
dim_head=64, ff_dropout=0., attn_dropout=0.):
super().__init__()
assert divisible_by(image_size, patch_size), 'image size must be divisible by patch size'
assert divisible_by(patch_size, pixel_size), 'patch size must be divisible by pixel size'
num_patch_tokens = (image_size // patch_size) ** 2
self.image_size = image_size
self.patch_size = patch_size
# 初始化patch tokens (视觉句子)
self.patch_tokens = nn.Parameter(torch.randn(num_patch_tokens + 1, patch_dim))
# 像素级token提取器
self.to_pixel_tokens = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> (b h w) c p1 p2', p1=patch_size, p2=patch_size),
nn.Unfold(kernel_size=pixel_size, stride=pixel_size, padding=0),
Rearrange('... c n -> ... n c'),
nn.Linear(channels * pixel_size ** 2, pixel_dim)
)
在这个实现中,to_pixel_tokens
模块负责将输入图像转换为像素级的token序列。首先使用Rearrange
操作将图像重新组织为patch的形式,然后通过Unfold
操作进一步细分为sub-patch,最后通过线性层映射到像素维度的特征空间。
前向传播机制
TNT的前向传播过程体现了其双层架构的精妙设计。在每一层中,首先对像素级token进行自注意力计算和前馈网络处理,然后将处理后的像素级特征聚合到patch级别,最后对patch级token进行自注意力计算。
def forward(self, x):
b, _, h, w, patch_size, image_size = *x.shape, self.patch_