Dify混合检索源码解析

混合检索源码解析

在这里插入图片描述

RetrievalService.retrieve

代码目录:api/core/rag/datasource/retrieval_service.py

该函数用于根据指定的检索方法执行检索操作

def retrieve(cls, retrival_method: str, dataset_id: str, query: str,
             top_k: int, score_threshold: Optional[float] = .0,
             reranking_model: Optional[dict] = None, reranking_mode: Optional[str] = None,
             weights: Optional[dict] = None):
    _"""_
_    根据指定的检索方法执行检索操作。_

_    :param retrival_method: 检索方法类型。 semantic_search full_text_search keyword_search,hybrid_search_
_    :param dataset_id: 数据集ID。_
_    :param query: 检索查询字符串。_
_    :param top_k: 返回结果的最大数量。_
_    :param score_threshold: 分数阈值,用于过滤结果。_
_    :param reranking_model: 重排序模型配置,用于对结果进行二次排序。_
_    :return: 检索结果列表。_
_    """_
_    _dataset = db.session.query(Dataset).filter(
        Dataset.id == dataset_id
    ).first()
    if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0:
        return []
    all_documents = [] # 存储所有检索结果
    keyword_search_documents = []
    embedding_search_documents = []
    full_text_search_documents = []
    hybrid_search_documents = []
    threads = [] # 存储执行检索的线程
    exceptions = []  # 存储执行过程中遇到的异常
    # retrieval_model source with keyword
    # 关键词检索
    if retrival_method == 'keyword_search':
        keyword_thread = threading.Thread(target=RetrievalService.keyword_search, kwargs={
            'flask_app': current_app._get_current_object(),
            'dataset_id': dataset_id,
            'query': query,
            'top_k': top_k,
            'all_documents': all_documents,
            'exceptions': exceptions,
        })
        threads.append(keyword_thread)
        keyword_thread.start()
    # 向量检索(混合检索中也会调用)

    if RetrievalMethod.is_support_semantic_search(retrival_method):
        embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
            'flask_app': current_app._get_current_object(),
            'dataset_id': dataset_id,
            'query': query,
            'top_k': top_k,
            'score_threshold': score_threshold,
            'reranking_model': reranking_model,
            'all_documents': all_documents,
            'retrival_method': retrival_method,
            'exceptions': exceptions,
        })
        threads.append(embedding_thread)
        embedding_thread.start()

    # 文本检索(混合检索中也会调用)
    if RetrievalMethod.is_support_fulltext_search(retrival_method):
        full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
            'flask_app': current_app._get_current_object(),
            'dataset_id': dataset_id,
            'query': query,
            'retrival_method': retrival_method,
            'score_threshold': score_threshold,
            'top_k': top_k,
            'reranking_model': reranking_model,
            'all_documents': all_documents,
            'exceptions': exceptions,
        })
        threads.append(full_text_index_thread)
        full_text_index_thread.start()
    # 等待所有线程完成
    for thread in threads:
        thread.join()
    # 如果执行过程中有异常,则合并异常信息并抛出
    if exceptions:
        exception_message = ';\n'.join(exceptions)
        raise Exception(exception_message)
    # 混合检索之后会执行向量和文本检索结果合并后的重排序
    if retrival_method == RetrievalMethod.HYBRID_SEARCH.value:
        data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_mode,
                                                reranking_model, weights, False)
        all_documents = data_post_processor.invoke(
            query=query,
            documents=all_documents,
            score_threshold=score_threshold,
            top_n=top_k
        )
        # print("混合搜索",all_documents)
    return all_documents

假设本次选择的检索方法是 高质量-混合检索-权重则代码分别会开启两个线程走到 RetrievalMethod.is_support_semantic_search(retrival_method)

RetrievalMethod.is_support_fulltext_search(retrival_method)

执行向量以及全文检索查询,之后会走到 DataPostProcessor 中进行混合检索,然后根据是否开启 Rerank 模型进行重排序

DataPostProcessor

代码目录:api/core/rag/data_post_processor/data_post_processor.py

DataPostProcessor 定义了一个数据后处理接口,用于对搜索结果进行重排序和重新排序。构造函数初始化了重排序和重新排序运行器实例,而 invoke 方法则根据这些实例执行相应的操作,并返回处理后的文档列表。

class DataPostProcessor:
    """接口用于数据后处理文档。
    """

    def __init__(self, tenant_id: str, reranking_mode: str,
                 reranking_model: Optional[dict] = None, weights: Optional[dict] = None,
                 reorder_enabled: bool = False):
        _# 初始化 DataPostProcessor 实例_
        _# 创建重排序运行器实例_
        self.rerank_runner = self._get_rerank_runner(reranking_mode, tenant_id, reranking_model, weights)
        _# 创建重新排序运行器实例_
        self.reorder_runner = self._get_reorder_runner(reorder_enabled)
_get_rerank_runner

根据提供的配置来创建一个重排序(reranking)运行器实例,调整搜索结果的顺序以提高相关性。

####### 基于权重的重排序

  1. 如果 reranking_modeWEIGHTED_SCORE 且提供了 weights,则创建一个 WeightRerankRunner 实例,并使用给定的 weights 配置。
  2. weights 字典包含了 vector_settingkeyword_setting,这些设置分别用于向量和关键词的权重配置。

####### 基于模型的重排序

  1. 如果 reranking_modeRERANKING_MODEL 且提供了 reranking_model,则尝试获取对应的模型实例。
  2. 获取模型实例的过程可能抛出 InvokeAuthorizationError 异常,此时返回 None
  3. 成功获取模型实例后,创建并返回一个 RerankModelRunner 实例。

####### 详细代码

def _get_rerank_runner(self, reranking_mode: str, tenant_id: str, reranking_model: Optional[dict] = None,
                       weights: Optional[dict] = None) -> Optional[RerankModelRunner | WeightRerankRunner]:
    # 根据提供的配置来创建一个重排序(reranking)运行器实例。重排序通常用于调整搜索结果的顺序以提高相关性。

    # 判断是否使用基于权重的重排序模式
    if reranking_mode == RerankMode.WEIGHTED_SCORE.value and weights:
        # 创建 WeightRerankRunner 实例
        return WeightRerankRunner(
            tenant_id,  # 租户ID
            Weights( # 创建 Weights 对象
                vector_setting=VectorSetting( # 创建 VectorSetting 对象
                    vector_weight=weights['vector_setting']['vector_weight'],  # 向量权重
                    embedding_provider_name=weights['vector_setting']['embedding_provider_name'], # 嵌入式提供商名称
                    embedding_model_name=weights['vector_setting']['embedding_model_name'],  # 嵌入式模型名称
                ),
                keyword_setting=KeywordSetting( # 创建 KeywordSetting 对象
                    keyword_weight=weights['keyword_setting']['keyword_weight'], # 关键词权重
                )
            )
        )
    # 判断是否使用基于模型的重排序模式
    elif reranking_mode == RerankMode.RERANKING_MODEL.value:
        if reranking_model:
            try:
                # 创建 ModelManager 实例
                model_manager = ModelManager()
                # 获取模型实例
                rerank_model_instance = model_manager.get_model_instance(
                    tenant_id=tenant_id, # 租户ID
                    provider=reranking_model['reranking_provider_name'], # 模型提供商名称
                    model_type=ModelType.RERANK,  # 模型类型
                    model=reranking_model['reranking_model_name'] # 模型名称
                )
            except InvokeAuthorizationError:
                return None
            return RerankModelRunner(rerank_model_instance)
        return None
    # 如果既不是基于权重也不是基于模型的重排序模式,则返回 None
    return None
invoke

根据返回的实例 WeightRerankRunner/ModelManager 调用对应的 run 方法 在构造函数初始化了重排序和重新排序运行器实例

def invoke(self, query: str, documents: list[Document], score_threshold: Optional[float] = None,
           top_n: Optional[int] = None, user: Optional[str] = None) -> list[Document]:
    #根据返回的实例WeightRerankRunner/ModelManager调用对应的run方法
    if self.rerank_runner:
        _# 执行重排序_
        documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user)
        
    if self.reorder_runner:
        _# 执行重新排序_
        documents = self.reorder_runner.run(documents)

    return documents
WeightRerankRunner

代码目录:api/core/rag/rerank/weight_rerank.py

基于权重的重排序功能,用于对搜索结果中的文档进行重新排序,以提高相关性。

run

执行重排序逻辑,包括去除重复文档、计算关键词得分、计算向量得分、合并得分、应用分数阈值、更新文档元数据中的得分、排序文档以及返回指定数量的文档。

def run(self, query: str, documents: list[Document], score_threshold: Optional[float] = None,
        top_n: Optional[int] = None, user: Optional[str] = None) -> list[Document]:
    _"""_
_   执行基于权重的重排序_

_    :param query:  搜索查询_
_    :param documents:  文档列表_
_    :param score_threshold:  分数阈值_
_    :param top_n: 返回的文档数量上限_
_    :param user: unique user id if needed_

_    :return:  重排序后的文档列表_
_    """_
_    _docs = [] # 用于存储文档内容
    doc_id = []  # 用于存储文档ID
    unique_documents = []  # 用于存储去重后的文档
    for document in documents:
        if document.metadata['doc_id'] not in doc_id:
            doc_id.append(document.metadata['doc_id'])  # 添加文档ID
            docs.append(document.page_content) # 添加文档内容
            unique_documents.append(document) # 添加文档

    documents = unique_documents  # 更新文档列表为去重后的列表

    rerank_documents = []   # 用于存储重排序后的文档
    # 计算关键词
    query_scores = self._calculate_keyword_score(query, documents)
    # 计算向量
    query_vector_scores = self._calculate_cosine(self.tenant_id, query, documents, self.weights.vector_setting)
    # 合并
    for document, query_score, query_vector_score in zip(documents, query_scores, query_vector_scores):
        # 计算文档的最终
        score = self.weights.vector_setting.vector_weight * query_vector_score + \
                self.weights.keyword_setting.keyword_weight * query_score
        # 应用分数阈值
        if score_threshold and score < score_threshold:
            continue
        # 更新文档元数据中的得分
        document.metadata['score'] = score
        # 添加文档到重排序列表
        rerank_documents.append(document)
    # 排序文档
    rerank_documents = sorted(rerank_documents, key=lambda x: x.metadata['score'], reverse=True)
    # 返回指定数量的文档
    return rerank_documents[:top_n] if top_n else rerank_documents

####### _calculate_keyword_score

根据关键词计算文档与查询之间的余弦相似度得分。

def _calculate_keyword_score(self, query: str, documents: list[Document]) -> list[float]:
    _"""_
_      计算文档与查询之间的余弦相似度得分。_
_    :param query: search query_
_    :param documents: documents for reranking_

_    :return:   每个文档的BM25得分列表_
_    """_
_    _keyword_table_handler = JiebaKeywordTableHandler() # 创建关键词处理实例
    query_keywords = keyword_table_handler.extract_keywords(query, None) # 提取查询关键词
    documents_keywords = [] # 用于存储文档关键词
    for document in documents:
        # # 提取文档关键词
        document_keywords = keyword_table_handler.extract_keywords(document.page_content, None)
        document.metadata['keywords'] = document_keywords  # 将关键词存入文档元数据
        documents_keywords.append(document_keywords)  # 添加文档关键词到列表

    # 统计查询关键词的词频(TF)
    query_keyword_counts = Counter(query_keywords)
    # 总文档数
    total_documents = len(documents)

    #  计算所有文档关键词的逆文档频率(IDF)
    all_keywords = set()  # 用于存储所有文档中的关键词
    for document_keywords in documents_keywords:
        all_keywords.update(document_keywords) # 更新所有关键词集合

    keyword_idf = {}   # 用于存储关键词的IDF值
    for keyword in all_keywords:
        # 计算包含特定关键词的文档数量
        doc_count_containing_keyword = sum(1 for doc_keywords in documents_keywords if keyword in doc_keywords)
        # 计算IDF
        keyword_idf[keyword] = math.log((1 + total_documents) / (1 + doc_count_containing_keyword)) + 1

    query_tfidf = {} # 用于存储查询关键词的TF-IDF值

    for keyword, count in query_keyword_counts.items():
        tf = count  # 查询关键词的词频
        idf = keyword_idf.get(keyword, 0)  # 查询关键词的IDF
        query_tfidf[keyword] = tf * idf # 计算查询关键词的TF-IDF值

    # 计算所有文档的TF-IDF值
    documents_tfidf = []
    for document_keywords in documents_keywords:
        document_keyword_counts = Counter(document_keywords) # 统计文档关键词的词频
        document_tfidf = {} # 用于存储文档关键词的TF-IDF值
        for keyword, count in document_keyword_counts.items():
            tf = count # 文档关键词的词频
            idf = keyword_idf.get(keyword, 0) # 文档关键词的IDF
            document_tfidf[keyword] = tf * idf  # 计算文档关键词的TF-IDF值
        documents_tfidf.append(document_tfidf) # 添加文档TF-IDF值到列表

    # 定义余弦相似度计算函数
    def cosine_similarity(vec1, vec2):
        # 计算两个向量的交集
        intersection = set(vec1.keys()) & set(vec2.keys())
        # 计算分子
        numerator = sum(vec1[x] * vec2[x] for x in intersection)
        # 计算分母
        sum1 = sum(vec1[x] ** 2 for x in vec1.keys())
        sum2 = sum(vec2[x] ** 2 for x in vec2.keys())
        denominator = math.sqrt(sum1) * math.sqrt(sum2)
        # 避免除以零的情况
        if not denominator:
            return 0.0
        else:
            return float(numerator) / denominator

    # 初始化用于存储文档相似度得分的列表
    similarities = []
    # 计算每个文档与查询之间的相似度得分
    for document_tfidf in documents_tfidf:
        similarity = cosine_similarity(query_tfidf, document_tfidf)
        similarities.append(similarity)

    # for idx, similarity in enumerate(similarities):
    #     print(f"Document {idx + 1} similarity: {similarity}")
    # 返回相似度得分列表
    return similarities

####### _calculate_cosine

根据向量计算_查询与文档之间的余弦相似度得分。_

def _calculate_cosine(self, tenant_id: str, query: str, documents: list[Document],
                      vector_setting: VectorSetting) -> list[float]:
    _"""_
_    计算查询与文档之间的余弦相似度得分。_

_    :param tenant_id: 租户ID,用于获取特定租户的模型实例_
_    :param query: 搜索查询字符串_
_    :param documents: 包含多个文档的列表_
_    :param vector_setting: 向量设置对象,包含嵌入模型的信息_
_    :return: 每个文档与查询之间的余弦相似度得分列表_
_    """_
_    _# 初始化用于存储查询与文档之间余弦相似度得分的列表
    query_vector_scores = []
    # 创建模型管理器实例
    model_manager = ModelManager()
    # 获取指定租户的嵌入模型实例
    embedding_model = model_manager.get_model_instance(
        tenant_id=tenant_id, # 租户ID
        provider=vector_setting.embedding_provider_name, # 嵌入模型提供商名称
        model_type=ModelType.TEXT_EMBEDDING, # 模型类型为文本嵌入
        model=vector_setting.embedding_model_name # 嵌入模型名称

    )
    # 创建缓存嵌入实例
    cache_embedding = CacheEmbedding(embedding_model)
    # 为查询生成嵌入向量
    query_vector = cache_embedding.embed_query(query)
    # 遍历每个文档
    for document in documents:
        # 如果文档元数据中已经存在得分,则直接使用
        if 'score' in document.metadata:
            query_vector_scores.append(document.metadata['score'])
        else:
            # 获取文档的嵌入向量
            content_vector = document.metadata['vector']
            # 将查询向量和文档向量转换为NumPy数组
            vec1 = np.array(query_vector)
            vec2 = np.array(document.metadata['vector'])

            # 计算点积
            dot_product = np.dot(vec1, vec2)

            # 计算向量范数
            norm_vec1 = np.linalg.norm(vec1)
            norm_vec2 = np.linalg.norm(vec2)

            # 计算余弦相似度
            cosine_sim = dot_product / (norm_vec1 * norm_vec2)
            # 将余弦相似度得分添加到列表中
            query_vector_scores.append(cosine_sim)

    return query_vector_scores
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值