混合检索源码解析

RetrievalService.retrieve
代码目录:api/core/rag/datasource/retrieval_service.py
该函数用于根据指定的检索方法执行检索操作
def retrieve(cls, retrival_method: str, dataset_id: str, query: str,
top_k: int, score_threshold: Optional[float] = .0,
reranking_model: Optional[dict] = None, reranking_mode: Optional[str] = None,
weights: Optional[dict] = None):
_"""_
_ 根据指定的检索方法执行检索操作。_
_ :param retrival_method: 检索方法类型。 semantic_search full_text_search keyword_search,hybrid_search_
_ :param dataset_id: 数据集ID。_
_ :param query: 检索查询字符串。_
_ :param top_k: 返回结果的最大数量。_
_ :param score_threshold: 分数阈值,用于过滤结果。_
_ :param reranking_model: 重排序模型配置,用于对结果进行二次排序。_
_ :return: 检索结果列表。_
_ """_
_ _dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id
).first()
if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0:
return []
all_documents = [] # 存储所有检索结果
keyword_search_documents = []
embedding_search_documents = []
full_text_search_documents = []
hybrid_search_documents = []
threads = [] # 存储执行检索的线程
exceptions = [] # 存储执行过程中遇到的异常
# retrieval_model source with keyword
# 关键词检索
if retrival_method == 'keyword_search':
keyword_thread = threading.Thread(target=RetrievalService.keyword_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset_id': dataset_id,
'query': query,
'top_k': top_k,
'all_documents': all_documents,
'exceptions': exceptions,
})
threads.append(keyword_thread)
keyword_thread.start()
# 向量检索(混合检索中也会调用)
if RetrievalMethod.is_support_semantic_search(retrival_method):
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset_id': dataset_id,
'query': query,
'top_k': top_k,
'score_threshold': score_threshold,
'reranking_model': reranking_model,
'all_documents': all_documents,
'retrival_method': retrival_method,
'exceptions': exceptions,
})
threads.append(embedding_thread)
embedding_thread.start()
# 文本检索(混合检索中也会调用)
if RetrievalMethod.is_support_fulltext_search(retrival_method):
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset_id': dataset_id,
'query': query,
'retrival_method': retrival_method,
'score_threshold': score_threshold,
'top_k': top_k,
'reranking_model': reranking_model,
'all_documents': all_documents,
'exceptions': exceptions,
})
threads.append(full_text_index_thread)
full_text_index_thread.start()
# 等待所有线程完成
for thread in threads:
thread.join()
# 如果执行过程中有异常,则合并异常信息并抛出
if exceptions:
exception_message = ';\n'.join(exceptions)
raise Exception(exception_message)
# 混合检索之后会执行向量和文本检索结果合并后的重排序
if retrival_method == RetrievalMethod.HYBRID_SEARCH.value:
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_mode,
reranking_model, weights, False)
all_documents = data_post_processor.invoke(
query=query,
documents=all_documents,
score_threshold=score_threshold,
top_n=top_k
)
# print("混合搜索",all_documents)
return all_documents
假设本次选择的检索方法是 高质量-混合检索-权重则代码分别会开启两个线程走到 RetrievalMethod.is_support_semantic_search(retrival_method)
RetrievalMethod.is_support_fulltext_search(retrival_method)
执行向量以及全文检索查询,之后会走到 DataPostProcessor 中进行混合检索,然后根据是否开启 Rerank 模型进行重排序
DataPostProcessor
代码目录:api/core/rag/data_post_processor/data_post_processor.py
DataPostProcessor 定义了一个数据后处理接口,用于对搜索结果进行重排序和重新排序。构造函数初始化了重排序和重新排序运行器实例,而 invoke 方法则根据这些实例执行相应的操作,并返回处理后的文档列表。
class DataPostProcessor:
"""接口用于数据后处理文档。
"""
def __init__(self, tenant_id: str, reranking_mode: str,
reranking_model: Optional[dict] = None, weights: Optional[dict] = None,
reorder_enabled: bool = False):
_# 初始化 DataPostProcessor 实例_
_# 创建重排序运行器实例_
self.rerank_runner = self._get_rerank_runner(reranking_mode, tenant_id, reranking_model, weights)
_# 创建重新排序运行器实例_
self.reorder_runner = self._get_reorder_runner(reorder_enabled)
_get_rerank_runner
根据提供的配置来创建一个重排序(reranking)运行器实例,调整搜索结果的顺序以提高相关性。
####### 基于权重的重排序:
- 如果
reranking_mode是WEIGHTED_SCORE且提供了weights,则创建一个WeightRerankRunner实例,并使用给定的weights配置。 weights字典包含了vector_setting和keyword_setting,这些设置分别用于向量和关键词的权重配置。
####### 基于模型的重排序:
- 如果
reranking_mode是RERANKING_MODEL且提供了reranking_model,则尝试获取对应的模型实例。 - 获取模型实例的过程可能抛出
InvokeAuthorizationError异常,此时返回None。 - 成功获取模型实例后,创建并返回一个
RerankModelRunner实例。
####### 详细代码
def _get_rerank_runner(self, reranking_mode: str, tenant_id: str, reranking_model: Optional[dict] = None,
weights: Optional[dict] = None) -> Optional[RerankModelRunner | WeightRerankRunner]:
# 根据提供的配置来创建一个重排序(reranking)运行器实例。重排序通常用于调整搜索结果的顺序以提高相关性。
# 判断是否使用基于权重的重排序模式
if reranking_mode == RerankMode.WEIGHTED_SCORE.value and weights:
# 创建 WeightRerankRunner 实例
return WeightRerankRunner(
tenant_id, # 租户ID
Weights( # 创建 Weights 对象
vector_setting=VectorSetting( # 创建 VectorSetting 对象
vector_weight=weights['vector_setting']['vector_weight'], # 向量权重
embedding_provider_name=weights['vector_setting']['embedding_provider_name'], # 嵌入式提供商名称
embedding_model_name=weights['vector_setting']['embedding_model_name'], # 嵌入式模型名称
),
keyword_setting=KeywordSetting( # 创建 KeywordSetting 对象
keyword_weight=weights['keyword_setting']['keyword_weight'], # 关键词权重
)
)
)
# 判断是否使用基于模型的重排序模式
elif reranking_mode == RerankMode.RERANKING_MODEL.value:
if reranking_model:
try:
# 创建 ModelManager 实例
model_manager = ModelManager()
# 获取模型实例
rerank_model_instance = model_manager.get_model_instance(
tenant_id=tenant_id, # 租户ID
provider=reranking_model['reranking_provider_name'], # 模型提供商名称
model_type=ModelType.RERANK, # 模型类型
model=reranking_model['reranking_model_name'] # 模型名称
)
except InvokeAuthorizationError:
return None
return RerankModelRunner(rerank_model_instance)
return None
# 如果既不是基于权重也不是基于模型的重排序模式,则返回 None
return None
invoke
根据返回的实例 WeightRerankRunner/ModelManager 调用对应的 run 方法 在构造函数初始化了重排序和重新排序运行器实例
def invoke(self, query: str, documents: list[Document], score_threshold: Optional[float] = None,
top_n: Optional[int] = None, user: Optional[str] = None) -> list[Document]:
#根据返回的实例WeightRerankRunner/ModelManager调用对应的run方法
if self.rerank_runner:
_# 执行重排序_
documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user)
if self.reorder_runner:
_# 执行重新排序_
documents = self.reorder_runner.run(documents)
return documents
WeightRerankRunner
代码目录:api/core/rag/rerank/weight_rerank.py
基于权重的重排序功能,用于对搜索结果中的文档进行重新排序,以提高相关性。
run
执行重排序逻辑,包括去除重复文档、计算关键词得分、计算向量得分、合并得分、应用分数阈值、更新文档元数据中的得分、排序文档以及返回指定数量的文档。
def run(self, query: str, documents: list[Document], score_threshold: Optional[float] = None,
top_n: Optional[int] = None, user: Optional[str] = None) -> list[Document]:
_"""_
_ 执行基于权重的重排序_
_ :param query: 搜索查询_
_ :param documents: 文档列表_
_ :param score_threshold: 分数阈值_
_ :param top_n: 返回的文档数量上限_
_ :param user: unique user id if needed_
_ :return: 重排序后的文档列表_
_ """_
_ _docs = [] # 用于存储文档内容
doc_id = [] # 用于存储文档ID
unique_documents = [] # 用于存储去重后的文档
for document in documents:
if document.metadata['doc_id'] not in doc_id:
doc_id.append(document.metadata['doc_id']) # 添加文档ID
docs.append(document.page_content) # 添加文档内容
unique_documents.append(document) # 添加文档
documents = unique_documents # 更新文档列表为去重后的列表
rerank_documents = [] # 用于存储重排序后的文档
# 计算关键词
query_scores = self._calculate_keyword_score(query, documents)
# 计算向量
query_vector_scores = self._calculate_cosine(self.tenant_id, query, documents, self.weights.vector_setting)
# 合并
for document, query_score, query_vector_score in zip(documents, query_scores, query_vector_scores):
# 计算文档的最终
score = self.weights.vector_setting.vector_weight * query_vector_score + \
self.weights.keyword_setting.keyword_weight * query_score
# 应用分数阈值
if score_threshold and score < score_threshold:
continue
# 更新文档元数据中的得分
document.metadata['score'] = score
# 添加文档到重排序列表
rerank_documents.append(document)
# 排序文档
rerank_documents = sorted(rerank_documents, key=lambda x: x.metadata['score'], reverse=True)
# 返回指定数量的文档
return rerank_documents[:top_n] if top_n else rerank_documents
####### _calculate_keyword_score
根据关键词计算文档与查询之间的余弦相似度得分。
def _calculate_keyword_score(self, query: str, documents: list[Document]) -> list[float]:
_"""_
_ 计算文档与查询之间的余弦相似度得分。_
_ :param query: search query_
_ :param documents: documents for reranking_
_ :return: 每个文档的BM25得分列表_
_ """_
_ _keyword_table_handler = JiebaKeywordTableHandler() # 创建关键词处理实例
query_keywords = keyword_table_handler.extract_keywords(query, None) # 提取查询关键词
documents_keywords = [] # 用于存储文档关键词
for document in documents:
# # 提取文档关键词
document_keywords = keyword_table_handler.extract_keywords(document.page_content, None)
document.metadata['keywords'] = document_keywords # 将关键词存入文档元数据
documents_keywords.append(document_keywords) # 添加文档关键词到列表
# 统计查询关键词的词频(TF)
query_keyword_counts = Counter(query_keywords)
# 总文档数
total_documents = len(documents)
# 计算所有文档关键词的逆文档频率(IDF)
all_keywords = set() # 用于存储所有文档中的关键词
for document_keywords in documents_keywords:
all_keywords.update(document_keywords) # 更新所有关键词集合
keyword_idf = {} # 用于存储关键词的IDF值
for keyword in all_keywords:
# 计算包含特定关键词的文档数量
doc_count_containing_keyword = sum(1 for doc_keywords in documents_keywords if keyword in doc_keywords)
# 计算IDF
keyword_idf[keyword] = math.log((1 + total_documents) / (1 + doc_count_containing_keyword)) + 1
query_tfidf = {} # 用于存储查询关键词的TF-IDF值
for keyword, count in query_keyword_counts.items():
tf = count # 查询关键词的词频
idf = keyword_idf.get(keyword, 0) # 查询关键词的IDF
query_tfidf[keyword] = tf * idf # 计算查询关键词的TF-IDF值
# 计算所有文档的TF-IDF值
documents_tfidf = []
for document_keywords in documents_keywords:
document_keyword_counts = Counter(document_keywords) # 统计文档关键词的词频
document_tfidf = {} # 用于存储文档关键词的TF-IDF值
for keyword, count in document_keyword_counts.items():
tf = count # 文档关键词的词频
idf = keyword_idf.get(keyword, 0) # 文档关键词的IDF
document_tfidf[keyword] = tf * idf # 计算文档关键词的TF-IDF值
documents_tfidf.append(document_tfidf) # 添加文档TF-IDF值到列表
# 定义余弦相似度计算函数
def cosine_similarity(vec1, vec2):
# 计算两个向量的交集
intersection = set(vec1.keys()) & set(vec2.keys())
# 计算分子
numerator = sum(vec1[x] * vec2[x] for x in intersection)
# 计算分母
sum1 = sum(vec1[x] ** 2 for x in vec1.keys())
sum2 = sum(vec2[x] ** 2 for x in vec2.keys())
denominator = math.sqrt(sum1) * math.sqrt(sum2)
# 避免除以零的情况
if not denominator:
return 0.0
else:
return float(numerator) / denominator
# 初始化用于存储文档相似度得分的列表
similarities = []
# 计算每个文档与查询之间的相似度得分
for document_tfidf in documents_tfidf:
similarity = cosine_similarity(query_tfidf, document_tfidf)
similarities.append(similarity)
# for idx, similarity in enumerate(similarities):
# print(f"Document {idx + 1} similarity: {similarity}")
# 返回相似度得分列表
return similarities
####### _calculate_cosine
根据向量计算_查询与文档之间的余弦相似度得分。_
def _calculate_cosine(self, tenant_id: str, query: str, documents: list[Document],
vector_setting: VectorSetting) -> list[float]:
_"""_
_ 计算查询与文档之间的余弦相似度得分。_
_ :param tenant_id: 租户ID,用于获取特定租户的模型实例_
_ :param query: 搜索查询字符串_
_ :param documents: 包含多个文档的列表_
_ :param vector_setting: 向量设置对象,包含嵌入模型的信息_
_ :return: 每个文档与查询之间的余弦相似度得分列表_
_ """_
_ _# 初始化用于存储查询与文档之间余弦相似度得分的列表
query_vector_scores = []
# 创建模型管理器实例
model_manager = ModelManager()
# 获取指定租户的嵌入模型实例
embedding_model = model_manager.get_model_instance(
tenant_id=tenant_id, # 租户ID
provider=vector_setting.embedding_provider_name, # 嵌入模型提供商名称
model_type=ModelType.TEXT_EMBEDDING, # 模型类型为文本嵌入
model=vector_setting.embedding_model_name # 嵌入模型名称
)
# 创建缓存嵌入实例
cache_embedding = CacheEmbedding(embedding_model)
# 为查询生成嵌入向量
query_vector = cache_embedding.embed_query(query)
# 遍历每个文档
for document in documents:
# 如果文档元数据中已经存在得分,则直接使用
if 'score' in document.metadata:
query_vector_scores.append(document.metadata['score'])
else:
# 获取文档的嵌入向量
content_vector = document.metadata['vector']
# 将查询向量和文档向量转换为NumPy数组
vec1 = np.array(query_vector)
vec2 = np.array(document.metadata['vector'])
# 计算点积
dot_product = np.dot(vec1, vec2)
# 计算向量范数
norm_vec1 = np.linalg.norm(vec1)
norm_vec2 = np.linalg.norm(vec2)
# 计算余弦相似度
cosine_sim = dot_product / (norm_vec1 * norm_vec2)
# 将余弦相似度得分添加到列表中
query_vector_scores.append(cosine_sim)
return query_vector_scores
298

被折叠的 条评论
为什么被折叠?



