背景
在处理大规模图像数据时,图像去重是一个非常重要的环节,尤其是在爬取网络数据、构建数据集或清理数据时。本文将介绍一种基于特征提取和相似度计算的图像去重方法,重点讲解如何使用特征向量和余弦相似度实现高效去重。
一、方法简介
基于特征的图像去重主要分为两步:
- 特征提取:通过预训练模型(如 ResNet50)提取每张图像的特征向量。
- 相似度计算与去重:计算图像特征之间的相似度,根据设定的阈值,保留唯一的图像。
本文的核心方法是 deduplicate_by_similarity,它通过余弦相似度检测特征向量之间的相似性,逐步排除重复图像。
二、特征提取
在去重之前,我们需要为每张图像提取一个固定长度的特征向量。这里使用了深度学习中常见的 ResNet50 预训练模型(去掉分类层),将输入图像转换为一个固定长度的特征向量。
特征提取的实现
以下是特征提取的核心代码:
def _extract_feature(self, image_path):
try:
image = Image.open(image_path).convert('RGB') # 加载图像并转换为 RGB 模式
image_tensor = self.transform(image).unsqueeze(0).to(self.device) # 图像预处理
with torch.no_grad(): # 关闭梯度计算
return self.model(image_tensor).squeeze().cpu().numpy() # 提取特征并转为 NumPy
except Exception as e:
print(f"Error processing {image_path}: {e}")
return None
特征提取流程说明
- 加载图像:使用 PIL 加载图像并转换为 RGB 模式。
- 预处理:将图像调整为模型输入尺寸(224×224),归一化后转换为张量。
- 特征提取:使用预训练模型提取特征,并移除分类层,得到一个高维特征向量。
提取完成后,所有图像的特征向量会存储为一个矩阵,形状为 (N, D),其中 N 是图像数量,D 是特征维度(ResNet50 的默认输出为 2048 维)。
三、基于相似度的去重方法
特征提取完成后,我们可以通过计算特征向量之间的相似度,来判断哪些图像是重复的。这里采用了 余弦相似度 方法,因为它对特征的尺度变化不敏感,计算简单且效果良好。
deduplicate_by_similarity 方法实现
以下是该方法的完整代码:
def deduplicate_by_similarity(self, features, threshold=0.95, batch_size=1000):
"""
基于相似度进行去重
:param features: 特征矩阵 (N x D)
:param threshold: 相似度阈值
:param batch_size: 批处理大小
:return: 保留的图像索引列表
"""
num_images = len(features)
to_keep = []
excluded = np.zeros(num_images, dtype=bool) # 标记已排除的图像
features /= np.linalg.norm(features, axis=1, keepdims=True) # 标准化特征向量
for i in tqdm(range(num_images), desc="Processing images", unit="image"):
if excluded[i]: # 跳过已标记为重复的图像
continue
to_keep.append(i) # 保留当前图像
# 找出未排除的图像索引
remaining_indices = np.where(~excluded[i + 1:])[0] + i + 1
if not len(remaining_indices): # 如果没有剩余图像,则跳过
continue
# 分批计算余弦相似度
for j in range(0, len(remaining_indices), batch_size):
batch_end = min(j + batch_size, len(remaining_indices))
batch_features = features[remaining_indices[j:batch_end]] # 批量特征
batch_similarities = np.dot(features[i:i+1], batch_features.T).squeeze() # 计算相似度
excluded[remaining_indices[j:batch_end]] |= batch_similarities > threshold # 标记相似度超过阈值的图像
return to_keep
方法详解
1. 特征标准化:
features /= np.linalg.norm(features, axis=1, keepdims=True)
通过标准化,将每个特征向量的长度归一化为 1,从而方便计算余弦相似度。
2. 逐步去重:
- 遍历每张图像的特征向量。
- 如果当前图像未被标记为重复,则将其保留,并计算它与剩余图像的相似度。
3. 批量计算相似度:
为了降低内存消耗,相似度计算分批进行:
batch_similarities = np.dot(features[i:i+1], batch_features.T).squeeze()
利用矩阵乘法快速计算当前图像与某一批次图像间的余弦相似度。
4. 去重逻辑:
如果相似度超过阈值(如 0.95),则将对应图像标记为重复:
excluded[remaining_indices[j:batch_end]] |= batch_similarities > threshold
5. 结果输出:
最终返回未被标记为重复的图像索引:
return to_keep
四、完整代码
如果你对完整代码感兴趣,可以参考以下实现:
import os
from concurrent.futures import ThreadPoolExecutor
from torchvision import models, transforms
from PIL import Image
import numpy as np
import torch
from tqdm import tqdm
class ImageDeduplicator:
def __init__(self, model_name='resnet50', num_threads=8, gpu_id=1):
"""
初始化去重器
:param model_name: 使用的预训练模型名称(支持 resnet50 等 torchvision 模型)
:param num_threads: 多线程的线程数
:param gpu_id: 指定使用的 GPU ID,默认为 1。如果设置为 -1,则使用 CPU。
"""
self.device = self._set_device(gpu_id)
self.model = self._load_model(model_name).to(self.device).eval()
self.num_threads = num_threads
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def _set_device(self, gpu_id):
"""设置设备(GPU 或 CPU)"""
if gpu_id == -1 or not torch.cuda.is_available():
print("Using CPU for computation.")
return torch.device("cpu")
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
print(f"Using GPU: {gpu_id}")
return torch.device("cuda:0")
def _load_model(self, model_name):
"""加载预训练模型,并移除分类层"""
if model_name != 'resnet50':
raise ValueError(f"Unsupported model: {model_name}")
from torchvision.models import ResNet50_Weights
weights = ResNet50_Weights.DEFAULT
model = models.resnet50(weights=weights)
return torch.nn.Sequential(*list(model.children())[:-1])
def _extract_feature(self, image_path):
"""提取单张图像的特征向量"""
try:
image = Image.open(image_path).convert('RGB') # 加载图像并转换为 RGB 模式
image_tensor = self.transform(image).unsqueeze(0).to(self.device) # 图像预处理
with torch.no_grad(): # 关闭梯度计算
return self.model(image_tensor).squeeze().cpu().numpy() # 提取特征并转为 NumPy
except Exception as e:
print(f"Error processing {image_path}: {e}")
return None
def extract_features(self, image_paths):
"""使用多线程提取图像特征"""
with ThreadPoolExecutor(max_workers=self.num_threads) as executor:
results = list(tqdm(executor.map(self._extract_feature, image_paths), total=len(image_paths)))
return np.array([res for res in results if res is not None])
def deduplicate_by_similarity(self, features, threshold=0.95, batch_size=1000):
"""
基于相似度进行去重
:param features: 特征矩阵 (N x D)
:param threshold: 相似度阈值
:param batch_size: 批处理大小
:return: 保留的图像索引列表
"""
num_images = len(features)
to_keep = []
excluded = np.zeros(num_images, dtype=bool) # 标记已排除的图像
features /= np.linalg.norm(features, axis=1, keepdims=True) # 标准化特征向量
for i in tqdm(range(num_images), desc="Processing images", unit="image"):
if excluded[i]: # 跳过已标记为重复的图像
continue
to_keep.append(i) # 保留当前图像
# 找出未排除的图像索引
remaining_indices = np.where(~excluded[i + 1:])[0] + i + 1
if not len(remaining_indices): # 如果没有剩余图像,则跳过
continue
# 分批计算余弦相似度
for j in range(0, len(remaining_indices), batch_size):
batch_end = min(j + batch_size, len(remaining_indices))
batch_features = features[remaining_indices[j:batch_end]] # 批量特征
batch_similarities = np.dot(features[i:i+1], batch_features.T).squeeze() # 计算相似度
excluded[remaining_indices[j:batch_end]] |= batch_similarities > threshold # 标记相似度超过阈值的图像
return to_keep
def deduplicate(self, image_paths, threshold=0.95, batch_size=1000):
"""
对图像路径列表去重
:param image_paths: 图像路径列表
:param threshold: 相似度阈值
:param batch_size: 批处理大小
:return: 去重后的图像路径列表
"""
print("Extracting features...")
features = self.extract_features(image_paths)
print("Performing deduplication...")
indices_to_keep = self.deduplicate_by_similarity(features, threshold, batch_size)
return [image_paths[i] for i in indices_to_keep]
五、总结
完整代码提供了从特征提取到去重的完整实现,用户只需提供一组图像路径,即可高效完成去重任务。这种方法具有以下特点:
- 高效性:支持多线程提取特征,批量计算相似度。
- 灵活性:支持多种预训练模型,阈值参数可调。
- 易用性:代码封装良好,调用简单。
你可以根据自己的场景需求调整代码参数,例如更换预训练模型、调整相似度阈值等。
如果你在使用过程中有任何问题,欢迎在评论区留言交流!