基于ResNet的图像识别

目录结构:

image_matching.py:

import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import os
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity


# 1. 加载预训练的 ResNet 模型,并去掉分类层
def get_resnet_model():
    model = models.resnet50(pretrained=True)  # 使用 ResNet-50 作为基础模型
    # 去掉最后的全连接层
    model = nn.Sequential(*list(model.children())[:-1])
    model.eval()  # 切换到评估模式
    return model


# 2. 图像预处理
def preprocess_image(image_path):
    transform = transforms.Compose([
        transforms.Resize(256),  # 调整大小
        transforms.CenterCrop(224),  # 中心裁剪
        transforms.ToTensor(),  # 转为张量
        transforms.Normalize(  # 标准化
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])
    image = Image.open(image_path).convert("RGB")
    image = transform(image).unsqueeze(0)  # 增加批次维度
    return image


# 3. 提取图像特征
def extract_features(model, image):
    with torch.no_grad():
        features = model(image).squeeze()  # 提取特征
        features = features.flatten()  # 扁平化
    return features.numpy()


# 4. 计算相似度(使用余弦相似度)
def calculate_similarity(feature1, feature2):
    return cosine_similarity([feature1], [feature2])[0][0]


# 5. 加载数据库中的图像并提取特征
def build_image_database(image_dir, model):
    features_database = {}
    for image_name in os.listdir(image_dir):
        image_path = os.path.join(image_dir, image_name)
        if os.path.isfile(image_path):
            image = preprocess_image(image_path)
            feature = extract_features(model, image)
            features_database[image_name] = feature
    return features_database


# 6. 查询数据库,找出最匹配的图像
def find_most_similar_image(query_image_path, image_dir, model):
    query_image = preprocess_image(query_image_path)
    query_feature = extract_features(model, query_image)

    features_database = build_image_database(image_dir, model)

    best_match = None
    best_similarity = -1  # 初始时设定最差的相似度(-1)

    for image_name, feature in features_database.items():
        similarity = calculate_similarity(query_feature, feature)
        if similarity > best_similarity:
            best_similarity = similarity
            best_match = image_name

    return best_match, best_similarity


# 使用示例
if __name__ == '__main__':
    # 加载模型
    model = get_resnet_model()

    # 定义查询图像路径和库中图像的目录路径
    query_image_path = 'images/wk1.jpg'
    image_dir = 'images_database'

    # 查找最匹配的图像
    best_match, similarity = find_most_similar_image(query_image_path, image_dir, model)

    print(f"最匹配的图像是: {best_match},匹配度: {similarity:.4f}")

执行结果:

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值