从MobileViT到EfficientViT:pytorch-image-models中的移动Transformer
在移动设备上部署高性能视觉模型一直是开发者面临的挑战。传统的卷积神经网络(CNN)虽然在移动设备上表现出良好的效率,但在复杂视觉任务中往往难以与Transformer模型抗衡。而Transformer模型虽然性能强大,但计算复杂度和内存占用较高,难以直接应用于移动场景。
pytorch-image-models(简称timm)库中集成了多种专为移动设备设计的高效Transformer模型,其中MobileViT和EfficientViT系列尤为引人注目。这些模型通过创新的网络结构设计,在保持高性能的同时大幅降低了计算成本,为移动视觉应用开辟了新的可能性。
MobileViT:CNN与Transformer的完美融合
MobileViT是由Apple团队提出的一种新型视觉Transformer模型,旨在弥合CNN和Transformer之间的性能差距,同时保持移动设备友好的特性。
MobileViT的核心创新
MobileViT的核心思想是将CNN的局部特征提取能力与Transformer的全局建模能力相结合。它通过以下关键组件实现这一目标:
-
局部表示模块:使用深度可分离卷积(DSConv)提取局部特征,这与MobileNet等高效CNN模型的设计理念一致。
-
全局表示模块:将特征图分割为非重叠补丁,然后通过Transformer处理这些补丁,以捕获长距离依赖关系。
-
融合模块:将局部和全局特征进行融合,以充分利用两种表示的优势。
MobileViT的网络结构
MobileViT的网络结构可以分为几个主要部分:
- ** stem模块 **:使用3x3卷积层对输入图像进行初始特征提取。
2.** 多个MobileViT块 **:每个块由一个倒残差块(Inverted Residual Block)和一个MobileViT块组成。倒残差块用于局部特征提取,而MobileViT块则用于全局特征建模。
3.** 分类头 **:使用全局平均池化和全连接层进行最终分类。
MobileViT的具体实现可以在timm/models/mobilevit.py中找到。以下是MobileViT块的核心代码:
class MobileVitBlock(nn.Module):
""" MobileViT block
Paper: https://arxiv.org/abs/2110.02178?context=cs.LG
"""
def __init__(
self,
in_chs: int,
out_chs: Optional[int] = None,
kernel_size: int = 3,
stride: int = 1,
bottle_ratio: float = 1.0,
group_size: Optional[int] = None,
dilation: Tuple[int, int] = (1, 1),
mlp_ratio: float = 2.0,
transformer_dim: Optional[int] = None,
transformer_depth: int = 2,
patch_size: int = 8,
num_heads: int = 4,
attn_drop: float = 0.,
drop: int = 0.,
no_fusion: bool = False,
drop_path_rate: float = 0.,
layers: LayerFn = None,
transformer_norm_layer: Callable = nn.LayerNorm,** kwargs, # eat unused args
):
super(MobileVitBlock, self).__init__()
layers = layers or LayerFn()
groups = num_groups(group_size, in_chs)
out_chs = out_chs or in_chs
transformer_dim = transformer_dim or make_divisible(bottle_ratio * in_chs)
self.conv_kxk = layers.conv_norm_act(
in_chs, in_chs, kernel_size=kernel_size,
stride=stride, groups=groups, dilation=dilation[0])
self.conv_1x1 = nn.Conv2d(in_chs, transformer_dim, kernel_size=1, bias=False)
self.transformer = nn.Sequential(*[
TransformerBlock(
transformer_dim,
mlp_ratio=mlp_ratio,
num_heads=num_heads,
qkv_bias=True,
attn_drop=attn_drop,
proj_drop=drop,
drop_path=drop_path_rate,
act_layer=layers.act,
norm_layer=transformer_norm_layer,
)
for _ in range(transformer_depth)
])
self.norm = transformer_norm_layer(transformer_dim)
self.conv_proj = layers.conv_norm_act(transformer_dim, out_chs, kernel_size=1, stride=1)
if no_fusion:
self.conv_fusion = None
else:
self.conv_fusion = layers.conv_norm_act(in_chs + out_chs, out_chs, kernel_size=kernel_size, stride=1)
self.patch_size = to_2tuple(patch_size)
self.patch_area = self.patch_size[0] * self.patch_size[1]
MobileViT的前向传播
MobileViT的前向传播过程体现了其融合CNN和Transformer的核心思想:
def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = x
# Local representation
x = self.conv_kxk(x)
x = self.conv_1x1(x)
# Unfold (feature map -> patches)
patch_h, patch_w = self.patch_size
B, C, H, W = x.shape
new_h, new_w = math.ceil(H / patch_h) * patch_h, math.ceil(W / patch_w) * patch_w
num_patch_h, num_patch_w = new_h // patch_h, new_w // patch_w # n_h, n_w
num_patches = num_patch_h * num_patch_w # N
interpolate = False
if new_h != H or new_w != W:
# Note: Padding can be done, but then it needs to be handled in attention function.
x = F.interpolate(x, size=(new_h, new_w), mode="bilinear", align_corners=False)
interpolate = True
# [B, C, H, W] --> [B * C * n_h, n_w, p_h, p_w]
x = x.reshape(B * C * num_patch_h, patch_h, num_patch_w, patch_w).transpose(1, 2)
# [B * C * n_h, n_w, p_h, p_w] --> [BP, N, C] where P = p_h * p_w and N = n_h * n_w
x = x.reshape(B * C * num_patch_h, num_patch_w * patch_h * patch_w).transpose(1, 2).reshape(B * self.patch_area, num_patches, -1)
# Global representations
x = self.transformer(x)
x = self.norm(x)
# Fold (patch -> feature map)
# [B, P, N, C] --> [B*C*n_h, n_w, p_h, p_w]
x = x.contiguous().view(B, self.patch_area, num_patches, -1)
x = x.transpose(1, 3).reshape(B * C * num_patch_h, num_patch_w, patch_h, patch_w)
# [B*C*n_h, n_w, p_h, p_w] --> [B*C*n_h, p_h, n_w, p_w] --> [B, C, H, W]
x = x.transpose(1, 2).reshape(B, C, num_patch_h * patch_h, num_patch_w * patch_w)
if interpolate:
x = F.interpolate(x, size=(H, W), mode="bilinear", align_corners=False)
x = self.conv_proj(x)
if self.conv_fusion is not None:
x = self.conv_fusion(torch.cat((shortcut, x), dim=1))
return x
在forward方法中,输入特征首先通过卷积层进行局部特征提取,然后被重塑为补丁序列并送入Transformer模块进行全局特征建模。处理后的特征被折叠回原始空间维度,并与 shortcut 连接进行特征融合。
MobileViT的改进版本:MobileViT v2
MobileViT v2是MobileViT的改进版本,主要引入了线性注意力机制(Linear Self-Attention)来进一步提高效率。这一改进使得模型在保持性能的同时,计算复杂度从O(n²)降低到O(n),其中n是序列长度。
线性注意力的核心实现如下:
class LinearSelfAttention(nn.Module):
"""
This layer applies a self-attention with linear complexity, as described in `https://arxiv.org/abs/2206.02680`
This layer can be used for self- as well as cross-attention.
"""
def __init__(
self,
embed_dim: int,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
bias: bool = True,
) -> None:
super().__init__()
self.embed_dim = embed_dim
self.qkv_proj = nn.Conv2d(
in_channels=embed_dim,
out_channels=1 + (2 * embed_dim),
bias=bias,
kernel_size=1,
)
self.attn_drop = nn.Dropout(attn_drop)
self.out_proj = nn.Conv2d(
in_channels=embed_dim,
out_channels=embed_dim,
bias=bias,
kernel_size=1,
)
self.out_drop = nn.Dropout(proj_drop)
def _forward_self_attn(self, x: torch.Tensor) -> torch.Tensor:
# [B, C, P, N] --> [B, h + 2d, P, N]
qkv = self.qkv_proj(x)
# Project x into query, key and value
# Query --> [B, 1, P, N]
# value, key --> [B, d, P, N]
query, key, value = qkv.split([1, self.embed_dim, self.embed_dim], dim=1)
# apply softmax along N dimension
context_scores = F.softmax(query, dim=-1)
context_scores = self.attn_drop(context_scores)
# Compute context vector
# [B, d, P, N] x [B, 1, P, N] -> [B, d, P, N] --> [B, d, P, 1]
context_vector = (key * context_scores).sum(dim=-1, keepdim=True)
# combine context vector with values
# [B, d, P, N] * [B, d, P, 1] --> [B, d, P, N]
out = F.relu(value) * context_vector.expand_as(value)
out = self.out_proj(out)
out = self.out_drop(out)
return out
通过这种方式,MobileViT v2能够在保持全局建模能力的同时,显著降低计算复杂度,使其更适合在移动设备上部署。
EfficientViT:迈向更高效率的视觉Transformer
继MobileViT之后,EfficientViT系列进一步推动了移动视觉Transformer的发展。timm库中包含了来自MIT和微软亚洲研究院(MSRA)的两种EfficientViT实现,它们各自采用了不同的优化策略。
MIT的EfficientViT:轻量级多尺度线性注意力
MIT提出的EfficientViT(在timm中对应efficientvit_mit.py)引入了轻量级多尺度线性注意力(LiteMLA)机制,旨在在保持高性能的同时进一步提高效率。
LiteMLA的核心思想
LiteMLA的核心创新在于:
-
多尺度特征聚合:通过不同尺度的深度卷积对查询、键和值进行处理,以捕获多尺度上下文信息。
-
线性注意力:采用线性复杂度的注意力机制,避免了标准Transformer中的二次复杂度。
-
高效投影:使用1x1卷积对注意力输出进行投影,以融合不同尺度的特征。
以下是LiteMLA的核心实现代码:
class LiteMLA(nn.Module):
"""Lightweight multi-scale linear attention"""
def __init__(
self,
in_channels: int,
out_channels: int,
heads: int or None = None,
heads_ratio: float = 1.0,
dim=8,
use_bias=False,
norm_layer=(None, nn.BatchNorm2d),
act_layer=(None, None),
kernel_func=nn.ReLU,
scales=(5,),
eps=1e-5,
):
super(LiteMLA, self).__init__()
self.eps = eps
heads = heads or int(in_channels // dim * heads_ratio)
total_dim = heads * dim
use_bias = val2tuple(use_bias, 2)
norm_layer = val2tuple(norm_layer, 2)
act_layer = val2tuple(act_layer, 2)
self.dim = dim
self.qkv = ConvNormAct(
in_channels,
3 * total_dim,
1,
bias=use_bias[0],
norm_layer=norm_layer[0],
act_layer=act_layer[0],
)
self.aggreg = nn.ModuleList([
nn.Sequential(
nn.Conv2d(
3 * total_dim,
3 * total_dim,
scale,
padding=get_same_padding(scale),
groups=3 * total_dim,
bias=use_bias[0],
),
nn.Conv2d(3 * total_dim, 3 * total_dim, 1, groups=3 * heads, bias=use_bias[0]),
)
for scale in scales
])
self.kernel_func = kernel_func(inplace=False)
self.proj = ConvNormAct(
total_dim * (1 + len(scales)),
out_channels,
1,
bias=use_bias[1],
norm_layer=norm_layer[1],
act_layer=act_layer[1],
)
def _attn(self, q, k, v):
dtype = v.dtype
q, k, v = q.float(), k.float(), v.float()
kv = k.transpose(-1, -2) @ v
out = q @ kv
out = out[..., :-1] / (out[..., -1:] + self.eps)
return out.to(dtype)
def forward(self, x):
B, _, H, W = x.shape
# generate multi-scale q, k, v
qkv = self.qkv(x)
multi_scale_qkv = [qkv]
for op in self.aggreg:
multi_scale_qkv.append(op(qkv))
multi_scale_qkv = torch.cat(multi_scale_qkv, dim=1)
multi_scale_qkv = multi_scale_qkv.reshape(B, -1, 3 * self.dim, H * W).transpose(-1, -2)
q, k, v = multi_scale_qkv.chunk(3, dim=-1)
# lightweight global attention
q = self.kernel_func(q)
k = self.kernel_func(k)
v = F.pad(v, (0, 1), mode="constant", value=1.)
if not torch.jit.is_scripting():
with torch.autocast(device_type=v.device.type, enabled=False):
out = self._attn(q, k, v)
else:
out = self._attn(q, k, v)
# final projection
out = out.transpose(-1, -2).reshape(B, -1, H, W)
out = self.proj(out)
return out
在EfficientVitBlock中,LiteMLA与MBConv(Mobile Inverted Residual Block)相结合,形成了一个高效的混合结构:
class EfficientVitBlock(nn.Module):
def __init__(
self,
in_channels,
heads_ratio=1.0,
head_dim=32,
expand_ratio=4,
norm_layer=nn.BatchNorm2d,
act_layer=nn.Hardswish,
):
super(EfficientVitBlock, self).__init__()
self.context_module = ResidualBlock(
LiteMLA(
in_channels=in_channels,
out_channels=in_channels,
heads_ratio=heads_ratio,
dim=head_dim,
norm_layer=(None, norm_layer),
),
nn.Identity(),
)
self.local_module = ResidualBlock(
MBConv(
in_channels=in_channels,
out_channels=in_channels,
expand_ratio=expand_ratio,
use_bias=(True, True, False),
norm_layer=(None, None, norm_layer),
act_layer=(act_layer, act_layer, None),
),
nn.Identity(),
)
def forward(self, x):
x = self.context_module(x)
x = self.local_module(x)
return x
这种结构设计使得EfficientViT能够在保持高性能的同时,显著降低计算复杂度和内存占用,非常适合移动设备部署。
MSRA的EfficientViT:级联组注意力机制
微软亚洲研究院提出的EfficientViT(在timm中对应efficientvit_msra.py)则采用了另一种创新思路——级联组注意力(Cascaded Group Attention)机制。
级联组注意力的核心思想
级联组注意力的主要创新点包括:
-
分组注意力:将输入特征分成多个组,每个组独立进行注意力计算,以降低计算复杂度。
-
级联融合:不同组的注意力输出进行级联融合,以捕获跨组的依赖关系。
-
局部窗口注意力:将特征图分割为局部窗口,在每个窗口内进行注意力计算,进一步降低计算量。
以下是级联组注意力的核心实现代码:
class CascadedGroupAttention(torch.nn.Module):
attention_bias_cache: Dict[str, torch.Tensor]
r""" Cascaded Group Attention.
Args:
dim (int): Number of input channels.
key_dim (int): The dimension for query and key.
num_heads (int): Number of attention heads.
attn_ratio (int): Multiplier for the query dim for value dimension.
resolution (int): Input resolution, correspond to the window size.
kernels (List[int]): The kernel size of the dw conv on query.
"""
def __init__(
self,
dim,
key_dim,
num_heads=8,
attn_ratio=4,
resolution=14,
kernels=(5, 5, 5, 5),
):
super().__init__()
self.num_heads = num_heads
self.scale = key_dim ** -0.5
self.key_dim = key_dim
self.val_dim = int(attn_ratio * key_dim)
self.attn_ratio = attn_ratio
qkvs = []
dws = []
for i in range(num_heads):
qkvs.append(ConvNorm(dim // (num_heads), self.key_dim * 2 + self.val_dim))
dws.append(ConvNorm(self.key_dim, self.key_dim, kernels[i], 1, kernels[i] // 2, groups=self.key_dim))
self.qkvs = torch.nn.ModuleList(qkvs)
self.dws = torch.nn.ModuleList(dws)
self.proj = torch.nn.Sequential(
torch.nn.ReLU(),
ConvNorm(self.val_dim * num_heads, dim, bn_weight_init=0)
)
points = list(itertools.product(range(resolution), range(resolution)))
N = len(points)
attention_offsets = {}
idxs = []
for p1 in points:
for p2 in points:
offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
if offset not in attention_offsets:
attention_offsets[offset] = len(attention_offsets)
idxs.append(attention_offsets[offset])
self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N), persistent=False)
self.attention_bias_cache = {}
def forward(self, x):
B, C, H, W = x.shape
feats_in = x.chunk(len(self.qkvs), dim=1)
feats_out = []
feat = feats_in[0]
attn_bias = self.get_attention_biases(x.device)
for head_idx, (qkv, dws) in enumerate(zip(self.qkvs, self.dws)):
if head_idx > 0:
feat = feat + feats_in[head_idx]
feat = qkv(feat)
q, k, v = feat.view(B, -1, H, W).split([self.key_dim, self.key_dim, self.val_dim], dim=1)
q = dws(q)
q, k, v = q.flatten(2), k.flatten(2), v.flatten(2)
q = q * self.scale
attn = q.transpose(-2, -1) @ k
attn = attn + attn_bias[head_idx]
attn = attn.softmax(dim=-1)
feat = v @ attn.transpose(-2, -1)
feat = feat.view(B, self.val_dim, H, W)
feats_out.append(feat)
x = self.proj(torch.cat(feats_out, 1))
return x
级联组注意力通过将输入特征分组处理,并在每个组内进行注意力计算,然后融合各组结果,有效降低了计算复杂度。同时,引入位置偏移偏置(attention_biases)来建模空间关系,进一步提升了模型性能。
MobileViT与EfficientViT的对比分析
MobileViT和EfficientViT系列作为移动友好型视觉Transformer的代表,各有其独特的设计理念和优势。
架构设计对比
| 模型 | 核心创新 | 计算复杂度 | 内存占用 | 适用场景 |
|---|---|---|---|---|
| MobileViT | CNN-Transformer混合结构 | 中 | 中 | 对性能和效率有均衡要求的场景 |
| MobileViT v2 | 线性注意力 | 低 | 低 | 资源受限的移动设备 |
| EfficientViT (MIT) | 轻量级多尺度线性注意力 | 低 | 中 | 需要多尺度特征的应用 |
| EfficientViT (MSRA) | 级联组注意力 | 中 | 低 | 对内存敏感的场景 |
性能对比
虽然具体的性能对比需要在统一的基准测试上进行,但根据模型设计和官方报告,我们可以得出以下大致结论:
-
精度:EfficientViT系列在大多数视觉任务上略优于MobileViT,尤其是在高分辨率图像上。
-
速度:MobileViT v2和MIT的EfficientViT在移动设备上的推理速度更快。
-
内存效率:MSRA的EfficientViT在内存使用上更具优势,适合内存受限的设备。
适用场景推荐
-
MobileViT:适合对性能和效率有均衡要求的通用移动视觉应用。
-
MobileViT v2:适合资源极其受限的低端移动设备。
-
EfficientViT (MIT):适合需要处理多尺度特征的应用,如目标检测和语义分割。
-
EfficientViT (MSRA):适合内存受限但对性能有较高要求的场景,如实时视频处理。
如何在pytorch-image-models中使用这些模型
pytorch-image-models库为这些高效移动Transformer模型提供了统一的接口,使得开发者可以轻松地使用和比较不同的模型。
模型加载与推理
以下是使用MobileViT和EfficientViT进行图像分类的示例代码:
import torch
from timm import create_model
from PIL import Image
import torchvision.transforms as transforms
# 加载预训练模型
model_mobilevit = create_model('mobilevit_s', pretrained=True)
model_efficientvit_mit = create_model('efficientvit_b0', pretrained=True)
model_efficientvit_msra = create_model('efficientvit_m0', pretrained=True)
# 设置为评估模式
model_mobilevit.eval()
model_efficientvit_mit.eval()
model_efficientvit_msra.eval()
# 图像预处理
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# 加载并预处理图像
image = Image.open("test_image.jpg")
image = transform(image).unsqueeze(0)
# 推理
with torch.no_grad():
output_mobilevit = model_mobilevit(image)
output_efficientvit_mit = model_efficientvit_mit(image)
output_efficientvit_msra = model_efficientvit_msra(image)
# 获取预测结果
pred_mobilevit = torch.argmax(output_mobilevit, dim=1)
pred_efficientvit_mit = torch.argmax(output_efficientvit_mit, dim=1)
pred_efficientvit_msra = torch.argmax(output_efficientvit_msra, dim=1)
print(f"MobileViT prediction: {pred_mobilevit.item()}")
print(f"EfficientViT (MIT) prediction: {pred_efficientvit_mit.item()}")
print(f"EfficientViT (MSRA) prediction: {pred_efficientvit_msra.item()}")
模型配置与定制
timm库还允许开发者根据需求定制模型配置,例如:
# 定制MobileViT模型
custom_mobilevit = create_model(
'mobilevit_s',
pretrained=True,
num_classes=10, # 修改分类头以适应自定义数据集
drop_rate=0.2, # 调整 dropout 率
img_size=384 # 支持更高分辨率输入
)
# 定制EfficientViT模型
custom_efficientvit = create_model(
'efficientvit_m4',
pretrained=True,
num_classes=100,
drop_path_rate=0.1
)
总结与展望
从MobileViT到EfficientViT,移动视觉Transformer的发展见证了研究者们在效率与性能之间寻求平衡的不懈努力。这些模型通过创新的注意力机制设计、网络结构优化和混合建模策略,不断推动着移动视觉智能的边界。
pytorch-image-models库将这些先进模型集中整合,为开发者提供了便捷的工具来探索和应用这些技术。无论是MobileViT的CNN-Transformer融合思想,还是EfficientViT的高效注意力机制,都为移动视觉应用开辟了新的可能性。
未来,随着模型压缩、量化技术和硬件加速的进一步发展,我们有理由相信移动视觉Transformer将在更多领域得到应用,从智能手机到物联网设备,从增强现实到自动驾驶,为我们的生活带来更多智能和便利。
作为开发者,我们应该密切关注这些模型的发展,根据具体应用场景选择最合适的模型,并通过实践不断优化模型在特定任务上的性能和效率。
官方文档:README.md 模型实现:timm/models/
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考



