摘要
贡献
- 从低级特征中提取patch
- 使用LeFF,促进空间维度上相邻token之间的关系
- Layer-wise Class token Attention (LCA)
介绍
CNN
- 提取低级特征,局部性
- 平移不变性与权重共享机制有关,它可以捕获有关视觉任务中的几何和拓扑结构的信息
- 局部性,相邻像素之间是相关的
transformer
- 很难提取图像中的低级特征。
- 自注意力模块专注于构建令牌之间的长期依赖关系,而忽略了空间维度上的局部性。
解决
- Image-to-Tokens (I2T) 模块,该模块从生成的低级特征中提取patch
- LeFF层,该层提高了空间维度上相邻令牌之间的相关性
方法
I2T
I2T 充分利用了 CNN 在提取低级特征方面的优势,并通过缩小patch大小来降低嵌入的训练难度。
LeFF
使用深度卷积来增强tokens在空间维度上的相关性。
code
I2T
# IoT
# I2T 模块是一个轻量级的 stem,由卷积层和最大池化层组成。
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, conv_kernel, stride, 4),
nn.BatchNorm2d(out_channels),
nn.MaxPool2d(pool_kernel, stride)
)
feature_size = image_size // 4
assert feature_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (feature_size // patch_size) ** 2
patch_dim = out_channels * patch_size ** 2
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size),
nn.Linear(patch_dim, dim),
)
LeFF
深度卷积来增强相邻tokens之间的空间相关性
class LeFF(nn.Module):
def __init__(self, dim=192, scale=4, depth_kernel=3):
super().__init__()
scale_dim = dim * scale
self.up_proj = nn.Sequential(nn.Linear(dim, scale_dim),
Rearrange('b n c -> b c n'),
nn.BatchNorm1d(scale_dim),
nn.GELU(),
Rearrange('b c (h w) -> b c h w', h=14, w=14)
)
self.depth_conv = nn.Sequential(
nn.Conv2d(scale_dim, scale_dim, kernel_size=depth_kernel, padding=1, groups=scale_dim, bias=False),
nn.BatchNorm2d(scale_dim),
nn.GELU(),
Rearrange('b c h w -> b (h w) c', h=14, w=14)
)
self.down_proj = nn.Sequential(nn.Linear(scale_dim, dim),
Rearrange('b n c -> b c n'),
nn.BatchNorm1d(dim),
nn.GELU(),
Rearrange('b c n -> b n c')
)
def forward(self, x):
x = self.up_proj(x)
x = self.depth_conv(x)
x = self.down_proj(x)
return x