摘要
Retriever是LlamaIndex中负责从索引中检索相关信息的核心组件,它在整个检索增强生成(RAG)流程中扮演着至关重要的角色。通过不同的Retriever实现,我们可以从各种索引结构中高效地检索相关数据。本文将深入探讨Retriever的工作原理、内置类型、配置选项以及在实际应用中的使用方法,帮助开发者更好地理解和应用这一重要组件。
正文
1. 引言
在前面的博客中,我们详细介绍了LlamaIndex的各种索引类型和查询引擎。在这些组件中,Retriever作为一个关键的中间层,负责从索引中检索与用户查询相关的信息。Retriever不仅决定了检索的准确性和效率,还直接影响最终生成答案的质量。理解Retriever的工作原理和使用方法,对于构建高质量的LLM应用至关重要。
2. Retriever基础概念
2.1 什么是Retriever
Retriever是LlamaIndex中负责从索引中检索相关信息的组件。它接收用户的查询请求,通过特定的算法和策略从索引中找出最相关的文档或节点,并将这些信息传递给后续的处理组件(如ResponseSynthesizer)用于生成最终的回答。
2.2 Retriever的核心特点
- 检索专注:专门负责信息检索任务
- 索引无关性:可以与不同类型的索引配合使用
- 可配置性:支持多种检索策略和参数配置
- 可扩展性:支持自定义实现以满足特定需求
- 性能优化:通过各种优化技术提升检索效率
3. Retriever工作原理
3.1 Retriever架构
Retriever在整个LlamaIndex架构中的位置如下:
- 查询接收:接收来自QueryEngine的查询请求
- 索引交互:与特定类型的索引进行交互
- 节点检索:从索引中检索相关节点
- 结果返回:将检索结果返回给QueryEngine
3.2 Retriever工作流程
4. 内置Retriever类型
4.1 VectorIndexRetriever
VectorIndexRetriever是与VectorStoreIndex配合使用的检索器,基于向量相似度进行检索:
from llama_index.core.indices.vector_store import VectorIndexRetriever
from llama_index.core import VectorStoreIndex
# 创建索引
index = VectorStoreIndex.from_documents(documents)
# 创建VectorIndexRetriever
retriever = VectorIndexRetriever(
index=index,
similarity_top_k=5 # 返回前5个最相似的节点
)
# 执行检索
nodes = retriever.retrieve("人工智能的应用有哪些?")
for node in nodes:
print(f"相似度: {node.score}")
print(f"内容: {node.text[:100]}...")
4.2 KeywordTableRetriever
KeywordTableRetriever是与KeywordTableIndex配合使用的检索器,基于关键词匹配进行检索:
from llama_index.core.indices.keyword_table import KeywordTableRetriever
from llama_index.core import KeywordTableIndex
# 创建索引
index = KeywordTableIndex.from_documents(documents)
# 创建KeywordTableRetriever
retriever = KeywordTableRetriever(
index=index,
max_keywords_per_query=10 # 每次查询最多使用10个关键词
)
# 执行检索
nodes = retriever.retrieve("机器学习算法")
4.3 ListIndexRetriever
ListIndexRetriever是与ListIndex配合使用的检索器:
from llama_index.core.indices.list import ListIndexRetriever
from llama_index.core import ListIndex
# 创建索引
index = ListIndex.from_documents(documents)
# 创建ListIndexRetriever
retriever = ListIndexRetriever(
index=index,
choice_batch_size=5 # 批量处理大小
)
# 执行检索
nodes = retriever.retrieve("技术发展趋势")
4.4 TreeIndexRetriever
TreeIndexRetriever是与TreeIndex配合使用的检索器:
from llama_index.core.indices.tree import TreeIndexRetriever
from llama_index.core import TreeIndex
# 创建索引
index = TreeIndex.from_documents(documents)
# 创建TreeIndexRetriever
retriever = TreeIndexRetriever(
index=index,
child_branch_factor=2 # 子分支因子
)
# 执行检索
nodes = retriever.retrieve("深度学习原理")
5. Retriever配置选项
5.1 相似度参数配置
# 配置向量检索器参数
retriever = index.as_retriever(
# 返回最相似的节点数量
similarity_top_k=10,
# 相似度阈值,低于此值的节点将被过滤
similarity_cutoff=0.7,
# 检索模式
vector_store_query_mode="hybrid", # hybrid, sparse, dense
# 混合搜索中向量和关键词的权重
alpha=0.5,
# 是否使用多模态检索
use_multimodal=False
)
5.2 过滤器配置
from llama_index.core.vector_stores import MetadataFilters, MetadataFilter
# 创建元数据过滤器
filters = MetadataFilters(
filters=[
MetadataFilter(
key="category",
value="technology",
operator="=="
),
MetadataFilter(
key="date",
value="2023-01-01",
operator=">="
)
],
condition="and" # and, or
)
# 在检索器中使用过滤器
retriever = index.as_retriever(
filters=filters,
similarity_top_k=5
)
# 执行带过滤的检索
nodes = retriever.retrieve("最新的AI技术")
6. 自定义Retriever
6.1 继承BaseRetriever
from llama_index.core.retrievers import BaseRetriever
from llama_index.core.schema import NodeWithScore
from typing import List
class CustomRetriever(BaseRetriever):
"""自定义检索器"""
def __init__(self, index, custom_scorer=None):
self.index = index
self.custom_scorer = custom_scorer or self._default_scorer()
super().__init__()
def _default_scorer(self):
"""默认评分函数"""
def scorer(query, node):
# 简单的关键词匹配评分
query_words = set(query.lower().split())
node_words = set(node.text.lower().split())
overlap = len(query_words.intersection(node_words))
return overlap / len(query_words) if query_words else 0
return scorer
def _retrieve(self, query_bundle):
"""实现检索逻辑"""
# 获取所有节点
all_nodes = self.index.docstore.get_nodes(
self.index.ref_doc_info.keys()
)
# 为每个节点计算自定义分数
scored_nodes = []
for node in all_nodes:
score = self.custom_scorer(query_bundle.query_str, node)
if score > 0: # 只保留有相关性的节点
scored_nodes.append(
NodeWithScore(node=node, score=score)
)
# 按分数排序并返回前K个
scored_nodes.sort(key=lambda x: x.score, reverse=True)
return scored_nodes[:10]
async def _aretrieve(self, query_bundle):
"""异步检索实现"""
# 实现异步检索逻辑
pass
# 使用自定义检索器
custom_retriever = CustomRetriever(index)
nodes = custom_retriever.retrieve("自定义检索示例")
6.2 多索引检索器
class MultiIndexRetriever(BaseRetriever):
"""多索引检索器"""
def __init__(self, indexes, weights=None):
self.indexes = indexes
self.weights = weights or [1.0] * len(indexes)
super().__init__()
def _retrieve(self, query_bundle):
"""从多个索引中检索"""
all_results = []
# 从每个索引中检索
for i, index in enumerate(self.indexes):
retriever = index.as_retriever(similarity_top_k=5)
nodes = retriever.retrieve(query_bundle.query_str)
# 应用权重
weighted_nodes = [
NodeWithScore(node=node.node, score=node.score * self.weights[i])
for node in nodes
]
all_results.extend(weighted_nodes)
# 合并和去重结果
merged_results = self._merge_results(all_results)
# 按分数排序
merged_results.sort(key=lambda x: x.score, reverse=True)
return merged_results[:10] # 返回前10个结果
def _merge_results(self, results):
"""合并检索结果"""
# 使用节点ID进行去重
unique_results = {}
for result in results:
node_id = result.node.node_id
if node_id not in unique_results:
unique_results[node_id] = result
else:
# 如果节点已存在,合并分数
existing = unique_results[node_id]
merged_score = max(existing.score, result.score)
unique_results[node_id] = NodeWithScore(
node=existing.node,
score=merged_score
)
return list(unique_results.values())
# 使用多索引检索器
indexes = [index1, index2, index3]
multi_retriever = MultiIndexRetriever(indexes, weights=[0.5, 0.3, 0.2])
nodes = multi_retriever.retrieve("跨领域查询")
7. 实际应用案例
7.1 智能搜索引擎
from llama_index.core.retrievers import BaseRetriever
from llama_index.core.schema import NodeWithScore
import numpy as np
class IntelligentSearchEngine:
"""智能搜索引擎"""
def __init__(self, indexes):
self.indexes = indexes
self.search_history = {}
self.user_profiles = {}
def create_adaptive_retriever(self, user_id, query):
"""创建自适应检索器"""
# 获取用户画像
user_profile = self.user_profiles.get(user_id, {})
# 分析查询类型
query_type = self._analyze_query_type(query)
# 根据用户画像和查询类型选择索引和参数
selected_indexes, weights, params = self._select_configuration(
user_profile, query_type
)
# 创建检索器
retriever = WeightedMultiIndexRetriever(
indexes=selected_indexes,
weights=weights,
**params
)
return retriever
def _analyze_query_type(self, query):
"""分析查询类型"""
query_lower = query.lower()
if any(word in query_lower for word in ["如何", "怎样", "步骤"]):
return "instructional"
elif any(word in query_lower for word in ["比较", "对比", "区别"]):
return "comparative"
elif any(word in query_lower for word in ["最新", "最近", "2023", "2024"]):
return "temporal"
else:
return "general"
def _select_configuration(self, user_profile, query_type):
"""选择配置"""
# 基于用户画像选择索引
preferred_categories = user_profile.get("preferred_categories", ["general"])
selected_indexes = [
idx for idx in self.indexes
if idx.category in preferred_categories
]
# 基于查询类型调整权重
weights = self._calculate_weights(query_type, selected_indexes)
# 设置检索参数
params = self._set_retrieval_params(query_type)
return selected_indexes, weights, params
def _calculate_weights(self, query_type, indexes):
"""计算索引权重"""
weights = []
for index in indexes:
base_weight = 1.0
# 根据查询类型调整权重
if query_type == "temporal" and index.temporal_relevance:
base_weight *= 1.5
elif query_type == "instructional" and index.instructional_content:
base_weight *= 1.3
weights.append(base_weight)
# 归一化权重
total_weight = sum(weights)
if total_weight > 0:
weights = [w / total_weight for w in weights]
return weights
def _set_retrieval_params(self, query_type):
"""设置检索参数"""
params = {
"similarity_top_k": 10,
"similarity_cutoff": 0.6
}
if query_type == "temporal":
params["similarity_top_k"] = 15
params["similarity_cutoff"] = 0.5
elif query_type == "instructional":
params["response_mode"] = "step_by_step"
return params
def search(self, user_id, query):
"""执行搜索"""
# 创建自适应检索器
retriever = self.create_adaptive_retriever(user_id, query)
# 执行检索
nodes = retriever.retrieve(query)
# 更新搜索历史
self._update_search_history(user_id, query, nodes)
# 个性化排序
personalized_nodes = self._personalize_results(user_id, nodes)
return personalized_nodes
def _update_search_history(self, user_id, query, results):
"""更新搜索历史"""
if user_id not in self.search_history:
self.search_history[user_id] = []
self.search_history[user_id].append({
"query": query,
"results": [node.node_id for node in results],
"timestamp": np.datetime64('now')
})
# 保持历史记录在合理范围内
if len(self.search_history[user_id]) > 100:
self.search_history[user_id] = self.search_history[user_id][-50:]
def _personalize_results(self, user_id, nodes):
"""个性化结果排序"""
user_profile = self.user_profiles.get(user_id, {})
preferred_topics = user_profile.get("preferred_topics", [])
# 为每个节点计算个性化分数
personalized_nodes = []
for node in nodes:
base_score = node.score
personalization_boost = self._calculate_personalization_boost(
node, preferred_topics
)
personalized_score = base_score * (1 + personalization_boost)
personalized_nodes.append(
NodeWithScore(node=node.node, score=personalized_score)
)
# 重新排序
personalized_nodes.sort(key=lambda x: x.score, reverse=True)
return personalized_nodes
def _calculate_personalization_boost(self, node, preferred_topics):
"""计算个性化提升分数"""
boost = 0.0
# 检查节点主题是否在用户偏好中
node_topics = node.node.metadata.get("topics", [])
overlap = len(set(node_topics).intersection(set(preferred_topics)))
if overlap > 0:
boost += 0.2 * overlap
# 检查内容关键词匹配
node_text = node.node.text.lower()
for topic in preferred_topics:
if topic.lower() in node_text:
boost += 0.1
return min(boost, 1.0) # 限制最大提升
class WeightedMultiIndexRetriever(BaseRetriever):
"""加权多索引检索器"""
def __init__(self, indexes, weights, similarity_top_k=10, similarity_cutoff=0.0):
self.indexes = indexes
self.weights = weights
self.similarity_top_k = similarity_top_k
self.similarity_cutoff = similarity_cutoff
super().__init__()
def _retrieve(self, query_bundle):
"""检索实现"""
all_results = []
for i, index in enumerate(self.indexes):
# 为每个索引创建检索器
retriever = index.as_retriever(
similarity_top_k=self.similarity_top_k,
similarity_cutoff=self.similarity_cutoff
)
# 检索结果
nodes = retriever.retrieve(query_bundle.query_str)
# 应用权重
weighted_nodes = [
NodeWithScore(node=node.node, score=node.score * self.weights[i])
for node in nodes
]
all_results.extend(weighted_nodes)
# 合并和去重
merged_results = self._merge_and_deduplicate(all_results)
# 排序并返回
merged_results.sort(key=lambda x: x.score, reverse=True)
return merged_results[:self.similarity_top_k]
def _merge_and_deduplicate(self, results):
"""合并和去重结果"""
node_dict = {}
for result in results:
node_id = result.node.node_id
if node_id not in node_dict:
node_dict[node_id] = result
else:
# 合并分数(取最大值)
existing = node_dict[node_id]
merged_score = max(existing.score, result.score)
node_dict[node_id] = NodeWithScore(
node=existing.node,
score=merged_score
)
return list(node_dict.values())
# 使用示例
# 假设已有多个索引
# indexes = [tech_index, business_index, news_index]
# search_engine = IntelligentSearchEngine(indexes)
# 设置用户画像
# search_engine.user_profiles["user_001"] = {
# "preferred_categories": ["technology", "business"],
# "preferred_topics": ["AI", "machine learning", "innovation"]
# }
# 执行搜索
# results = search_engine.search("user_001", "人工智能的最新发展")
7.2 学术文献检索系统
from llama_index.core.retrievers import BaseRetriever
from llama_index.core.schema import NodeWithScore
from datetime import datetime
import re
class AcademicLiteratureRetriever(BaseRetriever):
"""学术文献检索器"""
def __init__(self, index, citation_graph=None):
self.index = index
self.citation_graph = citation_graph
super().__init__()
def _retrieve(self, query_bundle):
"""学术文献检索"""
# 1. 基于内容的检索
base_retriever = self.index.as_retriever(
similarity_top_k=20,
similarity_cutoff=0.5
)
content_results = base_retriever.retrieve(query_bundle.query_str)
# 2. 基于元数据的过滤和增强
enhanced_results = self._enhance_with_metadata(content_results)
# 3. 基于引用关系的排序
citation_ranked_results = self._rank_by_citations(enhanced_results)
# 4. 基于时效性的调整
temporal_adjusted_results = self._adjust_for_recency(citation_ranked_results)
# 5. 综合排序
final_results = self._final_ranking(temporal_adjusted_results)
return final_results[:10]
def _enhance_with_metadata(self, results):
"""基于元数据增强结果"""
enhanced_results = []
for result in results:
node = result.node
metadata = node.metadata
# 提取学术元数据
publication_year = metadata.get("year", 0)
citation_count = metadata.get("citation_count", 0)
venue = metadata.get("venue", "")
authors = metadata.get("authors", [])
# 计算学术影响力分数
impact_score = self._calculate_academic_impact(
publication_year, citation_count, venue
)
# 创建增强的节点
enhanced_score = result.score * (1 + impact_score)
enhanced_results.append(
NodeWithScore(node=node, score=enhanced_score)
)
return enhanced_results
def _calculate_academic_impact(self, year, citations, venue):
"""计算学术影响力"""
# 年份影响因子(较新的论文权重更高)
current_year = datetime.now().year
year_factor = max(0.5, 1.0 - (current_year - year) * 0.1)
# 引用影响因子
citation_factor = min(1.0, citations / 1000.0)
# 会议/期刊影响因子
venue_factor = self._get_venue_factor(venue)
# 综合影响力分数
impact_score = (year_factor * 0.3 + citation_factor * 0.5 + venue_factor * 0.2)
return impact_score
def _get_venue_factor(self, venue):
"""获取会议/期刊影响因子"""
top_venues = {
"Nature", "Science", "Cell", "NeurIPS", "ICML", "CVPR", "ACL"
}
if venue in top_venues:
return 1.0
elif "Transactions" in venue or "_journal" in venue:
return 0.8
else:
return 0.5
def _rank_by_citations(self, results):
"""基于引用关系排序"""
if not self.citation_graph:
return results
ranked_results = []
for result in results:
node = result.node
node_id = node.node_id
# 获取引用信息
citation_info = self.citation_graph.get(node_id, {})
incoming_citations = citation_info.get("incoming", 0)
# 基于引用数调整分数
citation_boost = min(1.0, incoming_citations / 100.0)
boosted_score = result.score * (1 + citation_boost)
ranked_results.append(
NodeWithScore(node=node, score=boosted_score)
)
return ranked_results
def _adjust_for_recency(self, results):
"""基于时效性调整"""
adjusted_results = []
current_year = datetime.now().year
for result in results:
node = result.node
publication_year = node.metadata.get("year", current_year)
# 计算时效性因子
years_old = current_year - publication_year
recency_factor = max(0.5, 1.0 - years_old * 0.1)
# 调整分数
adjusted_score = result.score * recency_factor
adjusted_results.append(
NodeWithScore(node=node, score=adjusted_score)
)
return adjusted_results
def _final_ranking(self, results):
"""最终排序"""
# 可以实现更复杂的排序算法
# 例如:机器学习排序模型
results.sort(key=lambda x: x.score, reverse=True)
return results
# 使用示例
# academic_retriever = AcademicLiteratureRetriever(
# index=academic_index,
# citation_graph=citation_data
# )
# results = academic_retriever.retrieve("深度学习在自然语言处理中的应用")
7.3 企业知识库检索系统
from llama_index.core.retrievers import BaseRetriever
from llama_index.core.schema import NodeWithScore
import jieba
import jieba.analyse
class EnterpriseKnowledgeRetriever(BaseRetriever):
"""企业知识库检索器"""
def __init__(self, index, department_permissions=None):
self.index = index
self.department_permissions = department_permissions or {}
self.query_expander = QueryExpander()
super().__init__()
def _retrieve(self, query_bundle):
"""企业知识检索"""
# 1. 查询扩展
expanded_queries = self.query_expander.expand(query_bundle.query_str)
# 2. 多轮检索
all_results = []
for expanded_query in expanded_queries:
results = self._single_retrieve(expanded_query)
all_results.extend(results)
# 3. 结果合并和去重
merged_results = self._merge_results(all_results)
# 4. 权限过滤
filtered_results = self._apply_permissions(merged_results)
# 5. 最终排序
final_results = self._rank_results(filtered_results)
return final_results[:15]
def _single_retrieve(self, query):
"""单次检索"""
retriever = self.index.as_retriever(
similarity_top_k=10,
similarity_cutoff=0.3
)
return retriever.retrieve(query)
def _merge_results(self, results):
"""合并结果"""
node_dict = {}
for result in results:
node_id = result.node.node_id
if node_id not in node_dict:
node_dict[node_id] = result
else:
# 合并分数(取平均值)
existing = node_dict[node_id]
merged_score = (existing.score + result.score) / 2
node_dict[node_id] = NodeWithScore(
node=existing.node,
score=merged_score
)
return list(node_dict.values())
def _apply_permissions(self, results):
"""应用权限过滤"""
# 这里简化处理,实际应用中需要更复杂的权限系统
return results # 假设所有用户都有访问权限
def _rank_results(self, results):
"""结果排序"""
# 基于多个因素排序
ranked_results = []
for result in results:
node = result.node
metadata = node.metadata
# 基础相关性分数
relevance_score = result.score
# 时效性分数
recency_score = self._calculate_recency_score(metadata)
# 权威性分数
authority_score = self._calculate_authority_score(metadata)
# 使用频率分数
frequency_score = self._calculate_frequency_score(metadata)
# 综合分数
final_score = (
relevance_score * 0.6 +
recency_score * 0.2 +
authority_score * 0.1 +
frequency_score * 0.1
)
ranked_results.append(
NodeWithScore(node=node, score=final_score)
)
# 排序
ranked_results.sort(key=lambda x: x.score, reverse=True)
return ranked_results
def _calculate_recency_score(self, metadata):
"""计算时效性分数"""
updated_date = metadata.get("updated_date")
if not updated_date:
return 0.5
try:
update_time = datetime.fromisoformat(updated_date)
days_since_update = (datetime.now() - update_time).days
# 较新的文档分数更高
return max(0.1, 1.0 - days_since_update / 365.0)
except:
return 0.5
def _calculate_authority_score(self, metadata):
"""计算权威性分数"""
author = metadata.get("author", "")
department = metadata.get("department", "")
# 基于作者和部门的权威性
authority_indicators = [
"首席", "总监", "专家", "高级", "CTO", "CFO"
]
authority_score = 0.5 # 默认分数
for indicator in authority_indicators:
if indicator in author or indicator in department:
authority_score += 0.2
return min(1.0, authority_score)
def _calculate_frequency_score(self, metadata):
"""计算使用频率分数"""
access_count = metadata.get("access_count", 0)
# 基于访问次数计算分数
return min(1.0, access_count / 1000.0)
class QueryExpander:
"""查询扩展器"""
def __init__(self):
# 初始化同义词词典
self.synonyms = {
"电脑": ["计算机", "PC", "台式机", "笔记本"],
"软件": ["程序", "应用", "应用程序", "系统"],
"网络": ["互联网", "因特网", "联网", "在线"],
"数据": ["资料", "信息", "数据库", "数据集"],
"安全": ["安保", "防护", "保护", "保密"]
}
def expand(self, query):
"""扩展查询"""
expanded_queries = [query] # 包含原始查询
# 中文分词
words = jieba.lcut(query)
# 对每个词检查是否有同义词
for word in words:
if word in self.synonyms:
# 为每个同义词创建新的查询
for synonym in self.synonyms[word]:
expanded_query = query.replace(word, synonym)
expanded_queries.append(expanded_query)
# 使用TF-IDF提取关键词并创建组合查询
keywords = jieba.analyse.extract_tags(query, topK=3)
if len(keywords) > 1:
combined_query = " ".join(keywords)
expanded_queries.append(combined_query)
# 去重
return list(set(expanded_queries))
# 使用示例
# enterprise_retriever = EnterpriseKnowledgeRetriever(
# index=company_knowledge_index
# )
# results = enterprise_retriever.retrieve("如何配置公司VPN")
8. 性能优化策略
8.1 缓存机制
from functools import lru_cache
import hashlib
class CachedRetriever(BaseRetriever):
"""带缓存的检索器"""
def __init__(self, base_retriever, cache_size=1000):
self.base_retriever = base_retriever
self.cache_size = cache_size
self._cache = {}
super().__init__()
def _retrieve(self, query_bundle):
"""带缓存的检索"""
# 生成查询的哈希键
query_hash = hashlib.md5(query_bundle.query_str.encode()).hexdigest()
# 检查缓存
if query_hash in self._cache:
print("从缓存返回结果")
return self._cache[query_hash]
# 执行实际检索
results = self.base_retriever.retrieve(query_bundle)
# 存储到缓存
self._cache[query_hash] = results
# 限制缓存大小
if len(self._cache) > self.cache_size:
# 移除最旧的项
oldest_key = next(iter(self._cache))
del self._cache[oldest_key]
return results
# 使用缓存检索器
# cached_retriever = CachedRetriever(
# base_retriever=original_retriever,
# cache_size=500
# )
8.2 并行检索
import asyncio
from concurrent.futures import ThreadPoolExecutor
class ParallelRetriever(BaseRetriever):
"""并行检索器"""
def __init__(self, retrievers):
self.retrievers = retrievers
self.executor = ThreadPoolExecutor(max_workers=len(retrievers))
super().__init__()
def _retrieve(self, query_bundle):
"""并行检索"""
# 并行执行多个检索器
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
tasks = [
loop.run_in_executor(
self.executor,
retriever.retrieve,
query_bundle
)
for retriever in self.retrievers
]
results = loop.run_until_complete(asyncio.gather(*tasks))
# 合并结果
merged_results = []
for result_list in results:
merged_results.extend(result_list)
# 去重和排序
unique_results = self._deduplicate(merged_results)
unique_results.sort(key=lambda x: x.score, reverse=True)
return unique_results[:20] # 返回前20个结果
def _deduplicate(self, results):
"""去重"""
seen = set()
unique_results = []
for result in results:
if result.node.node_id not in seen:
seen.add(result.node.node_id)
unique_results.append(result)
return unique_results
# 使用并行检索器
# parallel_retriever = ParallelRetriever([retriever1, retriever2, retriever3])
# results = parallel_retriever.retrieve("并行检索测试")
9. 故障排除和最佳实践
9.1 常见问题及解决方案
-
检索结果相关性低:
# 解决方案:优化检索参数和策略 class ImprovedRetriever(BaseRetriever): def __init__(self, index): self.index = index self.query_processor = QueryProcessor() super().__init__() def _retrieve(self, query_bundle): # 1. 预处理查询 processed_query = self.query_processor.process(query_bundle.query_str) # 2. 使用更严格的相似度阈值 retriever = self.index.as_retriever( similarity_top_k=15, similarity_cutoff=0.7 # 提高阈值 ) # 3. 多轮检索和过滤 initial_results = retriever.retrieve(processed_query) filtered_results = self._filter_low_quality(initial_results) reranked_results = self._rerank_results(filtered_results) return reranked_results[:10] def _filter_low_quality(self, results): """过滤低质量结果""" filtered = [] for result in results: # 基于文本质量和元数据过滤 if self._is_high_quality(result.node): filtered.append(result) return filtered def _is_high_quality(self, node): """判断节点质量""" # 检查文本长度 if len(node.text.strip()) < 50: return False # 检查元数据完整性 required_metadata = ["title", "source", "date"] for field in required_metadata: if field not in node.metadata: return False return True def _rerank_results(self, results): """重新排序结果""" # 基于多个因素重新计算分数 reranked = [] for result in results: new_score = self._calculate_comprehensive_score(result) reranked.append( NodeWithScore(node=result.node, score=new_score) ) reranked.sort(key=lambda x: x.score, reverse=True) return reranked def _calculate_comprehensive_score(self, result): """计算综合分数""" base_score = result.score quality_bonus = self._calculate_quality_bonus(result.node) diversity_bonus = self._calculate_diversity_bonus(result.node) return base_score + quality_bonus + diversity_bonus def _calculate_quality_bonus(self, node): """计算质量奖励分数""" bonus = 0.0 # 基于来源权威性 authoritative_sources = ["官方文档", "学术论文", "行业报告"] if node.metadata.get("source") in authoritative_sources: bonus += 0.1 # 基于更新时间 recency_bonus = self._calculate_recency_bonus(node.metadata) bonus += recency_bonus return bonus def _calculate_recency_bonus(self, metadata): """计算时效性奖励""" # 实现时效性奖励计算 return 0.05 # 简化实现 def _calculate_diversity_bonus(self, node): """计算多样性奖励""" # 鼓励不同类型的内容 return 0.02 -
检索速度慢:
# 解决方案:实现分层检索 class HierarchicalRetriever(BaseRetriever): def __init__(self, index): self.index = index self.coarse_retriever = self._create_coarse_retriever() self.fine_retriever = self._create_fine_retriever() super().__init__() def _create_coarse_retriever(self): """创建粗检检索器""" return self.index.as_retriever( similarity_top_k=50, # 较大K值用于粗检 similarity_cutoff=0.5 # 较低阈值 ) def _create_fine_retriever(self): """创建精检检索器""" return self.index.as_retriever( similarity_top_k=10, # 较小K值用于精检 similarity_cutoff=0.7 # 较高阈值 ) def _retrieve(self, query_bundle): # 1. 粗检:快速筛选候选文档 coarse_results = self.coarse_retriever.retrieve(query_bundle.query_str) if not coarse_results: return [] # 2. 精检:在粗检结果基础上深入检索 # 提取粗检结果的文档ID doc_ids = [result.node.ref_doc_id for result in coarse_results] # 创建针对这些文档的精检检索器 fine_results = self.fine_retriever.retrieve(query_bundle.query_str) # 3. 返回精检结果 return fine_results[:10]
9.2 最佳实践建议
-
合理选择检索器类型:
def select_appropriate_retriever(index_type, requirements): """根据索引类型和需求选择合适的检索器""" if index_type == "vector": return index.as_retriever( similarity_top_k=requirements.get("top_k", 10), similarity_cutoff=requirements.get("cutoff", 0.7), vector_store_query_mode=requirements.get("mode", "default") ) elif index_type == "keyword": return index.as_retriever( max_keywords_per_query=requirements.get("max_keywords", 10) ) elif index_type == "tree": return index.as_retriever( child_branch_factor=requirements.get("branch_factor", 2) ) else: # 自定义实现 return CustomRetriever(index, **requirements) # 使用示例 # retriever = select_appropriate_retriever( # index_type="vector", # requirements={"top_k": 15, "cutoff": 0.6, "mode": "hybrid"} # ) -
监控和日志记录:
import logging from datetime import datetime class MonitoredRetriever(BaseRetriever): """带监控的检索器""" def __init__(self, base_retriever): self.base_retriever = base_retriever self.logger = logging.getLogger(__name__) self.stats = {"total_retrievals": 0, "avg_retrieval_time": 0} super().__init__() def _retrieve(self, query_bundle): """带监控的检索""" start_time = datetime.now() self.stats["total_retrievals"] += 1 try: results = self.base_retriever.retrieve(query_bundle) end_time = datetime.now() duration = (end_time - start_time).total_seconds() self._update_stats(duration) self.logger.info( f"检索成功 | 查询: {query_bundle.query_str[:30]}... | " f"结果数: {len(results)} | 耗时: {duration:.3f}s" ) return results except Exception as e: end_time = datetime.now() duration = (end_time - start_time).total_seconds() self.logger.error( f"检索失败 | 查询: {query_bundle.query_str[:30]}... | " f"耗时: {duration:.3f}s | 错误: {str(e)}" ) raise def _update_stats(self, duration): """更新统计信息""" total = self.stats["total_retrievals"] current_avg = self.stats["avg_retrieval_time"] new_avg = (current_avg * (total - 1) + duration) / total self.stats["avg_retrieval_time"] = new_avg def get_stats(self): """获取统计信息""" return self.stats # 配置日志 logging.basicConfig(level=logging.INFO) # 使用监控检索器 # monitored_retriever = MonitoredRetriever(base_retriever) # results = monitored_retriever.retrieve("监控检索示例")
10. 高级功能探索
10.1 语义检索增强
class SemanticEnhancedRetriever(BaseRetriever):
"""语义增强检索器"""
def __init__(self, index, llm):
self.index = index
self.llm = llm
self.semantic_processor = SemanticProcessor(llm)
super().__init__()
def _retrieve(self, query_bundle):
"""语义增强检索"""
# 1. 语义理解查询意图
query_intent = self.semantic_processor.extract_intent(query_bundle.query_str)
# 2. 生成语义相关的查询变体
semantic_queries = self.semantic_processor.generate_variants(
query_bundle.query_str, query_intent
)
# 3. 对每个语义查询进行检索
all_results = []
for semantic_query in semantic_queries:
retriever = self.index.as_retriever(similarity_top_k=5)
results = retriever.retrieve(semantic_query)
all_results.extend(results)
# 4. 语义重排序
reranked_results = self.semantic_processor.rerank(
all_results, query_bundle.query_str
)
return reranked_results[:10]
class SemanticProcessor:
"""语义处理器"""
def __init__(self, llm):
self.llm = llm
def extract_intent(self, query):
"""提取查询意图"""
prompt = f"""
请分析以下查询的意图:
查询:{query}
请从以下类别中选择最符合的意图:
- 信息获取 (information_seeking)
- 问题解决 (problem_solving)
- 比较分析 (comparison)
- 操作指导 (instruction)
- 创意生成 (creative)
意图:
"""
response = self.llm.complete(prompt)
return response.text.strip()
def generate_variants(self, query, intent):
"""生成查询变体"""
prompt = f"""
基于以下查询和意图,生成3个语义相关的查询变体:
原始查询:{query}
意图:{intent}
要求:
1. 保持原始查询的核心语义
2. 使用不同的表达方式
3. 考虑不同的角度和细节层次
变体1:
变体2:
变体3:
"""
response = self.llm.complete(prompt)
variants = [line.strip() for line in response.text.split('\n') if line.strip()]
return [query] + variants # 包含原始查询
def rerank(self, results, original_query):
"""语义重排序"""
# 为每个结果计算语义相关性分数
scored_results = []
for result in results:
semantic_score = self._calculate_semantic_similarity(
original_query, result.node.text
)
combined_score = (result.score + semantic_score) / 2
scored_results.append(
NodeWithScore(node=result.node, score=combined_score)
)
# 按综合分数排序
scored_results.sort(key=lambda x: x.score, reverse=True)
return scored_results
def _calculate_semantic_similarity(self, query, text):
"""计算语义相似性"""
# 这里可以使用专门的语义相似性模型
# 简化实现:使用关键词重叠
query_words = set(jieba.lcut(query.lower()))
text_words = set(jieba.lcut(text.lower()))
if not query_words:
return 0.0
overlap = len(query_words.intersection(text_words))
return overlap / len(query_words)
# 使用语义增强检索器
# semantic_retriever = SemanticEnhancedRetriever(index, llm)
# results = semantic_retriever.retrieve("人工智能在医疗领域的应用")
10.2 多模态检索器
class MultimodalRetriever(BaseRetriever):
"""多模态检索器"""
def __init__(self, text_index, image_index=None, multimodal_llm=None):
self.text_index = text_index
self.image_index = image_index
self.multimodal_llm = multimodal_llm
super().__init__()
def _retrieve(self, query_bundle):
"""多模态检索"""
all_results = []
# 1. 文本检索
text_retriever = self.text_index.as_retriever(similarity_top_k=10)
text_results = text_retriever.retrieve(query_bundle.query_str)
all_results.extend(text_results)
# 2. 图像检索(如果有图像索引)
if self.image_index:
image_retriever = self.image_index.as_retriever(similarity_top_k=5)
image_results = image_retriever.retrieve(query_bundle.query_str)
all_results.extend(image_results)
# 3. 多模态融合排序
if self.multimodal_llm:
fused_results = self._multimodal_fusion(all_results, query_bundle.query_str)
return fused_results[:15]
else:
# 简单合并
all_results.sort(key=lambda x: x.score, reverse=True)
return all_results[:15]
def _multimodal_fusion(self, results, query):
"""多模态融合"""
# 使用多模态LLM对结果进行重排序
# 这需要支持多模态的LLM,如GPT-4V等
fused_results = []
for result in results:
# 构建多模态提示
if hasattr(result.node, 'image'): # 图像节点
prompt = f"查询:{query}\n图像内容相关性评估:"
# 这里需要实际的多模态处理
multimodal_score = result.score # 简化处理
else: # 文本节点
prompt = f"查询:{query}\n文本内容:{result.node.text[:200]}\n相关性评分:"
# 使用多模态LLM评估
multimodal_score = result.score # 简化处理
fused_results.append(
NodeWithScore(node=result.node, score=multimodal_score)
)
# 排序
fused_results.sort(key=lambda x: x.score, reverse=True)
return fused_results
# 使用多模态检索器(概念性)
# multimodal_retriever = MultimodalRetriever(text_index, image_index)
# results = multimodal_retriever.retrieve("展示人工智能的发展历程")
总结
Retriever作为LlamaIndex中负责信息检索的核心组件,在整个RAG流程中发挥着至关重要的作用。通过本文的详细介绍,我们深入了解了Retriever的工作原理、内置类型、配置选项以及在实际应用中的使用方法。
Retriever的主要优势包括:
- 专门化设计:专注于信息检索任务,提供高效的检索能力
- 索引无关性:可以与各种类型的索引配合使用
- 高度可配置:支持丰富的参数配置以满足不同需求
- 良好的扩展性:支持自定义实现以适应特定场景
- 性能优化:通过多种优化技术提升检索效率
在实际应用中,我们需要根据具体场景选择合适的Retriever类型和配置:
- 向量检索场景:使用VectorIndexRetriever配合VectorStoreIndex
- 关键词检索场景:使用KeywordTableRetriever配合KeywordTableIndex
- 复杂检索需求:实现自定义Retriever或组合多种检索策略
- 企业应用:结合权限控制、缓存机制等企业级功能
通过合理使用Retriever,我们可以构建出高效、准确的信息检索系统,为后续的响应生成提供高质量的上下文信息。随着检索技术的不断发展,Retriever将在更多领域发挥重要作用,成为构建智能应用的核心组件。
950

被折叠的 条评论
为什么被折叠?



