传统的特征匹配算法:
通过opencv自带的matchtemplate方法识别发现对形变、旋转的效果不是很好,后来尝试利用orb特征、sift特征匹配,由于车辆很多特征很相似,也不能很好的区分,如利用sift特征匹配效果如下:
代码:
import shutil
import cv2
import numpy as np
import os
def calculate_match_score(img1, img2):
"""计算两张图像的匹配分数"""
# 创建SIFT对象
sift = cv2.SIFT_create()
# 检测SIFT关键点和描述符
keypoints1, descriptors1 = sift.detectAndCompute(img1, None)
keypoints2, descriptors2 = sift.detectAndCompute(img2, None)
if descriptors1 is None or descriptors2 is None:
return 0 # 如果无法计算描述符,则匹配分数为0
# 创建BFMatcher对象
bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=True)
matches = bf.match(descriptors1, descriptors2)
# 计算匹配度(匹配点数量与总点数的比值)
num_matches = len(matches)
total_points = len(keypoints1) + len(keypoints2)
if total_points > 0:
match_score = num_matches / total_points
else:
match_score = 0
return match_score * 1000
def template_match_folder(template_img, folder):
"""在文件夹中查找与模板图像匹配的图像"""
all_img_list = {}
folder_name = os.path.basename(template_img).split("_")[0]
save_folder = os.path.join("G:", "ss", folder_name)
os.makedirs(save_folder, exist_ok=True)
for des_img_name in os.listdir(folder):
des_img_path = os.path.join(folder, des_img_name)
# 读取目标图像
des_img = cv2.imread(des_img_path)
if des_img is None:
print(f"无法读取图像 {des_img_path}")
continue
height, width = des_img.shape[:2]
des_img_area = height * width
if des_img_area < 50 * 65:
continue
# 计算匹配分数
match_score = calculate_match_score(template_img, des_img)
if match_score > 200:
all_img_list[des_img_name] = match_score
save_img_path = os.path.join(save_folder, des_img_name)
shutil.copy(des_img_path, save_img_path)
return all_img_list
def template_folder_match_des_folder(template_folder, folder):
"""遍历模板文件夹,匹配每个模板图像与目标文件夹中的图像"""
for template_name in os.listdir(template_folder):
template_path = os.path.join(template_folder, template_name)
template_img = cv2.imread(template_path)
if template_img is None:
print(f"无法读取模板图像 {template_path}")
continue
all_img_list = template_match_folder(template_img, folder)
with open("1.txt", "a", encoding="utf-8") as f:
f.write(str(all_img_list))
f.write("\n")
# 主程序入口
template_folder = r"G:\dataset\M3FD\M3FD_Detection\templates"
folder = r"G:\dataset\M3FD\M3FD_Detection\cut_imgs"
template_folder_match_des_folder(template_folder, folder)
效果:
模版图像:
算法匹配结果:
模版图像:
算法匹配结果:
深度学习匹配算法:
通过resne提取图像特征,计算余弦相似度。再映射至hsv和lab颜色空间计算颜色的相似度,共同去评估模版与目标的相似度。
代码:
import torch
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import cv2
import shutil
import os
import concurrent.futures
from tqdm import tqdm
# 检查CUDA是否可用并选择设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
# 加载预训练的 ResNet 模型并将其移动到GPU
model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
model = model.to(device) # 将模型移动到GPU
model.eval() # 设置模型为评估模式
# 定义图像预处理步骤
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def preprocess_image(image):
"""将图像预处理为模型输入格式"""
if isinstance(image, str):
image = Image.open(image).convert('RGB')
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
if isinstance(image, Image.Image):
image = preprocess(image)
image = image.unsqueeze(0).to(device) # 增加一个批次维度并将图像移动到GPU
return image
else:
raise TypeError("Unsupported image type: {}".format(type(image)))
def get_features(image):
"""提取图像特征"""
image = preprocess_image(image)
# 使用模型提取特征
with torch.no_grad():
features = model(image)
return features.cpu().numpy().flatten() # 将特征从GPU移动到CPU并展平
def get_color_features(image):
"""提取图像颜色直方图特征"""
if isinstance(image, str):
image = Image.open(image).convert('RGB')
if isinstance(image, np.ndarray):
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
elif isinstance(image, Image.Image):
image = np.array(image.convert('RGB'))
else:
raise TypeError("Unsupported image type: {}".format(type(image)))
# 转换到 HSV 和 Lab 颜色空间
hsv_image = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
lab_image = cv2.cvtColor(image, cv2.COLOR_RGB2Lab)
# 计算 HSV 颜色直方图
hist_h = cv2.calcHist([hsv_image], [0], None, [256], [0, 256]).flatten()
hist_s = cv2.calcHist([hsv_image], [1], None, [256], [0, 256]).flatten()
hist_v = cv2.calcHist([hsv_image], [2], None, [256], [0, 256]).flatten()
# 计算 Lab 颜色直方图
hist_l = cv2.calcHist([lab_image], [0], None, [256], [0, 256]).flatten()
hist_a = cv2.calcHist([lab_image], [1], None, [256], [-128, 128]).flatten()
hist_b = cv2.calcHist([lab_image], [2], None, [256], [-128, 128]).flatten()
# 计算颜色矩(均值和标准差)
mean_hsv = np.mean(hsv_image, axis=(0, 1))
std_hsv = np.std(hsv_image, axis=(0, 1))
mean_lab = np.mean(lab_image, axis=(0, 1))
std_lab = np.std(lab_image, axis=(0, 1))
# 归一化直方图
hist_h /= hist_h.sum() if hist_h.sum() > 0 else 1
hist_s /= hist_s.sum() if hist_s.sum() > 0 else 1
hist_v /= hist_v.sum() if hist_v.sum() > 0 else 1
hist_l /= hist_l.sum() if hist_l.sum() > 0 else 1
hist_a /= hist_a.sum() if hist_a.sum() > 0 else 1
hist_b /= hist_b.sum() if hist_b.sum() > 0 else 1
# 合并特征并进行标准化
color_features = np.concatenate([hist_h, hist_s, hist_v, hist_l, hist_a, hist_b, mean_hsv, std_hsv, mean_lab, std_lab])
color_features = (color_features - np.mean(color_features)) / (np.std(color_features) + 1e-6) # 标准化
return color_features
def compare_images(image1, image2):
"""比较两张图像的相似性"""
# 提取图像特征
features1 = get_features(image1)
features2 = get_features(image2)
# 提取颜色特征
color_features1 = get_color_features(image1)
color_features2 = get_color_features(image2)
similarity_reset = cosine_similarity([features1], [features2])[0][0]
similarity_color = cosine_similarity([color_features1], [color_features2])[0][0]
return similarity_reset, similarity_color
def calculate_match_score(img1, img2):
"""计算SIFT匹配度"""
# 创建SIFT对象
sift = cv2.SIFT_create()
# 检测SIFT关键点和描述符
keypoints1, descriptors1 = sift.detectAndCompute(img1, None)
keypoints2, descriptors2 = sift.detectAndCompute(img2, None)
# 创建BFMatcher对象
bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=True)
matches = bf.match(descriptors1, descriptors2)
# 计算匹配度(匹配点数量与总点数的比值)
num_matches = len(matches)
total_points = len(keypoints1) + len(keypoints2)
if total_points > 100:
match_score = num_matches / total_points
else:
match_score = 0
return match_score * 1000
def process_image_pair(template_img_path, des_img_path, save_folder):
"""处理图像对并保存符合条件的图像"""
template_img = cv2.imread(template_img_path)
des_img = cv2.imread(des_img_path)
height, width = des_img.shape[:2]
des_img_area = height * width
if des_img_area < 50 * 65:
return None
similarity_reset_score, similarity_color_score = compare_images(template_img, des_img)
if similarity_reset_score > 0.8 and similarity_color_score > 0.998:
des_img_name = os.path.basename(des_img_path)
save_img_path = os.path.join(save_folder, des_img_name)
shutil.copy(des_img_path, save_img_path)
return {des_img_name: similarity_reset_score}
return None
def template_match_folder(template_path, folder, max_workers=8):
"""处理文件夹中的所有图像"""
all_img_list = {}
template_img = cv2.imread(template_path)
save_folder = os.path.join("G:\\fff", os.path.basename(template_path).split("_")[0])
os.makedirs(save_folder, exist_ok=True)
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = []
for des_img_name in os.listdir(folder):
des_img_path = os.path.join(folder, des_img_name)
futures.append(executor.submit(process_image_pair, template_path, des_img_path, save_folder))
for future in concurrent.futures.as_completed(futures):
result = future.result()
if result:
all_img_list.update(result)
return all_img_list
def template_folder_match_des_folder(template_folder, folder, max_workers=8):
"""处理模板文件夹和目标文件夹"""
for template_name in tqdm(os.listdir(template_folder)):
template_path = os.path.join(template_folder, template_name)
all_img_list = template_match_folder(template_path, folder, max_workers)
with open("3.txt", "a", encoding="utf-8") as f:
f.write(str(all_img_list))
f.write("\n")
# 示例路径(根据实际情况修改)
template_folder = r"G:\dataset\M3FD\M3FD_Detection\templates"
folder = r"G:\dataset\M3FD\M3FD_Detection\cut_imgs"
# 调整 max_workers 的值以控制并行处理的数量
template_folder_match_des_folder(template_folder, folder, max_workers=4)
效果:
汽车所有模版图
所有的汽车图
算法得到的结果图:
效果展示:
存在部分分类错误的情况:
优化建议:
黑车模版存在白车的情况,可以从颜色的特征进一步优化算法:
数据采用的是M3FD里面的车辆类别数据集