使用向量数据库Pinecone进行图片搜索

原理

首先提取每个图片的特征向量embedding,这里使用的是openai发布的一个多模态预训练模型CLIP(Contrastive Language-Image Pretraining);
然后将特征向量保存到向量数据库Pinecone;
最后根据查询图片的特征向量到向量数据库搜素与其相似的图片。

具体实现

准备工作

1、安装Pinecone

!pip install pinecone==4.0.0

2、下载image数据集
如下数据集为imageNet的子集,仅包含10个类别

!wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-160.tgz
!tar -xzvf imagenette2-160.tgz

这里运行代码使用的是Google colab,也可以使用jupyter notebook

提前特征向量embedding

这里仅以一张图片为例

import os
from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPModel
# Set up train directory
dataset_dir = 'imagenette2-160'
train_dir = os.path.join(dataset_dir, 'train')
# Pick an image from our dataset
subdir_path = os.path.join(train_dir, subdirs[0])
image_filename = os.listdir(subdir_path)[0]
image_path = os.path.join(subdir_path, image_filename)

# Preprocess the image with the CLIP image preprocessor
# and use the model to generate an embedding
image = Image.open(image_path).convert('RGB')

processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
inputs = processor(images=image, return_tensors="pt", padding=True)
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
with torch.no_grad():
    embedding = model.get_image_features(**inputs)


将特征向量保存到向量数据库Pinecone

1、要使用Pinecone需要先注册并获取免费的key 注册地址:https://www.pinecone.io/
2、创建索引

PINECONE_API_KEY = "TODO: Fill in your API key here"

pc = Pinecone(api_key=PINECONE_API_KEY)

index_name = "clip-index" # You can name the index whatever you want
###################################################################
# TODO: Fill in your code here
###################################################################
n_embd = 512 # 这里CLIP使用512维度
metric = "cosine" #这里使用cosine,可选的还有euclidean, cosine, and dotproduct
###################################################################
# END OF YOUR CODE
###################################################################

if index_name not in pc.list_indexes().names():
    pc.create_index(
        name=index_name,
        dimension=n_embd,
        metric=metric,
        spec=ServerlessSpec(
            cloud='aws',
            region='us-east-1'
        )
    )
    print(f"Created index: {index_name}")
else:
    print(f"Index already exists: {index_name}")

3、插入
上一步将所有图片的特征向量embeddings提取出来后以如下格式插入至Pinecone
vectors的格式如下:
vectors=[
{“id”: “/path/to/img1”, “values”: [-0.12, 0.05, -0.23, 0.18, -0.07, 0.31]},
{“id”: “/path/to/img2”, “values”: [0.42, -0.19, 0.27, -0.35, 0.11, -0.28]},
{“id”: “/path/to/img3”, “values”: [-0.51, 0.63, -0.17, 0.45, -0.38, 0.22]},
{“id”: “/path/to/img4”, “values”: [0.78, -0.41, 0.56, -0.72, 0.89, -0.25]}
]

# 使用列表推导式和zip函数将两个列表转换成所需格式  
vectors = [{"id": id, "values": value} for id, value in zip(image_files, embeddings)]
index.upsert(
    vectors = vectors,
    namespace= namespace
)

4、查询索引状态

index = pc.Index(index_name)
index_stats = index.describe_index_stats()

print(f"The index contains {index_stats['dimension']}-dimensional vectors/embeddings.")
print(f"A total of {index_stats['total_vector_count']} vectors have been uploaded to the index.")

5、图片搜索
首先获取待查询图片的特征向量

import requests
from io import BytesIO
def get_image_from_url(url):
  response = requests.get(url)
  return Image.open(BytesIO(response.content)).convert('RGB')

###################################################################
# TODO: Generate an embedding for the image you found
#       and query the vector database for the 3 most similar images
#       to your query image.
###################################################################

# Sample query image
image_url = "https://salient-imagenet.cs.umd.edu/feature_visualization/class_481/feature_1369/images/1.jpg"

image = get_image_from_url(image_url)
inputs = processor(images=[image], return_tensors="pt", padding=True)
with torch.no_grad():
  embeddings_new = model.get_image_features(**inputs)

results = index.query(
  namespace=namespace,
  vector=embeddings_new[0].tolist(),
  top_k=3,
  include_values=True
)

6、搜索结果可视化

%matplotlib inline
import matplotlib.pyplot as plt

def display_images_grid(query_image_url, image_paths, scores):
    # Calculate the number of images (query + results)
    num_images = len(image_paths) + 1

    # Create a figure with 2 rows: 1 for query, 1 for results
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    fig.suptitle("Image Search", fontsize=16)

    # Display query image
    query_img = get_image_from_url(query_image_url)
    axes[0, 1].imshow(query_img)
    axes[0, 1].axis('off')
    axes[0, 1].set_title("Query Image", fontweight='bold')

    # Turn off unused subplots in the first row
    axes[0, 0].axis('off')
    axes[0, 2].axis('off')

    # Loop through the result image paths and display each image
    for i, (path, score) in enumerate(zip(image_paths, scores)):
        img = Image.open(path)
        containing_dir = os.path.basename(os.path.dirname(path))
        filename = os.path.basename(path)
        title = mapping[containing_dir]

        axes[1, i].imshow(img)
        axes[1, i].axis('off')
        axes[1, i].set_title(f"{title}, score={score:.4f}")

    plt.tight_layout()
    plt.show()
image_paths=[ r['id'] for r in results['matches'] ]
scores=[ r['score'] for r in results['matches'] ]
display_images_grid(image_url, image_paths, scores)

参考链接

https://www.trybackprop.com/blog/linalg101/part_3_build_image_search

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值