本文来源公众号“OpenCV与AI深度学习”,仅用于学术分享,侵权删,干货满满。
原文链接:MobileCLIP:一种轻量级的零样本图像分类解决方案(介绍 + 代码演示)
介绍
在当今快速发展的计算机视觉领域,人们越来越需要能够快速高效地适应新情况的技术。该领域最令人兴奋的发展之一是零样本图像分类。但这具体意味着什么呢?
想象一下,向计算机展示一张它从未见过的物体图像。传统上,您必须使用大量带标签的图像来训练系统,以帮助它识别不同的物体。但是使用零样本分类,您可以绕过这个漫长的过程。您无需教模型每个物体的样子,只需给它描述并让它自己找出匹配项。
在新的类别频繁出现且很难(或成本高昂)获取足够的标记数据的环境中,此功能可以改变游戏规则。
实现这一目标的杰出模型之一是OpenAI 开发的CLIP(对比语言-图像预训练)。它将图像与其文本描述联系起来,使模型无需通常的繁重训练即可识别新物体。缺点是什么?CLIP 有点耗费资源。它需要大量的计算能力,因此很难在手机或物联网设备等小型设备上运行。
这正是MobileCLIP发挥作用的地方。它是 CLIP 的优化版,轻量级版本,专为资源较少的设备而设计。在本文中,我将向您展示 MobileCLIP 的工作原理以及它如何将零样本分类带到您的掌中。
MobileCLIP 论文:
https://arxiv.org/pdf/2311.17049
CLIP Paper 论文:
https://arxiv.org/abs/2103.00020
什么是MobileCLIP
基础:CLIP
CLIP 的核心是一种学习将图像与文本联系起来的模型。它经过训练可以预测哪段文本与给定图像最匹配,或者哪幅图像与描述最匹配。这个过程涉及包含图像和文本对的庞大数据集。
例如,如果您提供一张猫的图像,并且文本提示“狗”、“猫”和“汽车”,CLIP 可以确定“猫”是最可能的匹配,即使它以前从未见过该图像。
这种对视觉概念的一般理解使 CLIP 有别于传统的图像分类器。传统分类器通常需要对要识别的每个类别进行标记;如果添加新类别,则必须收集更多标签并重新训练。另一方面,CLIP 只需要文本标签或描述,不需要额外的图像数据。
CLIP 的强大之处在于它能够泛化各种视觉概念,这要归功于它对各种数据进行训练。然而,这种强大是有代价的:CLIP 需要大量的计算资源,这使得它不太适合资源有限的应用。
优化需求:进入MobileCLIP
为了让 CLIP 在资源有限的设备上更加实用,MobileCLIP 应运而生。它是 CLIP 的简化版本,经过微调以提高效率,同时又不损失原始模型的准确性,尤其是在零样本分类方面,它甚至优于传统的 CLIP 模型。
CLIP 和 MobileCLIP 之间的主要区别包括:
-
-
更小的模型尺寸:它被精简以使用更少的内存,这对于存储空间有限的移动或边缘设备至关重要。
-
计算效率: MobileCLIP 的设计即使在处理能力有限的设备上也能表现良好,例如智能手机或物联网设备。
-
低延迟: MobileCLIP 提供更低的推理延迟,这对于实时视频分析等实时应用至关重要。
-
MobileCLIP 的潜在用例
现在我们知道了 MobileCLIP 可以做什么,让我们来探索一下它可以在哪里使用。由于它是为资源有限的设备设计的,因此潜在的应用非常令人兴奋:
-
-
移动应用:想想您每天在手机上使用的应用。随着设备智能化的推进,MobileCLIP 可以增强您在增强现实应用、个人助理甚至实时照片分类方面的体验。您的手机无需将数据发送到云端进行处理(这需要时间和带宽),而是可以在本地完成所有艰苦的工作。
-
边缘计算:MobileCLIP 非常适合带宽和处理能力有限的边缘计算环境。无人机、机器人和远程传感器等设备可以利用该模型执行视觉识别任务,无需持续的云连接即可实现实时决策。
-
物联网设备:MobileCLIP 集成到物联网 (IoT) 设备(如安全摄像头或智能家居助手)中,使这些系统能够执行本地视觉识别。这在隐私、延迟和在互联网连接不稳定的环境中运行的能力方面带来了好处。
-
实现MobileCLIP
让我们深入了解如何实际使用 MobileCLIP 进行零样本分类。如果您已准备好亲自动手,这里有一份分步指南,可帮助您进行设置。
分步代码:使用 MobileCLIP 进行零样本图像分类
1. 环境设置
import os
import time
import argparse
from typing import List, Tuple
import cv2
import torch
import matplotlib.pyplot as plt
from PIL import Image
import open_clip
from timm.utils import reparameterize_model
import numpy as np
# Check CUDA availability and set the device (GPU if available, otherwise CPU)
cuda = torch.cuda.is_available()
device = torch.device("cuda" if cuda else "cpu")
print(f"Torch device: {device}")
首先,我们需要导入必要的库,包括用于模型加载的 open_clip、用于张量运算的 torch、用于图像处理的 cv2 以及用于可视化结果的 matplotlib。如果您有 GPU,MobileCLIP 可以利用它来加快速度。如果没有,它仍然可以在 CPU 上运行良好。
2. 模型与预处理
# Load MobileCLIP model and preprocessing transforms
model, _, preprocess = open_clip.create_model_and_transforms(
'MobileCLIP-S1', pretrained='datacompdr'
)
tokenizer = open_clip.get_tokenizer('MobileCLIP-S1')
# Set model to evaluation mode, reparameterize for efficiency,
# and move it to the selected device
model.eval()
model = reparameterize_model(model)
model.to(device)
接下来,我们加载 MobileCLIP 模型(我们使用 MobileCLIP-S1,一个更轻量级的版本)。我们还需要加载 tokenizer,它将您的文本提示转换为模型可以理解的 token 序列。将模型设置为评估模式,以便准备好进行推理。
3. 图像分类功能
def classify_image(img: np.ndarray, labels_list: List[str]) -> Tuple[str, float]:
"""
Classify an image using MobileCLIP.
This function preprocesses the input image, tokenizes the provided
text prompts, extracts features from both image and text,
computes the similarity, and returns the label with the highest
probability along with the probability value.
Args:
img (numpy.ndarray): Input image in RGB format.
labels_list (list): List of labels to classify against.
Returns:
tuple: A tuple containing the predicted label (str) and
the probability (float).
"""
# Convert the image from a NumPy array to a PIL image, preprocess it,
# add batch dimension, and move to device.
preprocessed_img = preprocess(Image.fromarray(img)).unsqueeze(0).to(device)
# Tokenize the labels inside the function and move tokens to the device.
text = tokenizer(labels_list).to(device)
# Disable gradient calculation and enable automatic mixed precision
with torch.no_grad(), torch.cuda.amp.autocast():
# Extract features from the image using the model.
image_features = model.encode_image(preprocessed_img)
# Extract text features from the tokenized text.
text_features = model.encode_text(text)
# Normalize image and text features to unit vectors.
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
# Compute the similarity (dot product) and apply softmax to
# obtain probabilities.
text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
# Get the label with the highest probability from the provided label list.
selected_label = labels_list[text_probs.argmax(dim=-1)]
selected_prob = text_probs.max(dim=-1)[0].item()
return selected_label, selected_prob
该流程的核心是图像分类函数。该函数将图像作为输入,对其进行预处理,并将其传递给 mobileCLIP 编码器模型以提取图像特征。然后,它计算与给定标签(例如“猫”、“狗”、“汽车”)的相似度(也使用 mobileCLIP 编码),并返回最可能的标签及其相关概率。
4. 可视化结果
def plot_results(results: List[Tuple[np.ndarray, str, float, float]]) -> None:
"""
Plot the classification results.
This function creates a horizontal plot for each image in the results,
displaying the image along with its predicted label, probability,
and processing time.
Args:
results (list): List of tuples (img, label, probability, elapsed_time).
"""
# Create subplots with one image per subplot.
fig, axes = plt.subplots(1, len(results), figsize=(len(results) * 5, 5))
# If there is only one image, make axes a list to handle it uniformly.
if len(results) == 1:
axes = [axes]
# Iterate over results and plot each one.
for ax, (img, label, prob, elapsed_time) in zip(axes, results):
ax.imshow(img)
ax.set_title(
f"Label: {label},\nProbability: {prob:.2%},\nTime: {elapsed_time:.2f}s"
)
ax.axis('off')
plt.tight_layout()
plt.show()
本节介绍了一种可视化函数,绘制分类图像及其预测标签、概率和处理时间。
5. 图像分类主循环
def main(data_folder: str, labels_list: List[str]) -> None:
"""
Process images and perform zero-shot image classification.
This function processes each image in the specified folder,
classifies them using MobileCLIP, and then plots the results.
Args:
data_folder (str): Path to the folder containing input images.
labels_list (List[str]): List of labels to classify against.
"""
results: List[Tuple[np.ndarray, str, float, float]] = []
for image_file in os.listdir(data_folder):
image_path = os.path.join(data_folder, image_file)
# Read the image using OpenCV.
img = cv2.imread(image_path)
# Skip files that are not valid images.
if img is None:
print(f"Warning: Unable to read image {image_file}. Skipping.")
continue
# Convert the image from BGR (OpenCV default) to RGB.
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
start_time = time.time()
selected_label, selected_prob = classify_image(img, labels_list)
elapsed_time = time.time() - start_time
print(f"{image_file} - Label: {selected_label}, Prob: {selected_prob:.2%} (Time: {elapsed_time:.2f}s)")
results.append((img, selected_label, selected_prob, elapsed_time))
plot_results(results)
if __name__ == '__main__':
data_folder = 'data'
labels_list = ['dog', 'cat', 'car']
main(data_folder, labels_list)
我们迭代数据文件夹中的图像,使用 classify_image() 对每个图像进行分类,并附加结果进行可视化。然后将结果传递给 plot_results() 以生成可视化输出。
完整代码:
https://github.com/vargroup-datascience/medium-repo/tree/main/exploring_MobileCLIP_A_Lightweight_Solution_for_ZeroShot_Image_Classification
识别结果:
THE END !
文章结束,感谢阅读。您的点赞,收藏,评论是我继续更新的动力。大家有推荐的公众号可以评论区留言,共同学习,一起进步。