图片查重从设计到实现(7) :使用 Milvus 实现高效图片查重功能

使用 Milvus 实现高效图片查重功能

本文将介绍如何利用 Milvus 向量数据库构建一个高效的图片查重系统,通过传入图片就能快速从已有数据中找出匹配度高的相似图片。

一.什么是图片查重?

图片查重指的是通过算法识别出内容相同或高度相似的图片,即使它们可能存在尺寸、格式、轻微编辑等差异。传统的基于文件名或元数据的查重方法效果极差,而基于内容的图片查重则能真正识别视觉上相似的图片。

二. 技术原理

基于 Milvus 的图片查重系统主要依赖以下关键技术:

  1. 图片特征提取:使用深度学习模型将图片转换为固定维度的特征向量,捕捉图片的视觉特征
  2. 向量相似度搜索:通过计算向量之间的距离(相似度)来判断图片的相似程度
  3. 高效向量数据库:Milvus 提供的高性能向量索引和搜索能力,支持亿级数据的毫秒级检索

核心流程:

  • 预处理:将所有图片转换为特征向量并存储到 Milvus
  • 查重阶段:对输入图片提取特征向量,在 Milvus 中搜索相似度高于阈值的向量,找到对应的图片

三 .实现步骤

1. 核心代码实现

下面是完整的图片查重系统实现,包含特征提取和 Milvus 操作:

2. 代码解析

核心组件
  1. 图片特征提取器(ImageFeatureExtractor)
class ImageFeatureExtractor:
    """图片特征提取器,将图片转换为特征向量"""
    
    def __init__(self):
        # 使用预训练的ResNet50模型
        self.model = models.resnet50(pretrained=True)
        # 移除最后一层全连接层,保留特征提取部分
        self.model = torch.nn.Sequential(*list(self.model.children())[:-1])
        self.model.eval()  # 切换到评估模式
        
        # 确保使用适当的设备(GPU如果可用)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        
        # 图片预处理流程,与模型训练时保持一致
        self.transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])
        
        logger.info(f"特征提取器初始化完成,使用设备: {self.device}")
    
    def extract(self, image_path):
        """提取单张图片的特征向量"""
        try:
            # 打开图片并转换为RGB格式
            img = Image.open(image_path).convert('RGB')
            # 预处理
            img_tensor = self.transform(img).unsqueeze(0)  # 添加批次维度
            img_tensor = img_tensor.to(self.device)
            
            # 提取特征
            with torch.no_grad():  # 不计算梯度,加速推理
                features = self.model(img_tensor)
            
            # 展平为一维向量并归一化
            feature_vector = features.squeeze().cpu().numpy()
            normalized_vector = feature_vector / np.linalg.norm(feature_vector)
            
            return normalized_vector
            
        except Exception as e:
            logger.error(f"提取图片 {image_path} 特征失败: {str(e)}")
            return None
    
    def batch_extract(self, image_paths):
        """批量提取图片特征"""
        features = []
        valid_paths = []
        
        for path in image_paths:
            feat = self.extract(path)
            if feat is not None:
                features.append(feat)
                valid_paths.append(path)
        
        return valid_paths, features
    
  • 使用预训练的 ResNet50 模型提取图片特征
  • 对图片进行标准化处理(Resize、裁剪、归一化)
  • 支持单张和批量图片特征提取
  • 自动选择 GPU/CPU 设备加速处理

class MilvusImageChecker:
    """基于Milvus的图片查重工具"""
    
    def __init__(self, host=Config.MILVUS_HOST, port=Config.MILVUS_PORT, 
                 collection_name=Config.COLLECTION_NAME):
        self.host = host
        self.port = port
        self.collection_name = collection_name
        self.collection = None
        
        # 连接Milvus并初始化集合
        self.connect()
        self.init_collection()
    
    def connect(self):
        """连接到Milvus服务器"""
        try:
            connections.connect(
                alias="default",
                host=self.host,
                port=self.port
            )
            logger.info(f"成功连接到Milvus服务器: {self.host}:{self.port}")
        except Exception as e:
            logger.error(f"连接Milvus服务器失败: {str(e)}")
            raise
    
    def init_collection(self):
        """初始化集合,如不存在则创建"""
        if utility.has_collection(self.collection_name):
            self.collection = Collection(self.collection_name)
            logger.info(f"已加载集合: {self.collection_name}")
            return
        
        # 定义集合字段
        fields = [
            FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
            FieldSchema(name="image_path", dtype=DataType.VARCHAR, max_length=512),
            FieldSchema(name="feature_vector", dtype=DataType.FLOAT_VECTOR, dim=Config.VECTOR_DIM),
            FieldSchema(name="upload_time", dtype=DataType.INT64, 
                       description="图片上传时间戳,用于去重时保留最早/最新版本")
        ]
        
        # 创建集合schema
        schema = CollectionSchema(
            fields, 
            "用于图片查重的集合,存储图片路径和特征向量",
            enable_dynamic_field=False
        )
        
        # 创建集合
        self.collection = Collection(self.collection_name, schema)
        logger.info(f"成功创建集合: {self.collection_name}")
        
        # 创建索引,优化查询性能
        index_params = {
            "index_type": "HNSW",  # 适用于高维向量的高效索引
            "metric_type": "IP",   # 内积,适用于归一化向量的相似度计算
            "params": {
                "M": 16,           # HNSW参数,影响索引质量和查询速度
                "efConstruction": 200  # 构建索引时的参数
            }
        }
        
        self.collection.create_index(
            field_name="feature_vector",
            index_params=index_params
        )
        logger.info("集合索引创建完成")
    
    def insert_images(self, image_paths, upload_timestamps=None):
        """
        插入图片特征到Milvus
        
        参数:
            image_paths: 图片路径列表
            upload_timestamps: 上传时间戳列表,用于去重时判断保留哪个版本
        """
        if not image_paths:
            logger.warning("没有图片路径可插入")
            return []
        
        # 处理时间戳(默认使用当前时间)
        if upload_timestamps is None:
            current_ts = int(os.path.getmtime(image_paths[0])) if image_paths else 0
            upload_timestamps = [current_ts] * len(image_paths)
        
        # 批量提取特征
        extractor = ImageFeatureExtractor()
        valid_paths, features = extractor.batch_extract(image_paths)
        
        if not valid_paths:
            logger.warning("没有有效图片可插入")
            return []
        
        # 准备插入数据
        data = [
            valid_paths,  # image_path字段
            [f.tolist() for f in features],  # feature_vector字段
            upload_timestamps[:len(valid_paths)]  # upload_time字段
        ]
        
        # 执行插入
        try:
            insert_result = self.collection.insert(data)
            self.collection.flush()  # 刷新到磁盘
            logger.info(f"成功插入 {len(valid_paths)} 张图片,ID范围: {insert_result.primary_keys}")
            return insert_result.primary_keys
        except Exception as e:
            logger.error(f"插入图片失败: {str(e)}")
            return []
    
    def check_duplicates(self, image_path, threshold=Config.DEFAULT_THRESHOLD, 
                        top_k=Config.DEFAULT_TOP_K):
        """
        检查指定图片是否存在重复或高度相似的图片
        
        参数:
            image_path: 待检查的图片路径
            threshold: 相似度阈值,高于此值认为是相似图片
            top_k: 返回的最大相似图片数量
            
        返回:
            相似图片列表,按相似度降序排列
        """
        # 提取查询图片特征
        extractor = ImageFeatureExtractor()
        query_vector = extractor.extract(image_path)
        
        if query_vector is None:
            logger.error("无法提取查询图片特征,查重失败")
            return []
        
        # 加载集合到内存(如果尚未加载)
        if not self.collection.is_loaded:
            self.collection.load()
            logger.info(f"集合 {self.collection_name} 已加载到内存")
        
        # 配置搜索参数
        search_params = {
            "metric_type": "IP",
            "params": {"ef": 64}  # 搜索时的参数,影响查询精度和速度
        }
        
        # 执行相似度搜索
        try:
            results = self.collection.search(
                data=[query_vector.tolist()],
                anns_field="feature_vector",
                param=search_params,
                limit=top_k,
                output_fields=["image_path", "upload_time"]
            )
            
            # 处理搜索结果,过滤低于阈值的结果
            duplicates = []
            for hits in results:
                for hit in hits:
                    similarity = hit.distance
                    if similarity >= threshold:
                        duplicates.append({
                            "image_path": hit.entity.get("image_path"),
                            "similarity": float(similarity),
                            "milvus_id": hit.id,
                            "upload_time": hit.entity.get("upload_time")
                        })
            
            # 按相似度降序排序
            duplicates.sort(key=lambda x: x["similarity"], reverse=True)
            logger.info(f"找到 {len(duplicates)} 张相似图片 (阈值: {threshold})")
            return duplicates
            
        except Exception as e:
            logger.error(f"查重搜索失败: {str(e)}")
            return []
    
    def delete_duplicates(self, duplicate_ids):
        """删除指定ID的重复图片记录"""
        if not duplicate_ids:
            return True
            
        try:
            self.collection.delete(f"id in {duplicate_ids}")
            self.collection.flush()
            logger.info(f"成功删除 {len(duplicate_ids)} 条重复记录")
            return True
        except Exception as e:
            logger.error(f"删除重复记录失败: {str(e)}")
            return False

  1. Milvus 查重工具(MilvusImageChecker)
    • 负责与 Milvus 服务器的连接和交互
    • 初始化集合并创建高效索引(使用 HNSW 索引)
    • 提供图片特征插入、重复图片查询和删除功能
    • 支持批量操作,提高处理效率
关键技术点
  • 特征向量归一化:确保内积(IP)可以直接作为余弦相似度使用
  • 合适的索引选择:使用 HNSW 索引平衡查询速度和精度
  • 相似度阈值:可根据业务需求调整,值越高表示要求越相似
  • 时间戳管理:记录图片上传时间,便于去重时选择保留最早或最新版本

五 .重点

  1. 模型选择

    • 追求精度:可使用更复杂的模型如 ResNet101 或 Vision Transformer
    • 追求速度:可使用轻量级模型如 MobileNet 或 EfficientNet
  2. 索引优化

    • 对于大规模数据,可调整 HNSW 索引的 MefConstruction 参数
    • 可尝试不同索引类型(如 IVF_FLAT、IVF_SQ8)找到性能平衡点
  3. 阈值调整

    • 对于严格查重(完全相同的图片),可将阈值设为 0.98 以上
    • 对于相似图片检索,可将阈值设为 0.85-0.95 之间
  4. 分布式部署

    • 对于超大规模图片库,可使用 Milvus 集群提高吞吐量和可靠性

六 应用场景

  • 内容管理系统:自动检测并去重上传的图片
  • 电商平台:识别盗图和相似商品图片
  • 版权保护:追踪未经授权使用的图片
  • 相册管理:自动整理相似照片,减少冗余

总结

基于 Milvus 的图片查重系统能够高效处理海量图片数据,通过特征向量和相似度搜索技术,实现了精准的重复图片识别。相比传统方法,它具有以下优势:

  • 识别真正视觉相似的图片,不受文件名或格式影响
  • 支持亿级图片的快速检索,毫秒级响应
  • 可灵活调整相似度阈值,适应不同业务需求
  • 易于扩展和集成到现有系统中
### Milvus 数据存储的 Python 实现 以下是关于如何使用 Python 集成 Milvus实现数据存储的具体教程和示例代码。 #### 1. 安装依赖库 为了能够与 Milvus 进行交互,需要先安装 `pymilvus` 库。可以通过以下命令完成安装: ```bash pip install pymilvus==2.0.0 ``` 此版本号应根据实际需求调整,确保兼容所使用Milvus 版本[^3]。 #### 2. 启动 Milvus 服务 如果尚未启动 Milvus 服务,可以使用 Docker 来快速部署 Milvus。运行以下命令即可启动 Milvus 的 CPU 版本实例: ```bash docker run -d --name milvus_cpu -p 19530:19530 -p 19121:19121 milvusdb/milvus:v2.0.0-cpu-d030822-7e21de ``` 该命令会创建并运行一个名为 `milvus_cpu` 的容器,并将其端口映射到主机上[^3]。 #### 3. 使用 Python SDK 存储数据 下面是一个完整的 Python 脚本示例,展示如何连接到 Milvus、创建集合以及插入向量数据。 ```python from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection # 建立与 Milvus 的连接 connections.connect(alias="default", host="localhost", port="19530") # 定义字段模式 fields = [ FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=False), FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=128) ] # 创建集合方案 schema = CollectionSchema(fields, "Example collection for storing embeddings") # 创建集合 collection_name = "example_collection" collection = Collection(name=collection_name, schema=schema) # 插入数据 data = [ [i for i in range(10)], # id 列表 [[float(j) for j in range(128)] for _ in range(10)] # embedding 向量列表 ] mr = collection.insert(data) # 加载集合以便执行查询操作 index_params = { "metric_type": "L2", "index_type": "IVF_FLAT", "params": {"nlist": 128} } collection.create_index(field_name="embedding", index_params=index_params) collection.load() print("Data inserted successfully!") ``` 以上脚本实现了以下几个功能: - **建立连接**:通过 `connections.connect()` 方法指定目标地址来连接本地或远程的 Milvus 实例。 - **定义字段**:设置两个字段——一个是整数类型的主键 (`id`);另一个是浮点型数组形式的嵌入向量 (`embedding`)。 - **构建集合**:利用给定的字段描述符生成一个新的集合对象。 - **写入记录**:调用 `insert()` 函数批量导入模拟数据集。 - **配置索引**:为提高检索性能而设定合适的参数组合[^1]。 #### 注意事项 在实际开发过程中需要注意以下几点: - 确认所选维度大小 (dim 参数值) 符合业务逻辑中的特征表示长度; - 对于大规模生产环境而言,建议采用 GPU 支持版本提升处理能力; - 不同类型的应用场景可能涉及更多复杂的预处理流程,比如降维算法的选择等[^2]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值