原理
首先提取每个图片的特征向量embedding,这里使用的是openai发布的一个多模态预训练模型CLIP(Contrastive Language-Image Pretraining);
然后将特征向量保存到向量数据库Pinecone;
最后根据查询图片的特征向量到向量数据库搜素与其相似的图片。
具体实现
准备工作
1、安装Pinecone
!pip install pinecone==4.0.0
2、下载image数据集
如下数据集为imageNet的子集,仅包含10个类别
!wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-160.tgz
!tar -xzvf imagenette2-160.tgz
这里运行代码使用的是Google colab,也可以使用jupyter notebook
提前特征向量embedding
这里仅以一张图片为例
import os
from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPModel
# Set up train directory
dataset_dir = 'imagenette2-160'
train_dir = os.path.join(dataset_dir, 'train')
# Pick an image from our dataset
subdir_path = os.path.join(train_dir, subdirs[0])
image_filename = os.listdir(subdir_path)[0]
image_path = os.path.join(subdir_path, image_filename)
# Preprocess the image with the CLIP image preprocessor
# and use the model to generate an embedding
image = Image.open(image_path).convert('RGB')
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
inputs = processor(images=image, return_tensors="pt", padding=True)
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
with torch.no_grad():
embedding = model.get_image_features(**inputs)
将特征向量保存到向量数据库Pinecone
1、要使用Pinecone需要先注册并获取免费的key 注册地址:https://www.pinecone.io/
2、创建索引
PINECONE_API_KEY = "TODO: Fill in your API key here"
pc = Pinecone(api_key=PINECONE_API_KEY)
index_name = "clip-index" # You can name the index whatever you want
###################################################################
# TODO: Fill in your code here
###################################################################
n_embd = 512 # 这里CLIP使用512维度
metric = "cosine" #这里使用cosine,可选的还有euclidean, cosine, and dotproduct
###################################################################
# END OF YOUR CODE
###################################################################
if index_name not in pc.list_indexes().names():
pc.create_index(
name=index_name,
dimension=n_embd,
metric=metric,
spec=ServerlessSpec(
cloud='aws',
region='us-east-1'
)
)
print(f"Created index: {index_name}")
else:
print(f"Index already exists: {index_name}")
3、插入
上一步将所有图片的特征向量embeddings提取出来后以如下格式插入至Pinecone
vectors的格式如下:
vectors=[
{“id”: “/path/to/img1”, “values”: [-0.12, 0.05, -0.23, 0.18, -0.07, 0.31]},
{“id”: “/path/to/img2”, “values”: [0.42, -0.19, 0.27, -0.35, 0.11, -0.28]},
{“id”: “/path/to/img3”, “values”: [-0.51, 0.63, -0.17, 0.45, -0.38, 0.22]},
{“id”: “/path/to/img4”, “values”: [0.78, -0.41, 0.56, -0.72, 0.89, -0.25]}
]
# 使用列表推导式和zip函数将两个列表转换成所需格式
vectors = [{"id": id, "values": value} for id, value in zip(image_files, embeddings)]
index.upsert(
vectors = vectors,
namespace= namespace
)
4、查询索引状态
index = pc.Index(index_name)
index_stats = index.describe_index_stats()
print(f"The index contains {index_stats['dimension']}-dimensional vectors/embeddings.")
print(f"A total of {index_stats['total_vector_count']} vectors have been uploaded to the index.")
5、图片搜索
首先获取待查询图片的特征向量
import requests
from io import BytesIO
def get_image_from_url(url):
response = requests.get(url)
return Image.open(BytesIO(response.content)).convert('RGB')
###################################################################
# TODO: Generate an embedding for the image you found
# and query the vector database for the 3 most similar images
# to your query image.
###################################################################
# Sample query image
image_url = "https://salient-imagenet.cs.umd.edu/feature_visualization/class_481/feature_1369/images/1.jpg"
image = get_image_from_url(image_url)
inputs = processor(images=[image], return_tensors="pt", padding=True)
with torch.no_grad():
embeddings_new = model.get_image_features(**inputs)
results = index.query(
namespace=namespace,
vector=embeddings_new[0].tolist(),
top_k=3,
include_values=True
)
6、搜索结果可视化
%matplotlib inline
import matplotlib.pyplot as plt
def display_images_grid(query_image_url, image_paths, scores):
# Calculate the number of images (query + results)
num_images = len(image_paths) + 1
# Create a figure with 2 rows: 1 for query, 1 for results
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
fig.suptitle("Image Search", fontsize=16)
# Display query image
query_img = get_image_from_url(query_image_url)
axes[0, 1].imshow(query_img)
axes[0, 1].axis('off')
axes[0, 1].set_title("Query Image", fontweight='bold')
# Turn off unused subplots in the first row
axes[0, 0].axis('off')
axes[0, 2].axis('off')
# Loop through the result image paths and display each image
for i, (path, score) in enumerate(zip(image_paths, scores)):
img = Image.open(path)
containing_dir = os.path.basename(os.path.dirname(path))
filename = os.path.basename(path)
title = mapping[containing_dir]
axes[1, i].imshow(img)
axes[1, i].axis('off')
axes[1, i].set_title(f"{title}, score={score:.4f}")
plt.tight_layout()
plt.show()
image_paths=[ r['id'] for r in results['matches'] ]
scores=[ r['score'] for r in results['matches'] ]
display_images_grid(image_url, image_paths, scores)
参考链接
https://www.trybackprop.com/blog/linalg101/part_3_build_image_search