目录结构:
image_matching.py:
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import os
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
# 1. 加载预训练的 ResNet 模型,并去掉分类层
def get_resnet_model():
model = models.resnet50(pretrained=True) # 使用 ResNet-50 作为基础模型
# 去掉最后的全连接层
model = nn.Sequential(*list(model.children())[:-1])
model.eval() # 切换到评估模式
return model
# 2. 图像预处理
def preprocess_image(image_path):
transform = transforms.Compose([
transforms.Resize(256), # 调整大小
transforms.CenterCrop(224), # 中心裁剪
transforms.ToTensor(), # 转为张量
transforms.Normalize( # 标准化
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
image = Image.open(image_path).convert("RGB")
image = transform(image).unsqueeze(0) # 增加批次维度
return image
# 3. 提取图像特征
def extract_features(model, image):
with torch.no_grad():
features = model(image).squeeze() # 提取特征
features = features.flatten() # 扁平化
return features.numpy()
# 4. 计算相似度(使用余弦相似度)
def calculate_similarity(feature1, feature2):
return cosine_similarity([feature1], [feature2])[0][0]
# 5. 加载数据库中的图像并提取特征
def build_image_database(image_dir, model):
features_database = {}
for image_name in os.listdir(image_dir):
image_path = os.path.join(image_dir, image_name)
if os.path.isfile(image_path):
image = preprocess_image(image_path)
feature = extract_features(model, image)
features_database[image_name] = feature
return features_database
# 6. 查询数据库,找出最匹配的图像
def find_most_similar_image(query_image_path, image_dir, model):
query_image = preprocess_image(query_image_path)
query_feature = extract_features(model, query_image)
features_database = build_image_database(image_dir, model)
best_match = None
best_similarity = -1 # 初始时设定最差的相似度(-1)
for image_name, feature in features_database.items():
similarity = calculate_similarity(query_feature, feature)
if similarity > best_similarity:
best_similarity = similarity
best_match = image_name
return best_match, best_similarity
# 使用示例
if __name__ == '__main__':
# 加载模型
model = get_resnet_model()
# 定义查询图像路径和库中图像的目录路径
query_image_path = 'images/wk1.jpg'
image_dir = 'images_database'
# 查找最匹配的图像
best_match, similarity = find_most_similar_image(query_image_path, image_dir, model)
print(f"最匹配的图像是: {best_match},匹配度: {similarity:.4f}")
执行结果: