从MobileViT到EfficientViT:pytorch-image-models中的移动Transformer

从MobileViT到EfficientViT:pytorch-image-models中的移动Transformer

【免费下载链接】pytorch-image-models huggingface/pytorch-image-models: 是一个由 Hugging Face 开发维护的 PyTorch 视觉模型库,包含多个高性能的预训练模型,适用于图像识别、分类等视觉任务。 【免费下载链接】pytorch-image-models 项目地址: https://gitcode.com/GitHub_Trending/py/pytorch-image-models

在移动设备上部署高性能视觉模型一直是开发者面临的挑战。传统的卷积神经网络(CNN)虽然在移动设备上表现出良好的效率,但在复杂视觉任务中往往难以与Transformer模型抗衡。而Transformer模型虽然性能强大,但计算复杂度和内存占用较高,难以直接应用于移动场景。

pytorch-image-models(简称timm)库中集成了多种专为移动设备设计的高效Transformer模型,其中MobileViT和EfficientViT系列尤为引人注目。这些模型通过创新的网络结构设计,在保持高性能的同时大幅降低了计算成本,为移动视觉应用开辟了新的可能性。

MobileViT:CNN与Transformer的完美融合

MobileViT是由Apple团队提出的一种新型视觉Transformer模型,旨在弥合CNN和Transformer之间的性能差距,同时保持移动设备友好的特性。

MobileViT的核心创新

MobileViT的核心思想是将CNN的局部特征提取能力与Transformer的全局建模能力相结合。它通过以下关键组件实现这一目标:

  1. 局部表示模块:使用深度可分离卷积(DSConv)提取局部特征,这与MobileNet等高效CNN模型的设计理念一致。

  2. 全局表示模块:将特征图分割为非重叠补丁,然后通过Transformer处理这些补丁,以捕获长距离依赖关系。

  3. 融合模块:将局部和全局特征进行融合,以充分利用两种表示的优势。

MobileViT的网络结构

MobileViT的网络结构可以分为几个主要部分:

  1. ** stem模块 **:使用3x3卷积层对输入图像进行初始特征提取。

2.** 多个MobileViT块 **:每个块由一个倒残差块(Inverted Residual Block)和一个MobileViT块组成。倒残差块用于局部特征提取,而MobileViT块则用于全局特征建模。

3.** 分类头 **:使用全局平均池化和全连接层进行最终分类。

MobileViT的具体实现可以在timm/models/mobilevit.py中找到。以下是MobileViT块的核心代码:

class MobileVitBlock(nn.Module):
    """ MobileViT block
        Paper: https://arxiv.org/abs/2110.02178?context=cs.LG
    """
    def __init__(
            self,
            in_chs: int,
            out_chs: Optional[int] = None,
            kernel_size: int = 3,
            stride: int = 1,
            bottle_ratio: float = 1.0,
            group_size: Optional[int] = None,
            dilation: Tuple[int, int] = (1, 1),
            mlp_ratio: float = 2.0,
            transformer_dim: Optional[int] = None,
            transformer_depth: int = 2,
            patch_size: int = 8,
            num_heads: int = 4,
            attn_drop: float = 0.,
            drop: int = 0.,
            no_fusion: bool = False,
            drop_path_rate: float = 0.,
            layers: LayerFn = None,
            transformer_norm_layer: Callable = nn.LayerNorm,** kwargs,  # eat unused args
    ):
        super(MobileVitBlock, self).__init__()

        layers = layers or LayerFn()
        groups = num_groups(group_size, in_chs)
        out_chs = out_chs or in_chs
        transformer_dim = transformer_dim or make_divisible(bottle_ratio * in_chs)

        self.conv_kxk = layers.conv_norm_act(
            in_chs, in_chs, kernel_size=kernel_size,
            stride=stride, groups=groups, dilation=dilation[0])
        self.conv_1x1 = nn.Conv2d(in_chs, transformer_dim, kernel_size=1, bias=False)

        self.transformer = nn.Sequential(*[
            TransformerBlock(
                transformer_dim,
                mlp_ratio=mlp_ratio,
                num_heads=num_heads,
                qkv_bias=True,
                attn_drop=attn_drop,
                proj_drop=drop,
                drop_path=drop_path_rate,
                act_layer=layers.act,
                norm_layer=transformer_norm_layer,
            )
            for _ in range(transformer_depth)
        ])
        self.norm = transformer_norm_layer(transformer_dim)

        self.conv_proj = layers.conv_norm_act(transformer_dim, out_chs, kernel_size=1, stride=1)

        if no_fusion:
            self.conv_fusion = None
        else:
            self.conv_fusion = layers.conv_norm_act(in_chs + out_chs, out_chs, kernel_size=kernel_size, stride=1)

        self.patch_size = to_2tuple(patch_size)
        self.patch_area = self.patch_size[0] * self.patch_size[1]

MobileViT的前向传播

MobileViT的前向传播过程体现了其融合CNN和Transformer的核心思想:

def forward(self, x: torch.Tensor) -> torch.Tensor:
    shortcut = x

    # Local representation
    x = self.conv_kxk(x)
    x = self.conv_1x1(x)

    # Unfold (feature map -> patches)
    patch_h, patch_w = self.patch_size
    B, C, H, W = x.shape
    new_h, new_w = math.ceil(H / patch_h) * patch_h, math.ceil(W / patch_w) * patch_w
    num_patch_h, num_patch_w = new_h // patch_h, new_w // patch_w  # n_h, n_w
    num_patches = num_patch_h * num_patch_w  # N
    interpolate = False
    if new_h != H or new_w != W:
        # Note: Padding can be done, but then it needs to be handled in attention function.
        x = F.interpolate(x, size=(new_h, new_w), mode="bilinear", align_corners=False)
        interpolate = True

    # [B, C, H, W] --> [B * C * n_h, n_w, p_h, p_w]
    x = x.reshape(B * C * num_patch_h, patch_h, num_patch_w, patch_w).transpose(1, 2)
    # [B * C * n_h, n_w, p_h, p_w] --> [BP, N, C] where P = p_h * p_w and N = n_h * n_w
    x = x.reshape(B * C * num_patch_h, num_patch_w * patch_h * patch_w).transpose(1, 2).reshape(B * self.patch_area, num_patches, -1)

    # Global representations
    x = self.transformer(x)
    x = self.norm(x)

    # Fold (patch -> feature map)
    # [B, P, N, C] --> [B*C*n_h, n_w, p_h, p_w]
    x = x.contiguous().view(B, self.patch_area, num_patches, -1)
    x = x.transpose(1, 3).reshape(B * C * num_patch_h, num_patch_w, patch_h, patch_w)
    # [B*C*n_h, n_w, p_h, p_w] --> [B*C*n_h, p_h, n_w, p_w] --> [B, C, H, W]
    x = x.transpose(1, 2).reshape(B, C, num_patch_h * patch_h, num_patch_w * patch_w)
    if interpolate:
        x = F.interpolate(x, size=(H, W), mode="bilinear", align_corners=False)

    x = self.conv_proj(x)
    if self.conv_fusion is not None:
        x = self.conv_fusion(torch.cat((shortcut, x), dim=1))
    return x

在forward方法中,输入特征首先通过卷积层进行局部特征提取,然后被重塑为补丁序列并送入Transformer模块进行全局特征建模。处理后的特征被折叠回原始空间维度,并与 shortcut 连接进行特征融合。

MobileViT的改进版本:MobileViT v2

MobileViT v2是MobileViT的改进版本,主要引入了线性注意力机制(Linear Self-Attention)来进一步提高效率。这一改进使得模型在保持性能的同时,计算复杂度从O(n²)降低到O(n),其中n是序列长度。

线性注意力的核心实现如下:

class LinearSelfAttention(nn.Module):
    """
    This layer applies a self-attention with linear complexity, as described in `https://arxiv.org/abs/2206.02680`
    This layer can be used for self- as well as cross-attention.
    """

    def __init__(
        self,
        embed_dim: int,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,
        bias: bool = True,
    ) -> None:
        super().__init__()
        self.embed_dim = embed_dim

        self.qkv_proj = nn.Conv2d(
            in_channels=embed_dim,
            out_channels=1 + (2 * embed_dim),
            bias=bias,
            kernel_size=1,
        )
        self.attn_drop = nn.Dropout(attn_drop)
        self.out_proj = nn.Conv2d(
            in_channels=embed_dim,
            out_channels=embed_dim,
            bias=bias,
            kernel_size=1,
        )
        self.out_drop = nn.Dropout(proj_drop)

    def _forward_self_attn(self, x: torch.Tensor) -> torch.Tensor:
        # [B, C, P, N] --> [B, h + 2d, P, N]
        qkv = self.qkv_proj(x)

        # Project x into query, key and value
        # Query --> [B, 1, P, N]
        # value, key --> [B, d, P, N]
        query, key, value = qkv.split([1, self.embed_dim, self.embed_dim], dim=1)

        # apply softmax along N dimension
        context_scores = F.softmax(query, dim=-1)
        context_scores = self.attn_drop(context_scores)

        # Compute context vector
        # [B, d, P, N] x [B, 1, P, N] -> [B, d, P, N] --> [B, d, P, 1]
        context_vector = (key * context_scores).sum(dim=-1, keepdim=True)

        # combine context vector with values
        # [B, d, P, N] * [B, d, P, 1] --> [B, d, P, N]
        out = F.relu(value) * context_vector.expand_as(value)
        out = self.out_proj(out)
        out = self.out_drop(out)
        return out

通过这种方式,MobileViT v2能够在保持全局建模能力的同时,显著降低计算复杂度,使其更适合在移动设备上部署。

EfficientViT:迈向更高效率的视觉Transformer

继MobileViT之后,EfficientViT系列进一步推动了移动视觉Transformer的发展。timm库中包含了来自MIT和微软亚洲研究院(MSRA)的两种EfficientViT实现,它们各自采用了不同的优化策略。

MIT的EfficientViT:轻量级多尺度线性注意力

MIT提出的EfficientViT(在timm中对应efficientvit_mit.py)引入了轻量级多尺度线性注意力(LiteMLA)机制,旨在在保持高性能的同时进一步提高效率。

LiteMLA的核心思想

LiteMLA的核心创新在于:

  1. 多尺度特征聚合:通过不同尺度的深度卷积对查询、键和值进行处理,以捕获多尺度上下文信息。

  2. 线性注意力:采用线性复杂度的注意力机制,避免了标准Transformer中的二次复杂度。

  3. 高效投影:使用1x1卷积对注意力输出进行投影,以融合不同尺度的特征。

以下是LiteMLA的核心实现代码:

class LiteMLA(nn.Module):
    """Lightweight multi-scale linear attention"""

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        heads: int or None = None,
        heads_ratio: float = 1.0,
        dim=8,
        use_bias=False,
        norm_layer=(None, nn.BatchNorm2d),
        act_layer=(None, None),
        kernel_func=nn.ReLU,
        scales=(5,),
        eps=1e-5,
    ):
        super(LiteMLA, self).__init__()
        self.eps = eps
        heads = heads or int(in_channels // dim * heads_ratio)
        total_dim = heads * dim
        use_bias = val2tuple(use_bias, 2)
        norm_layer = val2tuple(norm_layer, 2)
        act_layer = val2tuple(act_layer, 2)

        self.dim = dim
        self.qkv = ConvNormAct(
            in_channels,
            3 * total_dim,
            1,
            bias=use_bias[0],
            norm_layer=norm_layer[0],
            act_layer=act_layer[0],
        )
        self.aggreg = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(
                    3 * total_dim,
                    3 * total_dim,
                    scale,
                    padding=get_same_padding(scale),
                    groups=3 * total_dim,
                    bias=use_bias[0],
                ),
                nn.Conv2d(3 * total_dim, 3 * total_dim, 1, groups=3 * heads, bias=use_bias[0]),
            )
            for scale in scales
        ])
        self.kernel_func = kernel_func(inplace=False)

        self.proj = ConvNormAct(
            total_dim * (1 + len(scales)),
            out_channels,
            1,
            bias=use_bias[1],
            norm_layer=norm_layer[1],
            act_layer=act_layer[1],
        )

    def _attn(self, q, k, v):
        dtype = v.dtype
        q, k, v = q.float(), k.float(), v.float()
        kv = k.transpose(-1, -2) @ v
        out = q @ kv
        out = out[..., :-1] / (out[..., -1:] + self.eps)
        return out.to(dtype)

    def forward(self, x):
        B, _, H, W = x.shape

        # generate multi-scale q, k, v
        qkv = self.qkv(x)
        multi_scale_qkv = [qkv]
        for op in self.aggreg:
            multi_scale_qkv.append(op(qkv))
        multi_scale_qkv = torch.cat(multi_scale_qkv, dim=1)
        multi_scale_qkv = multi_scale_qkv.reshape(B, -1, 3 * self.dim, H * W).transpose(-1, -2)
        q, k, v = multi_scale_qkv.chunk(3, dim=-1)

        # lightweight global attention
        q = self.kernel_func(q)
        k = self.kernel_func(k)
        v = F.pad(v, (0, 1), mode="constant", value=1.)

        if not torch.jit.is_scripting():
            with torch.autocast(device_type=v.device.type, enabled=False):
                out = self._attn(q, k, v)
        else:
            out = self._attn(q, k, v)

        # final projection
        out = out.transpose(-1, -2).reshape(B, -1, H, W)
        out = self.proj(out)
        return out

在EfficientVitBlock中,LiteMLA与MBConv(Mobile Inverted Residual Block)相结合,形成了一个高效的混合结构:

class EfficientVitBlock(nn.Module):
    def __init__(
        self,
        in_channels,
        heads_ratio=1.0,
        head_dim=32,
        expand_ratio=4,
        norm_layer=nn.BatchNorm2d,
        act_layer=nn.Hardswish,
    ):
        super(EfficientVitBlock, self).__init__()
        self.context_module = ResidualBlock(
            LiteMLA(
                in_channels=in_channels,
                out_channels=in_channels,
                heads_ratio=heads_ratio,
                dim=head_dim,
                norm_layer=(None, norm_layer),
            ),
            nn.Identity(),
        )
        self.local_module = ResidualBlock(
            MBConv(
                in_channels=in_channels,
                out_channels=in_channels,
                expand_ratio=expand_ratio,
                use_bias=(True, True, False),
                norm_layer=(None, None, norm_layer),
                act_layer=(act_layer, act_layer, None),
            ),
            nn.Identity(),
        )

    def forward(self, x):
        x = self.context_module(x)
        x = self.local_module(x)
        return x

这种结构设计使得EfficientViT能够在保持高性能的同时,显著降低计算复杂度和内存占用,非常适合移动设备部署。

MSRA的EfficientViT:级联组注意力机制

微软亚洲研究院提出的EfficientViT(在timm中对应efficientvit_msra.py)则采用了另一种创新思路——级联组注意力(Cascaded Group Attention)机制。

级联组注意力的核心思想

级联组注意力的主要创新点包括:

  1. 分组注意力:将输入特征分成多个组,每个组独立进行注意力计算,以降低计算复杂度。

  2. 级联融合:不同组的注意力输出进行级联融合,以捕获跨组的依赖关系。

  3. 局部窗口注意力:将特征图分割为局部窗口,在每个窗口内进行注意力计算,进一步降低计算量。

以下是级联组注意力的核心实现代码:

class CascadedGroupAttention(torch.nn.Module):
    attention_bias_cache: Dict[str, torch.Tensor]

    r""" Cascaded Group Attention.

    Args:
        dim (int): Number of input channels.
        key_dim (int): The dimension for query and key.
        num_heads (int): Number of attention heads.
        attn_ratio (int): Multiplier for the query dim for value dimension.
        resolution (int): Input resolution, correspond to the window size.
        kernels (List[int]): The kernel size of the dw conv on query.
    """
    def __init__(
            self,
            dim,
            key_dim,
            num_heads=8,
            attn_ratio=4,
            resolution=14,
            kernels=(5, 5, 5, 5),
    ):
        super().__init__()
        self.num_heads = num_heads
        self.scale = key_dim ** -0.5
        self.key_dim = key_dim
        self.val_dim = int(attn_ratio * key_dim)
        self.attn_ratio = attn_ratio

        qkvs = []
        dws = []
        for i in range(num_heads):
            qkvs.append(ConvNorm(dim // (num_heads), self.key_dim * 2 + self.val_dim))
            dws.append(ConvNorm(self.key_dim, self.key_dim, kernels[i], 1, kernels[i] // 2, groups=self.key_dim))
        self.qkvs = torch.nn.ModuleList(qkvs)
        self.dws = torch.nn.ModuleList(dws)
        self.proj = torch.nn.Sequential(
            torch.nn.ReLU(),
            ConvNorm(self.val_dim * num_heads, dim, bn_weight_init=0)
        )

        points = list(itertools.product(range(resolution), range(resolution)))
        N = len(points)
        attention_offsets = {}
        idxs = []
        for p1 in points:
            for p2 in points:
                offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
                if offset not in attention_offsets:
                    attention_offsets[offset] = len(attention_offsets)
                idxs.append(attention_offsets[offset])
        self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
        self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N), persistent=False)
        self.attention_bias_cache = {}

    def forward(self, x):
        B, C, H, W = x.shape
        feats_in = x.chunk(len(self.qkvs), dim=1)
        feats_out = []
        feat = feats_in[0]
        attn_bias = self.get_attention_biases(x.device)
        for head_idx, (qkv, dws) in enumerate(zip(self.qkvs, self.dws)):
            if head_idx > 0:
                feat = feat + feats_in[head_idx]
            feat = qkv(feat)
            q, k, v = feat.view(B, -1, H, W).split([self.key_dim, self.key_dim, self.val_dim], dim=1)
            q = dws(q)
            q, k, v = q.flatten(2), k.flatten(2), v.flatten(2)
            q = q * self.scale
            attn = q.transpose(-2, -1) @ k
            attn = attn + attn_bias[head_idx]
            attn = attn.softmax(dim=-1)
            feat = v @ attn.transpose(-2, -1)
            feat = feat.view(B, self.val_dim, H, W)
            feats_out.append(feat)
        x = self.proj(torch.cat(feats_out, 1))
        return x

级联组注意力通过将输入特征分组处理,并在每个组内进行注意力计算,然后融合各组结果,有效降低了计算复杂度。同时,引入位置偏移偏置(attention_biases)来建模空间关系,进一步提升了模型性能。

MobileViT与EfficientViT的对比分析

MobileViT和EfficientViT系列作为移动友好型视觉Transformer的代表,各有其独特的设计理念和优势。

架构设计对比

模型核心创新计算复杂度内存占用适用场景
MobileViTCNN-Transformer混合结构对性能和效率有均衡要求的场景
MobileViT v2线性注意力资源受限的移动设备
EfficientViT (MIT)轻量级多尺度线性注意力需要多尺度特征的应用
EfficientViT (MSRA)级联组注意力对内存敏感的场景

性能对比

虽然具体的性能对比需要在统一的基准测试上进行,但根据模型设计和官方报告,我们可以得出以下大致结论:

  1. 精度:EfficientViT系列在大多数视觉任务上略优于MobileViT,尤其是在高分辨率图像上。

  2. 速度:MobileViT v2和MIT的EfficientViT在移动设备上的推理速度更快。

  3. 内存效率:MSRA的EfficientViT在内存使用上更具优势,适合内存受限的设备。

适用场景推荐

  • MobileViT:适合对性能和效率有均衡要求的通用移动视觉应用。

  • MobileViT v2:适合资源极其受限的低端移动设备。

  • EfficientViT (MIT):适合需要处理多尺度特征的应用,如目标检测和语义分割。

  • EfficientViT (MSRA):适合内存受限但对性能有较高要求的场景,如实时视频处理。

如何在pytorch-image-models中使用这些模型

pytorch-image-models库为这些高效移动Transformer模型提供了统一的接口,使得开发者可以轻松地使用和比较不同的模型。

模型加载与推理

以下是使用MobileViT和EfficientViT进行图像分类的示例代码:

import torch
from timm import create_model
from PIL import Image
import torchvision.transforms as transforms

# 加载预训练模型
model_mobilevit = create_model('mobilevit_s', pretrained=True)
model_efficientvit_mit = create_model('efficientvit_b0', pretrained=True)
model_efficientvit_msra = create_model('efficientvit_m0', pretrained=True)

# 设置为评估模式
model_mobilevit.eval()
model_efficientvit_mit.eval()
model_efficientvit_msra.eval()

# 图像预处理
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# 加载并预处理图像
image = Image.open("test_image.jpg")
image = transform(image).unsqueeze(0)

# 推理
with torch.no_grad():
    output_mobilevit = model_mobilevit(image)
    output_efficientvit_mit = model_efficientvit_mit(image)
    output_efficientvit_msra = model_efficientvit_msra(image)

# 获取预测结果
pred_mobilevit = torch.argmax(output_mobilevit, dim=1)
pred_efficientvit_mit = torch.argmax(output_efficientvit_mit, dim=1)
pred_efficientvit_msra = torch.argmax(output_efficientvit_msra, dim=1)

print(f"MobileViT prediction: {pred_mobilevit.item()}")
print(f"EfficientViT (MIT) prediction: {pred_efficientvit_mit.item()}")
print(f"EfficientViT (MSRA) prediction: {pred_efficientvit_msra.item()}")

模型配置与定制

timm库还允许开发者根据需求定制模型配置,例如:

# 定制MobileViT模型
custom_mobilevit = create_model(
    'mobilevit_s',
    pretrained=True,
    num_classes=10,  # 修改分类头以适应自定义数据集
    drop_rate=0.2,   # 调整 dropout 率
    img_size=384     # 支持更高分辨率输入
)

# 定制EfficientViT模型
custom_efficientvit = create_model(
    'efficientvit_m4',
    pretrained=True,
    num_classes=100,
    drop_path_rate=0.1
)

总结与展望

从MobileViT到EfficientViT,移动视觉Transformer的发展见证了研究者们在效率与性能之间寻求平衡的不懈努力。这些模型通过创新的注意力机制设计、网络结构优化和混合建模策略,不断推动着移动视觉智能的边界。

pytorch-image-models库将这些先进模型集中整合,为开发者提供了便捷的工具来探索和应用这些技术。无论是MobileViT的CNN-Transformer融合思想,还是EfficientViT的高效注意力机制,都为移动视觉应用开辟了新的可能性。

未来,随着模型压缩、量化技术和硬件加速的进一步发展,我们有理由相信移动视觉Transformer将在更多领域得到应用,从智能手机到物联网设备,从增强现实到自动驾驶,为我们的生活带来更多智能和便利。

作为开发者,我们应该密切关注这些模型的发展,根据具体应用场景选择最合适的模型,并通过实践不断优化模型在特定任务上的性能和效率。

官方文档:README.md 模型实现:timm/models/

【免费下载链接】pytorch-image-models huggingface/pytorch-image-models: 是一个由 Hugging Face 开发维护的 PyTorch 视觉模型库,包含多个高性能的预训练模型,适用于图像识别、分类等视觉任务。 【免费下载链接】pytorch-image-models 项目地址: https://gitcode.com/GitHub_Trending/py/pytorch-image-models

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值