一、通俗易懂介绍
1.1 核心思想
ConvMixer是2022年CVPR论文《Patches Are All You Need?》提出的新型架构,用纯卷积操作实现类似Vision Transformer(ViT)的分块信息混合能力。
- ViT的启示:ViT将图像分块后通过自注意力融合全局信息,但自注意力计算成本高。
- ConvMixer的创新:
- 分块嵌入:将图像分割为小块,线性映射到特征空间。
- 深度卷积混合:用极深的卷积层(如14层)替代自注意力,大幅降低计算量。
1.2 举个栗子 🌰
假设要识别一张图片中的猫:
- ViT做法:将图片分为16x16的小块,通过自注意力计算所有块的关系。
- ConvMixer做法:同样分块,但用大卷积核的深度卷积(如9x9)逐层混合邻域信息,逐步捕获全局特征。
二、应用场景与优缺点
2.1 应用场景
领域 | 任务 | 优势体现 |
---|---|---|
轻量级分类 | 移动端图像分类(如手机相册自动标注) | 参数量少,计算效率高 |
医学影像 | 低分辨率X光片分析 | 对局部细节敏感,适合小尺寸图像 |
嵌入式设备 | 无人机实时目标检测 | 低延迟,适合资源受限环境 |
2.2 优缺点对比
优点 | 缺点 |
---|---|
✅ 参数量仅为ViT的1/10,计算量减少50%以上 | ❌ 深层次卷积可能导致梯度消失 |
✅ 无需位置编码,卷积天然具有平移等变性 | ❌ 大卷积核增加显存消耗 |
✅ 在小数据集上表现优异(如CIFAR-10) | ❌ 超深网络训练时间较长 |
三、模型结构详解
3.1 整体架构
输入图像 → 分块嵌入 → ConvMixer层(重复N次) → 全局池化 → 分类头
3.1.1 分块嵌入(Patch Embedding)
- 操作:将图像分割为p×p的块,展平后线性投影到d维。
- 示例:输入224x224图像,分块尺寸14x14 → 得到16x16=256个块,每个块投影为d=768维。
- 数学表达:
输入图像,
线性投影:
3.1.2 ConvMixer层
每层包含两个子模块:
-
深度卷积(Depthwise Convolution):
- 卷积核:k×k(如9x9),分组数=通道数,每个通道独立卷积。
- 作用:混合空间邻域信息。
- 输出尺寸:与输入相同。
-
逐点卷积(Pointwise Convolution):
- 卷积核:1x1,跨通道融合特征。
- 作用:增强通道间交互。
完整层结构:
输入 → 深度卷积 → GELU激活 → 逐点卷积 → Add残差连接 → 输出
3.1.3 参数配置(以ConvMixer-1536/20为例)
- 分块尺寸p:14
- 隐藏维度d:1536
- 层数N:20
- 深度卷积核k:9x9
四、数学原理
4.1 深度卷积公式
设输入张量 ,深度卷积核
,输出:
其中 为偏置项,每个通道独立计算。
4.2 残差连接
第l层输出:
五、代表性变体及改进
5.1 ConvMixer-Lite
- 改进点:减少隐藏维度(如d=512),深度卷积核缩小为7x7。
- 优势:显存占用减少40%,适合移动端部署。
5.2 ConvMixer-Attention
- 改进点:在逐点卷积前加入SE(Squeeze-Excite)注意力模块。
- SE模块公式:
5.3 DynamicConvMixer
- 改进点:动态调整卷积核大小(如5x5或7x7),根据输入内容自适应。
- 实现:使用轻量级网络预测卷积核权重。
六、PyTorch代码示例
6.1 自定义ConvMixer模型
import torch
import torch.nn as nn
class ConvMixerBlock(nn.Module):
def __init__(self, dim, kernel_size=9):
super().__init__()
# 深度卷积(分组数=输入通道数)
self.dw_conv = nn.Conv2d(dim, dim, kernel_size, padding=kernel_size//2, groups=dim)
self.act = nn.GELU()
# 逐点卷积
self.pw_conv = nn.Conv2d(dim, dim, 1)
def forward(self, x):
return x + self.pw_conv(self.act(self.dw_conv(x)))
class ConvMixer(nn.Module):
def __init__(self, img_size=224, patch_size=14, dim=768, depth=20, num_classes=1000):
super().__init__()
# 分块嵌入
self.patch_embed = nn.Sequential(
nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size),
nn.GELU(),
nn.BatchNorm2d(dim)
)
# ConvMixer层堆叠
self.blocks = nn.Sequential(*[
ConvMixerBlock(dim=dim, kernel_size=9)
for _ in range(depth)
])
# 分类头
self.head = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(dim, num_classes)
)
def forward(self, x):
x = self.patch_embed(x) # [B, dim, h, w]
x = self.blocks(x) # 通过所有ConvMixer块
x = self.head(x) # 分类输出
return x
# 示例:输入为224x224 RGB图像
model = ConvMixer(img_size=224, patch_size=14, dim=768, depth=20)
x = torch.randn(4, 3, 224, 224) # 输入形状 [4,3,224,224]
out = model(x) # 输出形状 [4,1000]
print(out.shape)
6.2 使用HuggingFace预训练模型
from transformers import ConvMixerConfig, ConvMixerForImageClassification
from PIL import Image
import requests
# 加载预训练模型(假设HuggingFace已支持)
model = ConvMixerForImageClassification.from_pretrained("timm/convmixer_1536_20")
processor = ConvMixerConfig().build_preprocessor()
# 预处理输入
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
inputs = processor(image, return_tensors="pt")
# 推理
outputs = model(**inputs)
logits = outputs.logits
predicted_class = logits.argmax(-1).item()
print("预测类别:", model.config.id2label[predicted_class])
七、总结
ConvMixer通过极简的纯卷积架构+深度分块混合,挑战了Transformer在视觉任务中的统治地位,证明了卷积在全局建模中的潜力。其变体在效率与精度的平衡上持续改进,未来可能的发展方向包括:
- 动态卷积核:根据输入动态调整参数。
- 跨模态扩展:适配视频、点云等数据类型。
- 自监督预训练:结合MAE等方法提升表征能力。