paper:An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
official implementation:https://github.com/google-research/vision_transformer
third-party implementation:https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
背景
虽然Transformer结构已经在自然语言处理任务中得到了广泛应用,但在计算机视觉任务中的应用仍然有限。在视觉中,注意力要么与卷积网络结合应用,要么用于替换卷积网络的某些组成部分,同时保持原本的整体结构。
本文的创新点
- 直接应用Transformer:与以往将 Transformer 与 CNN 结合或用 Transformer 替代 CNN 的某些部分不同,本文直接将标准的 Transformer 应用于图像,几乎没有修改。
- 图像分块处理:将图像分割成固定大小的patch,并将这些块的线性嵌入序列作为 Transformer 的输入,图像块被视为 NLP 应用中的tokens。
- 大规模预训练:在大规模数据集(如 ImageNet-21k 和 JFT-300M)上进行预训练的 Vision Transformer 在迁移到数据量较少的任务时能够取得优异的结果。
- 计算效率:在预训练阶段,ViT 显示出比 ResNet 等模型更低的计算成本,同时在多个识别基准测试中达到了最先进的性能。
- 简化模型设计:通过尽可能少地修改原始 Transformer 架构,实现了对图像的高效处理,简化了模型设计。
文章通过这些创新点,推动了 Transformer 模型在计算机视觉领域的应用,特别是在大规模图像识别任务中的表现。
方法介绍
Vision Transformer 的整体结构如图1所示。
标准Transformer接收一个一维的token embedding作为输入。为了处理二维图片,我们将图片 \(\mathbf{x}\in \mathbb{R}^{H\times W\times C}\) reshape成一系列展平的二维图像块 \(\mathbf{x}_p\in \mathbb{R}^{N\times (P^{2}\cdot C)}\),其中 \((H,W)\) 是原始图像的分辨率,\(C\) 是通道数,\((P,P)\) 是每个图像块的分辨率,\(N=HW/P^2\) 是图像块的数量同时也作为Transformer的有效输入序列长度。Transformer在其所有层中使用恒定的潜在向量大小 \(D\),因此我们展平图像块并用一个可训练的线性投影(式1)将其投影到 \(D\) 维,我们将这个投影的输出称为patch embeddings。
类似于BERT的[class] token,我们在图像块embedding序列前加上一个可学习的embedding \((\mathbf{z}^0_0=\mathbf{x}_{class})\),它在Transformer编码器输出的状态 \((\mathbf{z}^0_L)\) 作为图像表示 \(\mathbf{y}\)。在预训练和微调时,一个分类head紧跟在 \(\mathbf{z}^0_L\) 后。分类头在预训练时是一个有一个隐藏层的MLP,在微调时是单个线性层。
Position embedding被添加到图像块embedding中以保留位置信息。本文使用可学习的一维postion embedding,因为作者通过实验发现使用更高级的二维position embedding并不会带来性能上的提升。最终得到的embedding vector作为编码器的输入。
Transformer编码器交替使用multihead self-attention(MSA)和MLP blocks(式2, 3)。在每个block之前应用Layernorm(LN),在每个block之后应用residual connection。MLP包含两层和一个非线性激活函数GELU。
微调和更高的分辨率
通常,我们在大型数据集上预训练ViT,并对其进行微调以适应(较小的)下游任务。为此,我们移除预训练的prediction head,并添加一个0初始化的 \(D\times K\) 前向层,其中 \(K\) 是下游任务的类别数量。通常在比预训练更高的分辨率下进行微调是有益的。当输入更高分辨率的图像时,我们保持图像块大小不变,这导致有效序列长度变大。Vision Transformer可以处理任意序列长度(直到内存限制),但是预训练的position embedding可能不再有意义。因此,我们根据它们在原始图像中的位置,对预训练的position embedding进行二维插值。注意,这个分辨率调整和图像块提取是向Vision Transformer中唯一人工注入关于图像二维结构的归纳偏差。
实验结果
ViT的不同变种配置如表1所示。
表2展示了ViT在JFT-300M和ImageNet-21k数据上预训练的结果和其它SOTA方法的对比。
class token是继承了文本Transformer中的设计,另一种方法是使用全局平均池化然后接一个线性分类器,就像在卷积分类网络中做的那样,可是效果很差。作者发现这既不是因为额外的token的作用也不是因为GAP,而是学习率,在调整了学习率后,两者的效果是一样的。
代码解析
这里以timm中的实现为例,介绍一下代码。
第一步是对输入进行分块,代码如下,核心就是通过卷积实现分块以及展平flatten成二维的sequence。卷积kernel_size=stride=patch_size=16,输入大小224x224,则sequence_len=224x224/(16x16)=196。在PatchEmbed后,特征变成 (batch_size, seq_len, embed_dim) 比如 (1, 196, 192)。
class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding
"""
output_fmt: Format
dynamic_img_pad: torch.jit.Final[bool]
def __init__(
self,
img_size: Optional[int] = 224,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer: Optional[Callable] = None,
flatten: bool = True,
output_fmt: Optional[str] = None,
bias: bool = True,
strict_img_size: bool = True,
dynamic_img_pad: bool = False,
):
super().__init__()
self.patch_size = to_2tuple(patch_size)
if img_size is not None:
self.img_size = to_2tuple(img_size)
self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)])
self.num_patches = self.grid_size[0] * self.grid_size[1]
else:
self.img_size = None
self.grid_size = None
self.num_patches = None
if output_fmt is not None:
self.flatten = False
self.output_fmt = Format(output_fmt)
else:
# flatten spatial dim and transpose to channels last, kept for bwd compat
self.flatten = flatten
self.output_fmt = Format.NCHW
self.strict_img_size = strict_img_size
self.dynamic_img_pad = dynamic_img_pad
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]:
if as_scalar:
return max(self.patch_size)
else:
return self.patch_size
def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]:
""" Get grid (feature) size for given image size taking account of dynamic padding.
NOTE: must be torchscript compatible so using fixed tuple indexing
"""
if self.dynamic_img_pad:
return math.ceil(img_size[0] / self.patch_size[0]), math.ceil(img_size[1] / self.patch_size[1])
else:
return img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1]
def forward(self, x):
B, C, H, W = x.shape
if self.img_size is not None:
if self.strict_img_size:
_assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).")
_assert(W == self.img_size[1], f"Input width ({W}) doesn't match model ({self.img_size[1]}).")
elif not self.dynamic_img_pad:
_assert(
H % self.patch_size[0] == 0,
f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
)
_assert(
W % self.patch_size[1] == 0,
f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
)
if self.dynamic_img_pad:
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
x = F.pad(x, (0, pad_w, 0, pad_h))
x = self.proj(x) # (1,192,14,14)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
# (1,196,192), 将从第2维开始到最后一维的所有维度展平成一个单一的维度
elif self.output_fmt != Format.NCHW:
x = nchw_to(x, self.output_fmt)
return x
第二步就是添加class token和postion embedding。代码如下,其中cls_token和pos_embed都是通过nn.Parameter定义的可学习的参数,每张图片都有自己独立的一个class token,将class token拼接到上一步分块并展平后的特征embedding前,然后再与pos_embed相加。
- 添加cls_token。维度为 (batch_size, 1, embed_dim) 比如 (1, 1, 192)。这是一个常见的操作,特别是在 Transformer 模型中,其中 cls_token 用作分类标记(classification token),并与输入序列一起处理。通过扩展 cls_token,确保每个输入序列都有一个对应的分类标记,这样可以在后续的 Transformer 层中一起处理。后续是将cls_token拼接到特征前面,进行后续的操作。最后的classification head是接在cls_token后面的,即特征的第0维。
- 位置编码。learnable 1D position embedding,维度是 (1, seq_len, embed_dim),初始化时乘以0.02缩小初始化权重的范围。这是为了避免初始值过大,从而有助于训练的稳定性。此外,pos_embed是与特征相加,而不是像cls_token与特征拼接。另外,cls_token需要expand和batch_size大小,即每张图片都有一个单独的cls_token,而pos_embed是共享的,不用expand到batch_size。
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02) # 缩小初始化权重的范围。这是为了避免初始值过大,从而有助于训练的稳定性。
def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
if self.dynamic_img_size: # False
B, H, W, C = x.shape
pos_embed = resample_abs_pos_embed(
self.pos_embed,
(H, W),
num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
)
x = x.view(B, -1, C)
else:
pos_embed = self.pos_embed
to_cat = []
if self.cls_token is not None: # not None
to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
# 具体来说,cls_token.expand(x.shape[0], -1, -1) 扩展了 cls_token,使其第一个维度与输入张量x的第一个维度大小相同(通常是批次大小)。
# 这样一来,每个样本都将有一个独立的cls_token。举例来说,如果x的形状是(batch_size, seq_len, embed_dim),那么扩展后的cls_token的形状将是 (batch_size, 1, embed_dim)。
# 这是一个常见的操作,特别是在 Transformer 模型中,其中 cls_token 用作分类标记(classification token),并与输入序列一起处理。通过扩展 cls_token,确保每个输入序列都有一个对应的分类标记,这样可以在后续的 Transformer 层中一起处理。
if self.reg_token is not None: # None
to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
if self.no_embed_class: # False
# deit-3, updated JAX (big vision)
# position embedding does not overlap with class token, add then concat
x = x + pos_embed
if to_cat:
x = torch.cat(to_cat + [x], dim=1)
else:
# original timm, JAX, and deit vit impl
# pos_embed has entry for class token, concat then add
if to_cat:
x = torch.cat(to_cat + [x], dim=1) # [(1,1,92)] + [(1,196,192)] -> (1,197,192)
x = x + pos_embed # (1,197,192)
return self.pos_drop(x)
第三步就是经过若干个block,代码如下,图1右侧是单个block的示例。forward函数中norm1和norm2都是LayerNorm,ls1和ls2都是nn.Identity,drop_path1和drop_path2也都是nn.Identity,核心实现就是一个attention加一个mlp。
class Block(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.,
qkv_bias: bool = False,
qk_norm: bool = False,
proj_drop: float = 0.,
attn_drop: float = 0.,
init_values: Optional[float] = None,
drop_path: float = 0.,
act_layer: nn.Module = nn.GELU,
norm_layer: nn.Module = nn.LayerNorm,
mlp_layer: nn.Module = Mlp,
) -> None:
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
attn_drop=attn_drop,
proj_drop=proj_drop,
norm_layer=norm_layer,
)
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() # None
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() # 0
self.norm2 = norm_layer(dim)
self.mlp = mlp_layer(
in_features=dim, # 192
hidden_features=int(dim * mlp_ratio), # 4
act_layer=act_layer, # GELU
drop=proj_drop, # 0
)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
return x
我们首先来看attention的实现,代码如下。首先self.qkv是一个线性矩阵操作,相当于三个不同的映射矩阵分别与输入相乘得到q、k、v。然后从L42-L46是计算attention的过程,对应式 \(attention=softmax(\frac{q\cdot k^T}{\sqrt{d_k}})\)。最后再经过一个线性层self.proj就得到了attention的输出。
class Attention(nn.Module):
fused_attn: Final[bool]
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = False,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: nn.Module = nn.LayerNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads # 3
self.head_dim = dim // num_heads # 192/3=64
self.scale = self.head_dim ** -0.5
self.fused_attn = use_fused_attn()
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() # False
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop) # 0
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop) # 0
def forward(self, x: torch.Tensor) -> torch.Tensor: # (1,197,192)
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
# (1,197,576)->(1,197,3,3,64)->(3,1,3,197,64), (3, batch_size, num_heads, seq_len, head_dim), 3表示qkv
q, k, v = qkv.unbind(0) # (1,3,197,64)
q, k = self.q_norm(q), self.k_norm(k)
if self.fused_attn: # False
x = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p if self.training else 0.,
)
else:
# attn=softmax(qk)
q = q * self.scale
attn = q @ k.transpose(-2, -1) # (1,3,197,197)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v # (1,3,197,64)
x = x.transpose(1, 2).reshape(B, N, C) # (1,197,3,64)->(1,197,192)
x = self.proj(x) # (1,197,192)
x = self.proj_drop(x)
return x
mlp就更简单了,就是两个全连接层。
class Mlp(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
norm_layer=None,
bias=True,
drop=0.,
use_conv=False,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
bias = to_2tuple(bias)
drop_probs = to_2tuple(drop)
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.norm(x)
x = self.fc2(x)
x = self.drop2(x)
return x
最后就是取出class token,用一个线性层映射到num_classes维度。
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
x = x[:, 0] # class token
x = self.head(x)