从ViT到CvT:pytorch-image-models中的卷积视觉Transformer
你是否在处理图像分类任务时遇到过这些问题:传统卷积神经网络(CNN)难以捕捉长距离依赖关系,而纯Transformer模型(如ViT)又对数据量要求极高且计算成本昂贵?本文将带你深入了解pytorch-image-models库中两种革命性的视觉模型——Vision Transformer(ViT)和CrossViT(CvT),看看它们如何通过创新设计解决这些痛点。读完本文,你将掌握两者的核心原理、结构差异及在实际项目中的应用方法。
ViT:开创视觉Transformer时代
Vision Transformer(ViT)是由Google团队在2020年提出的开创性模型,它首次将Transformer架构成功应用于计算机视觉领域。与传统CNN不同,ViT直接将图像分割为 patches,通过自注意力机制捕捉全局特征。
ViT的核心结构
ViT的核心结构主要包括以下几个部分:
-
Patch Embedding(补丁嵌入):将输入图像分割为固定大小的补丁,并将每个补丁线性投影到高维空间。这一步对应代码中的
PatchEmbed类,位于timm/models/vision_transformer.py文件中。 -
位置嵌入(Positional Embedding):为每个补丁添加位置信息,使模型能够感知图像的空间结构。
-
Transformer编码器:由多个Transformer块组成,每个块包含多头自注意力机制和前馈神经网络。代码中的
Block类实现了这一结构。 -
分类头:通常使用一个额外的类标记(class token)或全局池化操作,将Transformer编码器的输出映射到分类结果。
ViT的代码实现
在pytorch-image-models库中,ViT的实现位于timm/models/vision_transformer.py文件。以下是ViT模型的核心代码片段:
class VisionTransformer(nn.Module):
""" Vision Transformer
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
- https://arxiv.org/abs/2010.11929
"""
def __init__(
self,
img_size: Union[int, Tuple[int, int]] = 224,
patch_size: Union[int, Tuple[int, int]] = 16,
in_chans: int = 3,
num_classes: int = 1000,
global_pool: Literal['', 'avg', 'avgmax', 'max', 'token', 'map'] = 'token',
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.,
qkv_bias: bool = True,
qk_norm: bool = False,
# ... 其他参数
) -> None:
super().__init__()
# ... 初始化代码
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
# ... 前向传播代码
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
# ... 分类头代码
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.forward_features(x)
x = self.forward_head(x)
return x
ViT的创新之处在于将图像视为序列数据,直接应用Transformer架构。然而,纯Transformer架构缺乏CNN固有的归纳偏置(如平移不变性),导致在数据量有限时性能不如CNN。为了解决这一问题,CrossViT应运而生。
CrossViT:融合卷积与Transformer的创新设计
CrossViT(Cross-Attention Multi-Scale Vision Transformer)是对ViT的改进,它通过引入多尺度特征提取和跨注意力机制,有效结合了CNN的局部特征提取能力和Transformer的全局依赖建模能力。
CrossViT的核心改进
-
多尺度补丁嵌入:CrossViT使用不同大小的补丁(如12x12和16x16)提取多尺度特征,这类似于CNN中的多尺度卷积操作。
-
跨注意力融合:通过跨注意力机制(Cross-Attention)融合不同尺度的特征,使模型能够同时捕捉局部细节和全局上下文。
-
并行分支结构:采用并行的Transformer分支处理不同尺度的特征,然后通过融合模块整合结果,提高模型的表达能力。
CrossViT的代码实现
CrossViT的实现位于timm/models/crossvit.py文件。以下是其核心结构的代码片段:
class CrossViT(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(
self,
img_size=224,
img_scale=(1.0, 1.0),
patch_size=(8, 16),
in_chans=3,
num_classes=1000,
embed_dim=(192, 384),
depth=((1, 3, 1), (1, 3, 1), (1, 3, 1)),
num_heads=(6, 12),
mlp_ratio=(2., 2., 4.),
# ... 其他参数
):
super().__init__()
# ... 初始化代码
def forward_features(self, x) -> List[torch.Tensor]:
B = x.shape[0]
xs = []
for i, patch_embed in enumerate(self.patch_embed):
x_ = x
ss = self.img_size_scaled[i]
x_ = scale_image(x_, ss, self.crop_scale)
x_ = patch_embed(x_)
cls_tokens = self.cls_token_0 if i == 0 else self.cls_token_1
cls_tokens = cls_tokens.expand(B, -1, -1)
x_ = torch.cat((cls_tokens, x_), dim=1)
pos_embed = self.pos_embed_0 if i == 0 else self.pos_embed_1
x_ = x_ + pos_embed
x_ = self.pos_drop(x_)
xs.append(x_)
for i, blk in enumerate(self.blocks):
xs = blk(xs)
xs = [norm(xs[i]) for i, norm in enumerate(self.norm)]
return xs
def forward_head(self, xs: List[torch.Tensor], pre_logits: bool = False) -> torch.Tensor:
xs = [x[:, 1:].mean(dim=1) for x in xs] if self.global_pool == 'avg' else [x[:, 0] for x in xs]
xs = [self.head_drop(x) for x in xs]
if pre_logits or isinstance(self.head[0], nn.Identity):
return torch.cat([x for x in xs], dim=1)
return torch.mean(torch.stack([head(xs[i]) for i, head in enumerate(self.head)], dim=0), dim=0)
CrossViT的多尺度特征融合
CrossViT通过MultiScaleBlock实现多尺度特征的融合,代码如下:
class MultiScaleBlock(nn.Module):
def __init__(
self,
dim,
patches,
depth,
num_heads,
mlp_ratio,
qkv_bias=False,
proj_drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
):
super().__init__()
# ... 初始化代码
def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
outs_b = []
for i, block in enumerate(self.blocks):
outs_b.append(block(x[i]))
proj_cls_token = torch.jit.annotate(List[torch.Tensor], [])
for i, proj in enumerate(self.projs):
proj_cls_token.append(proj(outs_b[i][:, 0:1, ...]))
outs = []
for i, (fusion, revert_proj) in enumerate(zip(self.fusion, self.revert_projs)):
tmp = torch.cat((proj_cls_token[i], outs_b[(i + 1) % self.num_branches][:, 1:, ...]), dim=1)
tmp = fusion(tmp)
reverted_proj_cls_token = revert_proj(tmp[:, 0:1, ...])
tmp = torch.cat((reverted_proj_cls_token, outs_b[i][:, 1:, ...]), dim=1)
outs.append(tmp)
return outs
这一结构使CrossViT能够有效整合不同尺度的特征信息,在保持Transformer全局建模能力的同时,增强了对局部细节的捕捉能力。
ViT与CrossViT的性能对比
根据pytorch-image-models库的官方文档,CrossViT在多个图像分类数据集上表现出优于ViT的性能。特别是在数据量有限的情况下,CrossViT的多尺度设计和卷积特征提取能力使其能够更快收敛并取得更高的准确率。
官方提供的模型性能数据可以在hfdocs/source/models.mdx中找到。这些数据显示,CrossViT在ImageNet数据集上的Top-1准确率比同量级的ViT模型高出2-3个百分点,同时计算效率也有所提升。
如何在项目中使用ViT和CrossViT
pytorch-image-models库提供了简单易用的API,使开发者能够轻松使用预训练的ViT和CrossViT模型。以下是使用示例:
import torch
from timm import create_model
# 加载预训练的ViT模型
vit_model = create_model('vit_base_patch16_224', pretrained=True)
vit_model.eval()
# 加载预训练的CrossViT模型
crossvit_model = create_model('crossvit_15_240', pretrained=True)
crossvit_model.eval()
# 准备输入图像
input_tensor = torch.randn(1, 3, 224, 224) # 示例输入,实际使用时需预处理
# 使用ViT进行推理
with torch.no_grad():
vit_output = vit_model(input_tensor)
crossvit_output = crossvit_model(input_tensor)
print("ViT输出形状:", vit_output.shape)
print("CrossViT输出形状:", crossvit_output.shape)
在实际应用中,你可以根据任务需求和计算资源选择合适的模型。如果你的任务需要处理高分辨率图像或对局部细节敏感,CrossViT可能是更好的选择;如果数据量充足且追求极致的全局特征建模,ViT也是一个不错的选择。
总结与展望
从ViT到CrossViT,我们看到了视觉Transformer模型的快速发展。ViT开创了将Transformer应用于计算机视觉的先河,而CrossViT则通过融合卷积和Transformer的优势,进一步提升了模型性能和效率。
随着研究的深入,我们有理由相信未来会出现更多创新的视觉Transformer模型。pytorch-image-models库作为一个活跃的开源项目,将持续整合这些最新研究成果,为开发者提供更多强大的视觉模型选择。
如果你想了解更多关于ViT和CrossViT的技术细节,可以查阅以下资源:
- ViT论文:An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
- CrossViT论文:CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification
- pytorch-image-models官方文档:hfdocs/source/models.mdx
希望本文能够帮助你更好地理解和应用这两种先进的视觉Transformer模型。如果你有任何问题或建议,欢迎在项目的GitHub仓库提交issue或PR,让我们一起推动视觉识别技术的发展!
如果你觉得本文对你有帮助,请点赞、收藏并关注我们的项目,以便获取更多关于视觉模型的最新资讯和教程。下期我们将介绍如何使用pytorch-image-models库进行模型微调,敬请期待!
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考



