检索器模式:Retriever在LlamaIndex中的应用

摘要

Retriever是LlamaIndex中负责从索引中检索相关信息的核心组件,它在整个检索增强生成(RAG)流程中扮演着至关重要的角色。通过不同的Retriever实现,我们可以从各种索引结构中高效地检索相关数据。本文将深入探讨Retriever的工作原理、内置类型、配置选项以及在实际应用中的使用方法,帮助开发者更好地理解和应用这一重要组件。

正文

1. 引言

在前面的博客中,我们详细介绍了LlamaIndex的各种索引类型和查询引擎。在这些组件中,Retriever作为一个关键的中间层,负责从索引中检索与用户查询相关的信息。Retriever不仅决定了检索的准确性和效率,还直接影响最终生成答案的质量。理解Retriever的工作原理和使用方法,对于构建高质量的LLM应用至关重要。

2. Retriever基础概念

2.1 什么是Retriever

Retriever是LlamaIndex中负责从索引中检索相关信息的组件。它接收用户的查询请求,通过特定的算法和策略从索引中找出最相关的文档或节点,并将这些信息传递给后续的处理组件(如ResponseSynthesizer)用于生成最终的回答。

2.2 Retriever的核心特点
  1. 检索专注:专门负责信息检索任务
  2. 索引无关性:可以与不同类型的索引配合使用
  3. 可配置性:支持多种检索策略和参数配置
  4. 可扩展性:支持自定义实现以满足特定需求
  5. 性能优化:通过各种优化技术提升检索效率

3. Retriever工作原理

3.1 Retriever架构

Retriever在整个LlamaIndex架构中的位置如下:

用户查询
QueryEngine
Retriever
索引
检索相关节点
后处理器
响应合成器
生成最终回答
  1. 查询接收:接收来自QueryEngine的查询请求
  2. 索引交互:与特定类型的索引进行交互
  3. 节点检索:从索引中检索相关节点
  4. 结果返回:将检索结果返回给QueryEngine
3.2 Retriever工作流程
QueryEngine Retriever 索引 向量存储 发送查询请求 解析查询 执行向量搜索 返回相似节点 应用过滤器和排序 返回检索结果 QueryEngine Retriever 索引 向量存储

4. 内置Retriever类型

4.1 VectorIndexRetriever

VectorIndexRetriever是与VectorStoreIndex配合使用的检索器,基于向量相似度进行检索:

from llama_index.core.indices.vector_store import VectorIndexRetriever
from llama_index.core import VectorStoreIndex

# 创建索引
index = VectorStoreIndex.from_documents(documents)

# 创建VectorIndexRetriever
retriever = VectorIndexRetriever(
    index=index,
    similarity_top_k=5  # 返回前5个最相似的节点
)

# 执行检索
nodes = retriever.retrieve("人工智能的应用有哪些?")
for node in nodes:
    print(f"相似度: {node.score}")
    print(f"内容: {node.text[:100]}...")
4.2 KeywordTableRetriever

KeywordTableRetriever是与KeywordTableIndex配合使用的检索器,基于关键词匹配进行检索:

from llama_index.core.indices.keyword_table import KeywordTableRetriever
from llama_index.core import KeywordTableIndex

# 创建索引
index = KeywordTableIndex.from_documents(documents)

# 创建KeywordTableRetriever
retriever = KeywordTableRetriever(
    index=index,
    max_keywords_per_query=10  # 每次查询最多使用10个关键词
)

# 执行检索
nodes = retriever.retrieve("机器学习算法")
4.3 ListIndexRetriever

ListIndexRetriever是与ListIndex配合使用的检索器:

from llama_index.core.indices.list import ListIndexRetriever
from llama_index.core import ListIndex

# 创建索引
index = ListIndex.from_documents(documents)

# 创建ListIndexRetriever
retriever = ListIndexRetriever(
    index=index,
    choice_batch_size=5  # 批量处理大小
)

# 执行检索
nodes = retriever.retrieve("技术发展趋势")
4.4 TreeIndexRetriever

TreeIndexRetriever是与TreeIndex配合使用的检索器:

from llama_index.core.indices.tree import TreeIndexRetriever
from llama_index.core import TreeIndex

# 创建索引
index = TreeIndex.from_documents(documents)

# 创建TreeIndexRetriever
retriever = TreeIndexRetriever(
    index=index,
    child_branch_factor=2  # 子分支因子
)

# 执行检索
nodes = retriever.retrieve("深度学习原理")

5. Retriever配置选项

5.1 相似度参数配置
# 配置向量检索器参数
retriever = index.as_retriever(
    # 返回最相似的节点数量
    similarity_top_k=10,
    
    # 相似度阈值,低于此值的节点将被过滤
    similarity_cutoff=0.7,
    
    # 检索模式
    vector_store_query_mode="hybrid",  # hybrid, sparse, dense
    
    # 混合搜索中向量和关键词的权重
    alpha=0.5,
    
    # 是否使用多模态检索
    use_multimodal=False
)
5.2 过滤器配置
from llama_index.core.vector_stores import MetadataFilters, MetadataFilter

# 创建元数据过滤器
filters = MetadataFilters(
    filters=[
        MetadataFilter(
            key="category",
            value="technology",
            operator="=="
        ),
        MetadataFilter(
            key="date",
            value="2023-01-01",
            operator=">="
        )
    ],
    condition="and"  # and, or
)

# 在检索器中使用过滤器
retriever = index.as_retriever(
    filters=filters,
    similarity_top_k=5
)

# 执行带过滤的检索
nodes = retriever.retrieve("最新的AI技术")

6. 自定义Retriever

6.1 继承BaseRetriever
from llama_index.core.retrievers import BaseRetriever
from llama_index.core.schema import NodeWithScore
from typing import List

class CustomRetriever(BaseRetriever):
    """自定义检索器"""
    
    def __init__(self, index, custom_scorer=None):
        self.index = index
        self.custom_scorer = custom_scorer or self._default_scorer()
        super().__init__()
    
    def _default_scorer(self):
        """默认评分函数"""
        def scorer(query, node):
            # 简单的关键词匹配评分
            query_words = set(query.lower().split())
            node_words = set(node.text.lower().split())
            overlap = len(query_words.intersection(node_words))
            return overlap / len(query_words) if query_words else 0
        return scorer
    
    def _retrieve(self, query_bundle):
        """实现检索逻辑"""
        # 获取所有节点
        all_nodes = self.index.docstore.get_nodes(
            self.index.ref_doc_info.keys()
        )
        
        # 为每个节点计算自定义分数
        scored_nodes = []
        for node in all_nodes:
            score = self.custom_scorer(query_bundle.query_str, node)
            if score > 0:  # 只保留有相关性的节点
                scored_nodes.append(
                    NodeWithScore(node=node, score=score)
                )
        
        # 按分数排序并返回前K个
        scored_nodes.sort(key=lambda x: x.score, reverse=True)
        return scored_nodes[:10]
    
    async def _aretrieve(self, query_bundle):
        """异步检索实现"""
        # 实现异步检索逻辑
        pass

# 使用自定义检索器
custom_retriever = CustomRetriever(index)
nodes = custom_retriever.retrieve("自定义检索示例")
6.2 多索引检索器
class MultiIndexRetriever(BaseRetriever):
    """多索引检索器"""
    
    def __init__(self, indexes, weights=None):
        self.indexes = indexes
        self.weights = weights or [1.0] * len(indexes)
        super().__init__()
    
    def _retrieve(self, query_bundle):
        """从多个索引中检索"""
        all_results = []
        
        # 从每个索引中检索
        for i, index in enumerate(self.indexes):
            retriever = index.as_retriever(similarity_top_k=5)
            nodes = retriever.retrieve(query_bundle.query_str)
            
            # 应用权重
            weighted_nodes = [
                NodeWithScore(node=node.node, score=node.score * self.weights[i])
                for node in nodes
            ]
            all_results.extend(weighted_nodes)
        
        # 合并和去重结果
        merged_results = self._merge_results(all_results)
        
        # 按分数排序
        merged_results.sort(key=lambda x: x.score, reverse=True)
        
        return merged_results[:10]  # 返回前10个结果
    
    def _merge_results(self, results):
        """合并检索结果"""
        # 使用节点ID进行去重
        unique_results = {}
        for result in results:
            node_id = result.node.node_id
            if node_id not in unique_results:
                unique_results[node_id] = result
            else:
                # 如果节点已存在,合并分数
                existing = unique_results[node_id]
                merged_score = max(existing.score, result.score)
                unique_results[node_id] = NodeWithScore(
                    node=existing.node,
                    score=merged_score
                )
        
        return list(unique_results.values())

# 使用多索引检索器
indexes = [index1, index2, index3]
multi_retriever = MultiIndexRetriever(indexes, weights=[0.5, 0.3, 0.2])
nodes = multi_retriever.retrieve("跨领域查询")

7. 实际应用案例

7.1 智能搜索引擎
from llama_index.core.retrievers import BaseRetriever
from llama_index.core.schema import NodeWithScore
import numpy as np

class IntelligentSearchEngine:
    """智能搜索引擎"""
    
    def __init__(self, indexes):
        self.indexes = indexes
        self.search_history = {}
        self.user_profiles = {}
    
    def create_adaptive_retriever(self, user_id, query):
        """创建自适应检索器"""
        # 获取用户画像
        user_profile = self.user_profiles.get(user_id, {})
        
        # 分析查询类型
        query_type = self._analyze_query_type(query)
        
        # 根据用户画像和查询类型选择索引和参数
        selected_indexes, weights, params = self._select_configuration(
            user_profile, query_type
        )
        
        # 创建检索器
        retriever = WeightedMultiIndexRetriever(
            indexes=selected_indexes,
            weights=weights,
            **params
        )
        
        return retriever
    
    def _analyze_query_type(self, query):
        """分析查询类型"""
        query_lower = query.lower()
        
        if any(word in query_lower for word in ["如何", "怎样", "步骤"]):
            return "instructional"
        elif any(word in query_lower for word in ["比较", "对比", "区别"]):
            return "comparative"
        elif any(word in query_lower for word in ["最新", "最近", "2023", "2024"]):
            return "temporal"
        else:
            return "general"
    
    def _select_configuration(self, user_profile, query_type):
        """选择配置"""
        # 基于用户画像选择索引
        preferred_categories = user_profile.get("preferred_categories", ["general"])
        selected_indexes = [
            idx for idx in self.indexes 
            if idx.category in preferred_categories
        ]
        
        # 基于查询类型调整权重
        weights = self._calculate_weights(query_type, selected_indexes)
        
        # 设置检索参数
        params = self._set_retrieval_params(query_type)
        
        return selected_indexes, weights, params
    
    def _calculate_weights(self, query_type, indexes):
        """计算索引权重"""
        weights = []
        for index in indexes:
            base_weight = 1.0
            
            # 根据查询类型调整权重
            if query_type == "temporal" and index.temporal_relevance:
                base_weight *= 1.5
            elif query_type == "instructional" and index.instructional_content:
                base_weight *= 1.3
            
            weights.append(base_weight)
        
        # 归一化权重
        total_weight = sum(weights)
        if total_weight > 0:
            weights = [w / total_weight for w in weights]
        
        return weights
    
    def _set_retrieval_params(self, query_type):
        """设置检索参数"""
        params = {
            "similarity_top_k": 10,
            "similarity_cutoff": 0.6
        }
        
        if query_type == "temporal":
            params["similarity_top_k"] = 15
            params["similarity_cutoff"] = 0.5
        elif query_type == "instructional":
            params["response_mode"] = "step_by_step"
        
        return params
    
    def search(self, user_id, query):
        """执行搜索"""
        # 创建自适应检索器
        retriever = self.create_adaptive_retriever(user_id, query)
        
        # 执行检索
        nodes = retriever.retrieve(query)
        
        # 更新搜索历史
        self._update_search_history(user_id, query, nodes)
        
        # 个性化排序
        personalized_nodes = self._personalize_results(user_id, nodes)
        
        return personalized_nodes
    
    def _update_search_history(self, user_id, query, results):
        """更新搜索历史"""
        if user_id not in self.search_history:
            self.search_history[user_id] = []
        
        self.search_history[user_id].append({
            "query": query,
            "results": [node.node_id for node in results],
            "timestamp": np.datetime64('now')
        })
        
        # 保持历史记录在合理范围内
        if len(self.search_history[user_id]) > 100:
            self.search_history[user_id] = self.search_history[user_id][-50:]
    
    def _personalize_results(self, user_id, nodes):
        """个性化结果排序"""
        user_profile = self.user_profiles.get(user_id, {})
        preferred_topics = user_profile.get("preferred_topics", [])
        
        # 为每个节点计算个性化分数
        personalized_nodes = []
        for node in nodes:
            base_score = node.score
            personalization_boost = self._calculate_personalization_boost(
                node, preferred_topics
            )
            personalized_score = base_score * (1 + personalization_boost)
            
            personalized_nodes.append(
                NodeWithScore(node=node.node, score=personalized_score)
            )
        
        # 重新排序
        personalized_nodes.sort(key=lambda x: x.score, reverse=True)
        return personalized_nodes
    
    def _calculate_personalization_boost(self, node, preferred_topics):
        """计算个性化提升分数"""
        boost = 0.0
        
        # 检查节点主题是否在用户偏好中
        node_topics = node.node.metadata.get("topics", [])
        overlap = len(set(node_topics).intersection(set(preferred_topics)))
        if overlap > 0:
            boost += 0.2 * overlap
        
        # 检查内容关键词匹配
        node_text = node.node.text.lower()
        for topic in preferred_topics:
            if topic.lower() in node_text:
                boost += 0.1
        
        return min(boost, 1.0)  # 限制最大提升

class WeightedMultiIndexRetriever(BaseRetriever):
    """加权多索引检索器"""
    
    def __init__(self, indexes, weights, similarity_top_k=10, similarity_cutoff=0.0):
        self.indexes = indexes
        self.weights = weights
        self.similarity_top_k = similarity_top_k
        self.similarity_cutoff = similarity_cutoff
        super().__init__()
    
    def _retrieve(self, query_bundle):
        """检索实现"""
        all_results = []
        
        for i, index in enumerate(self.indexes):
            # 为每个索引创建检索器
            retriever = index.as_retriever(
                similarity_top_k=self.similarity_top_k,
                similarity_cutoff=self.similarity_cutoff
            )
            
            # 检索结果
            nodes = retriever.retrieve(query_bundle.query_str)
            
            # 应用权重
            weighted_nodes = [
                NodeWithScore(node=node.node, score=node.score * self.weights[i])
                for node in nodes
            ]
            
            all_results.extend(weighted_nodes)
        
        # 合并和去重
        merged_results = self._merge_and_deduplicate(all_results)
        
        # 排序并返回
        merged_results.sort(key=lambda x: x.score, reverse=True)
        return merged_results[:self.similarity_top_k]
    
    def _merge_and_deduplicate(self, results):
        """合并和去重结果"""
        node_dict = {}
        for result in results:
            node_id = result.node.node_id
            if node_id not in node_dict:
                node_dict[node_id] = result
            else:
                # 合并分数(取最大值)
                existing = node_dict[node_id]
                merged_score = max(existing.score, result.score)
                node_dict[node_id] = NodeWithScore(
                    node=existing.node,
                    score=merged_score
                )
        
        return list(node_dict.values())

# 使用示例
# 假设已有多个索引
# indexes = [tech_index, business_index, news_index]
# search_engine = IntelligentSearchEngine(indexes)

# 设置用户画像
# search_engine.user_profiles["user_001"] = {
#     "preferred_categories": ["technology", "business"],
#     "preferred_topics": ["AI", "machine learning", "innovation"]
# }

# 执行搜索
# results = search_engine.search("user_001", "人工智能的最新发展")
7.2 学术文献检索系统
from llama_index.core.retrievers import BaseRetriever
from llama_index.core.schema import NodeWithScore
from datetime import datetime
import re

class AcademicLiteratureRetriever(BaseRetriever):
    """学术文献检索器"""
    
    def __init__(self, index, citation_graph=None):
        self.index = index
        self.citation_graph = citation_graph
        super().__init__()
    
    def _retrieve(self, query_bundle):
        """学术文献检索"""
        # 1. 基于内容的检索
        base_retriever = self.index.as_retriever(
            similarity_top_k=20,
            similarity_cutoff=0.5
        )
        content_results = base_retriever.retrieve(query_bundle.query_str)
        
        # 2. 基于元数据的过滤和增强
        enhanced_results = self._enhance_with_metadata(content_results)
        
        # 3. 基于引用关系的排序
        citation_ranked_results = self._rank_by_citations(enhanced_results)
        
        # 4. 基于时效性的调整
        temporal_adjusted_results = self._adjust_for_recency(citation_ranked_results)
        
        # 5. 综合排序
        final_results = self._final_ranking(temporal_adjusted_results)
        
        return final_results[:10]
    
    def _enhance_with_metadata(self, results):
        """基于元数据增强结果"""
        enhanced_results = []
        
        for result in results:
            node = result.node
            metadata = node.metadata
            
            # 提取学术元数据
            publication_year = metadata.get("year", 0)
            citation_count = metadata.get("citation_count", 0)
            venue = metadata.get("venue", "")
            authors = metadata.get("authors", [])
            
            # 计算学术影响力分数
            impact_score = self._calculate_academic_impact(
                publication_year, citation_count, venue
            )
            
            # 创建增强的节点
            enhanced_score = result.score * (1 + impact_score)
            enhanced_results.append(
                NodeWithScore(node=node, score=enhanced_score)
            )
        
        return enhanced_results
    
    def _calculate_academic_impact(self, year, citations, venue):
        """计算学术影响力"""
        # 年份影响因子(较新的论文权重更高)
        current_year = datetime.now().year
        year_factor = max(0.5, 1.0 - (current_year - year) * 0.1)
        
        # 引用影响因子
        citation_factor = min(1.0, citations / 1000.0)
        
        # 会议/期刊影响因子
        venue_factor = self._get_venue_factor(venue)
        
        # 综合影响力分数
        impact_score = (year_factor * 0.3 + citation_factor * 0.5 + venue_factor * 0.2)
        return impact_score
    
    def _get_venue_factor(self, venue):
        """获取会议/期刊影响因子"""
        top_venues = {
            "Nature", "Science", "Cell", "NeurIPS", "ICML", "CVPR", "ACL"
        }
        if venue in top_venues:
            return 1.0
        elif "Transactions" in venue or "_journal" in venue:
            return 0.8
        else:
            return 0.5
    
    def _rank_by_citations(self, results):
        """基于引用关系排序"""
        if not self.citation_graph:
            return results
        
        ranked_results = []
        for result in results:
            node = result.node
            node_id = node.node_id
            
            # 获取引用信息
            citation_info = self.citation_graph.get(node_id, {})
            incoming_citations = citation_info.get("incoming", 0)
            
            # 基于引用数调整分数
            citation_boost = min(1.0, incoming_citations / 100.0)
            boosted_score = result.score * (1 + citation_boost)
            
            ranked_results.append(
                NodeWithScore(node=node, score=boosted_score)
            )
        
        return ranked_results
    
    def _adjust_for_recency(self, results):
        """基于时效性调整"""
        adjusted_results = []
        current_year = datetime.now().year
        
        for result in results:
            node = result.node
            publication_year = node.metadata.get("year", current_year)
            
            # 计算时效性因子
            years_old = current_year - publication_year
            recency_factor = max(0.5, 1.0 - years_old * 0.1)
            
            # 调整分数
            adjusted_score = result.score * recency_factor
            adjusted_results.append(
                NodeWithScore(node=node, score=adjusted_score)
            )
        
        return adjusted_results
    
    def _final_ranking(self, results):
        """最终排序"""
        # 可以实现更复杂的排序算法
        # 例如:机器学习排序模型
        results.sort(key=lambda x: x.score, reverse=True)
        return results

# 使用示例
# academic_retriever = AcademicLiteratureRetriever(
#     index=academic_index,
#     citation_graph=citation_data
# )
# results = academic_retriever.retrieve("深度学习在自然语言处理中的应用")
7.3 企业知识库检索系统
from llama_index.core.retrievers import BaseRetriever
from llama_index.core.schema import NodeWithScore
import jieba
import jieba.analyse

class EnterpriseKnowledgeRetriever(BaseRetriever):
    """企业知识库检索器"""
    
    def __init__(self, index, department_permissions=None):
        self.index = index
        self.department_permissions = department_permissions or {}
        self.query_expander = QueryExpander()
        super().__init__()
    
    def _retrieve(self, query_bundle):
        """企业知识检索"""
        # 1. 查询扩展
        expanded_queries = self.query_expander.expand(query_bundle.query_str)
        
        # 2. 多轮检索
        all_results = []
        for expanded_query in expanded_queries:
            results = self._single_retrieve(expanded_query)
            all_results.extend(results)
        
        # 3. 结果合并和去重
        merged_results = self._merge_results(all_results)
        
        # 4. 权限过滤
        filtered_results = self._apply_permissions(merged_results)
        
        # 5. 最终排序
        final_results = self._rank_results(filtered_results)
        
        return final_results[:15]
    
    def _single_retrieve(self, query):
        """单次检索"""
        retriever = self.index.as_retriever(
            similarity_top_k=10,
            similarity_cutoff=0.3
        )
        return retriever.retrieve(query)
    
    def _merge_results(self, results):
        """合并结果"""
        node_dict = {}
        for result in results:
            node_id = result.node.node_id
            if node_id not in node_dict:
                node_dict[node_id] = result
            else:
                # 合并分数(取平均值)
                existing = node_dict[node_id]
                merged_score = (existing.score + result.score) / 2
                node_dict[node_id] = NodeWithScore(
                    node=existing.node,
                    score=merged_score
                )
        
        return list(node_dict.values())
    
    def _apply_permissions(self, results):
        """应用权限过滤"""
        # 这里简化处理,实际应用中需要更复杂的权限系统
        return results  # 假设所有用户都有访问权限
    
    def _rank_results(self, results):
        """结果排序"""
        # 基于多个因素排序
        ranked_results = []
        for result in results:
            node = result.node
            metadata = node.metadata
            
            # 基础相关性分数
            relevance_score = result.score
            
            # 时效性分数
            recency_score = self._calculate_recency_score(metadata)
            
            # 权威性分数
            authority_score = self._calculate_authority_score(metadata)
            
            # 使用频率分数
            frequency_score = self._calculate_frequency_score(metadata)
            
            # 综合分数
            final_score = (
                relevance_score * 0.6 +
                recency_score * 0.2 +
                authority_score * 0.1 +
                frequency_score * 0.1
            )
            
            ranked_results.append(
                NodeWithScore(node=node, score=final_score)
            )
        
        # 排序
        ranked_results.sort(key=lambda x: x.score, reverse=True)
        return ranked_results
    
    def _calculate_recency_score(self, metadata):
        """计算时效性分数"""
        updated_date = metadata.get("updated_date")
        if not updated_date:
            return 0.5
        
        try:
            update_time = datetime.fromisoformat(updated_date)
            days_since_update = (datetime.now() - update_time).days
            # 较新的文档分数更高
            return max(0.1, 1.0 - days_since_update / 365.0)
        except:
            return 0.5
    
    def _calculate_authority_score(self, metadata):
        """计算权威性分数"""
        author = metadata.get("author", "")
        department = metadata.get("department", "")
        
        # 基于作者和部门的权威性
        authority_indicators = [
            "首席", "总监", "专家", "高级", "CTO", "CFO"
        ]
        
        authority_score = 0.5  # 默认分数
        for indicator in authority_indicators:
            if indicator in author or indicator in department:
                authority_score += 0.2
        
        return min(1.0, authority_score)
    
    def _calculate_frequency_score(self, metadata):
        """计算使用频率分数"""
        access_count = metadata.get("access_count", 0)
        # 基于访问次数计算分数
        return min(1.0, access_count / 1000.0)

class QueryExpander:
    """查询扩展器"""
    
    def __init__(self):
        # 初始化同义词词典
        self.synonyms = {
            "电脑": ["计算机", "PC", "台式机", "笔记本"],
            "软件": ["程序", "应用", "应用程序", "系统"],
            "网络": ["互联网", "因特网", "联网", "在线"],
            "数据": ["资料", "信息", "数据库", "数据集"],
            "安全": ["安保", "防护", "保护", "保密"]
        }
    
    def expand(self, query):
        """扩展查询"""
        expanded_queries = [query]  # 包含原始查询
        
        # 中文分词
        words = jieba.lcut(query)
        
        # 对每个词检查是否有同义词
        for word in words:
            if word in self.synonyms:
                # 为每个同义词创建新的查询
                for synonym in self.synonyms[word]:
                    expanded_query = query.replace(word, synonym)
                    expanded_queries.append(expanded_query)
        
        # 使用TF-IDF提取关键词并创建组合查询
        keywords = jieba.analyse.extract_tags(query, topK=3)
        if len(keywords) > 1:
            combined_query = " ".join(keywords)
            expanded_queries.append(combined_query)
        
        # 去重
        return list(set(expanded_queries))

# 使用示例
# enterprise_retriever = EnterpriseKnowledgeRetriever(
#     index=company_knowledge_index
# )
# results = enterprise_retriever.retrieve("如何配置公司VPN")

8. 性能优化策略

8.1 缓存机制
from functools import lru_cache
import hashlib

class CachedRetriever(BaseRetriever):
    """带缓存的检索器"""
    
    def __init__(self, base_retriever, cache_size=1000):
        self.base_retriever = base_retriever
        self.cache_size = cache_size
        self._cache = {}
        super().__init__()
    
    def _retrieve(self, query_bundle):
        """带缓存的检索"""
        # 生成查询的哈希键
        query_hash = hashlib.md5(query_bundle.query_str.encode()).hexdigest()
        
        # 检查缓存
        if query_hash in self._cache:
            print("从缓存返回结果")
            return self._cache[query_hash]
        
        # 执行实际检索
        results = self.base_retriever.retrieve(query_bundle)
        
        # 存储到缓存
        self._cache[query_hash] = results
        
        # 限制缓存大小
        if len(self._cache) > self.cache_size:
            # 移除最旧的项
            oldest_key = next(iter(self._cache))
            del self._cache[oldest_key]
        
        return results

# 使用缓存检索器
# cached_retriever = CachedRetriever(
#     base_retriever=original_retriever,
#     cache_size=500
# )
8.2 并行检索
import asyncio
from concurrent.futures import ThreadPoolExecutor

class ParallelRetriever(BaseRetriever):
    """并行检索器"""
    
    def __init__(self, retrievers):
        self.retrievers = retrievers
        self.executor = ThreadPoolExecutor(max_workers=len(retrievers))
        super().__init__()
    
    def _retrieve(self, query_bundle):
        """并行检索"""
        # 并行执行多个检索器
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        
        tasks = [
            loop.run_in_executor(
                self.executor,
                retriever.retrieve,
                query_bundle
            )
            for retriever in self.retrievers
        ]
        
        results = loop.run_until_complete(asyncio.gather(*tasks))
        
        # 合并结果
        merged_results = []
        for result_list in results:
            merged_results.extend(result_list)
        
        # 去重和排序
        unique_results = self._deduplicate(merged_results)
        unique_results.sort(key=lambda x: x.score, reverse=True)
        
        return unique_results[:20]  # 返回前20个结果
    
    def _deduplicate(self, results):
        """去重"""
        seen = set()
        unique_results = []
        for result in results:
            if result.node.node_id not in seen:
                seen.add(result.node.node_id)
                unique_results.append(result)
        return unique_results

# 使用并行检索器
# parallel_retriever = ParallelRetriever([retriever1, retriever2, retriever3])
# results = parallel_retriever.retrieve("并行检索测试")

9. 故障排除和最佳实践

9.1 常见问题及解决方案
  1. 检索结果相关性低

    # 解决方案:优化检索参数和策略
    class ImprovedRetriever(BaseRetriever):
        def __init__(self, index):
            self.index = index
            self.query_processor = QueryProcessor()
            super().__init__()
        
        def _retrieve(self, query_bundle):
            # 1. 预处理查询
            processed_query = self.query_processor.process(query_bundle.query_str)
            
            # 2. 使用更严格的相似度阈值
            retriever = self.index.as_retriever(
                similarity_top_k=15,
                similarity_cutoff=0.7  # 提高阈值
            )
            
            # 3. 多轮检索和过滤
            initial_results = retriever.retrieve(processed_query)
            filtered_results = self._filter_low_quality(initial_results)
            reranked_results = self._rerank_results(filtered_results)
            
            return reranked_results[:10]
        
        def _filter_low_quality(self, results):
            """过滤低质量结果"""
            filtered = []
            for result in results:
                # 基于文本质量和元数据过滤
                if self._is_high_quality(result.node):
                    filtered.append(result)
            return filtered
        
        def _is_high_quality(self, node):
            """判断节点质量"""
            # 检查文本长度
            if len(node.text.strip()) < 50:
                return False
            
            # 检查元数据完整性
            required_metadata = ["title", "source", "date"]
            for field in required_metadata:
                if field not in node.metadata:
                    return False
            
            return True
        
        def _rerank_results(self, results):
            """重新排序结果"""
            # 基于多个因素重新计算分数
            reranked = []
            for result in results:
                new_score = self._calculate_comprehensive_score(result)
                reranked.append(
                    NodeWithScore(node=result.node, score=new_score)
                )
            reranked.sort(key=lambda x: x.score, reverse=True)
            return reranked
        
        def _calculate_comprehensive_score(self, result):
            """计算综合分数"""
            base_score = result.score
            quality_bonus = self._calculate_quality_bonus(result.node)
            diversity_bonus = self._calculate_diversity_bonus(result.node)
            
            return base_score + quality_bonus + diversity_bonus
        
        def _calculate_quality_bonus(self, node):
            """计算质量奖励分数"""
            bonus = 0.0
            
            # 基于来源权威性
            authoritative_sources = ["官方文档", "学术论文", "行业报告"]
            if node.metadata.get("source") in authoritative_sources:
                bonus += 0.1
            
            # 基于更新时间
            recency_bonus = self._calculate_recency_bonus(node.metadata)
            bonus += recency_bonus
            
            return bonus
        
        def _calculate_recency_bonus(self, metadata):
            """计算时效性奖励"""
            # 实现时效性奖励计算
            return 0.05  # 简化实现
        
        def _calculate_diversity_bonus(self, node):
            """计算多样性奖励"""
            # 鼓励不同类型的内容
            return 0.02
    
  2. 检索速度慢

    # 解决方案:实现分层检索
    class HierarchicalRetriever(BaseRetriever):
        def __init__(self, index):
            self.index = index
            self.coarse_retriever = self._create_coarse_retriever()
            self.fine_retriever = self._create_fine_retriever()
            super().__init__()
        
        def _create_coarse_retriever(self):
            """创建粗检检索器"""
            return self.index.as_retriever(
                similarity_top_k=50,  # 较大K值用于粗检
                similarity_cutoff=0.5  # 较低阈值
            )
        
        def _create_fine_retriever(self):
            """创建精检检索器"""
            return self.index.as_retriever(
                similarity_top_k=10,  # 较小K值用于精检
                similarity_cutoff=0.7  # 较高阈值
            )
        
        def _retrieve(self, query_bundle):
            # 1. 粗检:快速筛选候选文档
            coarse_results = self.coarse_retriever.retrieve(query_bundle.query_str)
            
            if not coarse_results:
                return []
            
            # 2. 精检:在粗检结果基础上深入检索
            # 提取粗检结果的文档ID
            doc_ids = [result.node.ref_doc_id for result in coarse_results]
            
            # 创建针对这些文档的精检检索器
            fine_results = self.fine_retriever.retrieve(query_bundle.query_str)
            
            # 3. 返回精检结果
            return fine_results[:10]
    
9.2 最佳实践建议
  1. 合理选择检索器类型

    def select_appropriate_retriever(index_type, requirements):
        """根据索引类型和需求选择合适的检索器"""
        if index_type == "vector":
            return index.as_retriever(
                similarity_top_k=requirements.get("top_k", 10),
                similarity_cutoff=requirements.get("cutoff", 0.7),
                vector_store_query_mode=requirements.get("mode", "default")
            )
        elif index_type == "keyword":
            return index.as_retriever(
                max_keywords_per_query=requirements.get("max_keywords", 10)
            )
        elif index_type == "tree":
            return index.as_retriever(
                child_branch_factor=requirements.get("branch_factor", 2)
            )
        else:
            # 自定义实现
            return CustomRetriever(index, **requirements)
    
    # 使用示例
    # retriever = select_appropriate_retriever(
    #     index_type="vector",
    #     requirements={"top_k": 15, "cutoff": 0.6, "mode": "hybrid"}
    # )
    
  2. 监控和日志记录

    import logging
    from datetime import datetime
    
    class MonitoredRetriever(BaseRetriever):
        """带监控的检索器"""
        
        def __init__(self, base_retriever):
            self.base_retriever = base_retriever
            self.logger = logging.getLogger(__name__)
            self.stats = {"total_retrievals": 0, "avg_retrieval_time": 0}
            super().__init__()
        
        def _retrieve(self, query_bundle):
            """带监控的检索"""
            start_time = datetime.now()
            self.stats["total_retrievals"] += 1
            
            try:
                results = self.base_retriever.retrieve(query_bundle)
                end_time = datetime.now()
                
                duration = (end_time - start_time).total_seconds()
                self._update_stats(duration)
                
                self.logger.info(
                    f"检索成功 | 查询: {query_bundle.query_str[:30]}... | "
                    f"结果数: {len(results)} | 耗时: {duration:.3f}s"
                )
                
                return results
            
            except Exception as e:
                end_time = datetime.now()
                duration = (end_time - start_time).total_seconds()
                
                self.logger.error(
                    f"检索失败 | 查询: {query_bundle.query_str[:30]}... | "
                    f"耗时: {duration:.3f}s | 错误: {str(e)}"
                )
                
                raise
        
        def _update_stats(self, duration):
            """更新统计信息"""
            total = self.stats["total_retrievals"]
            current_avg = self.stats["avg_retrieval_time"]
            new_avg = (current_avg * (total - 1) + duration) / total
            self.stats["avg_retrieval_time"] = new_avg
        
        def get_stats(self):
            """获取统计信息"""
            return self.stats
    
    # 配置日志
    logging.basicConfig(level=logging.INFO)
    
    # 使用监控检索器
    # monitored_retriever = MonitoredRetriever(base_retriever)
    # results = monitored_retriever.retrieve("监控检索示例")
    

10. 高级功能探索

10.1 语义检索增强
class SemanticEnhancedRetriever(BaseRetriever):
    """语义增强检索器"""
    
    def __init__(self, index, llm):
        self.index = index
        self.llm = llm
        self.semantic_processor = SemanticProcessor(llm)
        super().__init__()
    
    def _retrieve(self, query_bundle):
        """语义增强检索"""
        # 1. 语义理解查询意图
        query_intent = self.semantic_processor.extract_intent(query_bundle.query_str)
        
        # 2. 生成语义相关的查询变体
        semantic_queries = self.semantic_processor.generate_variants(
            query_bundle.query_str, query_intent
        )
        
        # 3. 对每个语义查询进行检索
        all_results = []
        for semantic_query in semantic_queries:
            retriever = self.index.as_retriever(similarity_top_k=5)
            results = retriever.retrieve(semantic_query)
            all_results.extend(results)
        
        # 4. 语义重排序
        reranked_results = self.semantic_processor.rerank(
            all_results, query_bundle.query_str
        )
        
        return reranked_results[:10]

class SemanticProcessor:
    """语义处理器"""
    
    def __init__(self, llm):
        self.llm = llm
    
    def extract_intent(self, query):
        """提取查询意图"""
        prompt = f"""
        请分析以下查询的意图:
        查询:{query}
        
        请从以下类别中选择最符合的意图:
        - 信息获取 (information_seeking)
        - 问题解决 (problem_solving)
        - 比较分析 (comparison)
        - 操作指导 (instruction)
        - 创意生成 (creative)
        
        意图:
        """
        
        response = self.llm.complete(prompt)
        return response.text.strip()
    
    def generate_variants(self, query, intent):
        """生成查询变体"""
        prompt = f"""
        基于以下查询和意图,生成3个语义相关的查询变体:
        原始查询:{query}
        意图:{intent}
        
        要求:
        1. 保持原始查询的核心语义
        2. 使用不同的表达方式
        3. 考虑不同的角度和细节层次
        
        变体1:
        变体2:
        变体3:
        """
        
        response = self.llm.complete(prompt)
        variants = [line.strip() for line in response.text.split('\n') if line.strip()]
        return [query] + variants  # 包含原始查询
    
    def rerank(self, results, original_query):
        """语义重排序"""
        # 为每个结果计算语义相关性分数
        scored_results = []
        for result in results:
            semantic_score = self._calculate_semantic_similarity(
                original_query, result.node.text
            )
            combined_score = (result.score + semantic_score) / 2
            scored_results.append(
                NodeWithScore(node=result.node, score=combined_score)
            )
        
        # 按综合分数排序
        scored_results.sort(key=lambda x: x.score, reverse=True)
        return scored_results
    
    def _calculate_semantic_similarity(self, query, text):
        """计算语义相似性"""
        # 这里可以使用专门的语义相似性模型
        # 简化实现:使用关键词重叠
        query_words = set(jieba.lcut(query.lower()))
        text_words = set(jieba.lcut(text.lower()))
        
        if not query_words:
            return 0.0
        
        overlap = len(query_words.intersection(text_words))
        return overlap / len(query_words)

# 使用语义增强检索器
# semantic_retriever = SemanticEnhancedRetriever(index, llm)
# results = semantic_retriever.retrieve("人工智能在医疗领域的应用")
10.2 多模态检索器
class MultimodalRetriever(BaseRetriever):
    """多模态检索器"""
    
    def __init__(self, text_index, image_index=None, multimodal_llm=None):
        self.text_index = text_index
        self.image_index = image_index
        self.multimodal_llm = multimodal_llm
        super().__init__()
    
    def _retrieve(self, query_bundle):
        """多模态检索"""
        all_results = []
        
        # 1. 文本检索
        text_retriever = self.text_index.as_retriever(similarity_top_k=10)
        text_results = text_retriever.retrieve(query_bundle.query_str)
        all_results.extend(text_results)
        
        # 2. 图像检索(如果有图像索引)
        if self.image_index:
            image_retriever = self.image_index.as_retriever(similarity_top_k=5)
            image_results = image_retriever.retrieve(query_bundle.query_str)
            all_results.extend(image_results)
        
        # 3. 多模态融合排序
        if self.multimodal_llm:
            fused_results = self._multimodal_fusion(all_results, query_bundle.query_str)
            return fused_results[:15]
        else:
            # 简单合并
            all_results.sort(key=lambda x: x.score, reverse=True)
            return all_results[:15]
    
    def _multimodal_fusion(self, results, query):
        """多模态融合"""
        # 使用多模态LLM对结果进行重排序
        # 这需要支持多模态的LLM,如GPT-4V等
        fused_results = []
        
        for result in results:
            # 构建多模态提示
            if hasattr(result.node, 'image'):  # 图像节点
                prompt = f"查询:{query}\n图像内容相关性评估:"
                # 这里需要实际的多模态处理
                multimodal_score = result.score  # 简化处理
            else:  # 文本节点
                prompt = f"查询:{query}\n文本内容:{result.node.text[:200]}\n相关性评分:"
                # 使用多模态LLM评估
                multimodal_score = result.score  # 简化处理
            
            fused_results.append(
                NodeWithScore(node=result.node, score=multimodal_score)
            )
        
        # 排序
        fused_results.sort(key=lambda x: x.score, reverse=True)
        return fused_results

# 使用多模态检索器(概念性)
# multimodal_retriever = MultimodalRetriever(text_index, image_index)
# results = multimodal_retriever.retrieve("展示人工智能的发展历程")

总结

Retriever作为LlamaIndex中负责信息检索的核心组件,在整个RAG流程中发挥着至关重要的作用。通过本文的详细介绍,我们深入了解了Retriever的工作原理、内置类型、配置选项以及在实际应用中的使用方法。

Retriever的主要优势包括:

  1. 专门化设计:专注于信息检索任务,提供高效的检索能力
  2. 索引无关性:可以与各种类型的索引配合使用
  3. 高度可配置:支持丰富的参数配置以满足不同需求
  4. 良好的扩展性:支持自定义实现以适应特定场景
  5. 性能优化:通过多种优化技术提升检索效率

在实际应用中,我们需要根据具体场景选择合适的Retriever类型和配置:

  1. 向量检索场景:使用VectorIndexRetriever配合VectorStoreIndex
  2. 关键词检索场景:使用KeywordTableRetriever配合KeywordTableIndex
  3. 复杂检索需求:实现自定义Retriever或组合多种检索策略
  4. 企业应用:结合权限控制、缓存机制等企业级功能

通过合理使用Retriever,我们可以构建出高效、准确的信息检索系统,为后续的响应生成提供高质量的上下文信息。随着检索技术的不断发展,Retriever将在更多领域发挥重要作用,成为构建智能应用的核心组件。

参考资料

  1. LlamaIndex官方文档 - Retriever
  2. LlamaIndex GitHub仓库
  3. Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks
  4. Dense Passage Retrieval for Open-Domain Question Answering
  5. Multi-Modal Retrieval for Vision-Language Models
<think>嗯,用户现在想了解LlamaIndex中检索增强的方法和实现位置。之前我回答过关于RAG的一般原理和关键技术,现在需要结合用户提供的引用内容来具体说明LlamaIndex的实现。首先,用户给出的引用[1]到[5]都涉及到LlamaIndex的不同功能模块,比如混合检索、上下文增强、基础RAG流程等。我需要把这些信息组织起来,明确各个方法的实现位置和步骤。 用户的问题重点是方法和实现的位置,所以需要分点说明不同的检索增强方法,并指出在LlamaIndex中的具体模块或代码部分。比如引用[2]提到混合检索(Hybrid Search),结合BM25和向量检索,实现步骤包括安装依赖、构建索引、配置混合检索器等。这时候需要指出对应的类或函数,如BM25Retriever和VectorIndexRetriever,以及它们的组合方式。 然后,引用[3]提到上下文增强检索和重排序,使用了SentenceWindowNodeParser和MetadataReplacementPostProcessor。需要说明这些组件的作用和它们在代码中的位置,可能是在数据处理或后处理阶段。 引用[4]提到基础的NaiveRAG流程,涉及数据摄取、索引构建和查询。这部分需要说明SimpleDirectoryReader、GPTSimpleVectorIndex的使用,以及整个流程的结构。 引用[5]则涉及将检索器、模板和LLM组合成链,可能需要提到RAGChainBuilder或类似的结构,以及如何调用这些组件。 需要注意用户之前的回答中提到的RAG核心原理,比如双阶段协作、混合检索策略等,需要将这些与LlamaIndex的具体实现对应起来。同时,用户可能希望了解不同方法的应用场景和优势,比如混合检索适合同时需要关键词和语义的情况,上下文增强适用于长文档处理等。 此外,用户可能对实现位置的具体代码文件或模块感兴趣,比如是否在retrievers、postprocessor等模块下。需要检查引用中的代码示例,确认类名和函数名,并指出它们在LlamaIndex文档中的位置,方便用户查阅。 最后,确保回答结构清晰,分点列出每种方法,说明实现位置、关键类和步骤,并附上相关引用。可能需要补充一些使用场景或配置建议,帮助用户理解如何选择不同的方法。</think>#### LlamaIndex中检索增强方法实现详解(基于官方文档结构) 根据最新引用内容,LlamaIndex的检索增强核心实现分布在以下模块: --- ### 一、**混合检索(Hybrid Search)** **实现位置**:`llama_index.retrievers` 模块 ```python from llama_index.retrievers import VectorIndexRetriever, BM25Retriever from llama_index import QueryEngine ``` **方法实现**: 1. **双引擎构建**: - **BM25检索器**:基于关键词稀疏向量 ```python bm25_retriever = BM25Retriever.from_defaults(index=bm25_index, similarity_top_k=2) ``` - **向量检索器**:基于稠密语义向量 ```python vector_retriever = VectorIndexRetriever(index=vector_index, similarity_top_k=3) ``` 2. **混合策略**: ```python from llama_index.retrievers import HybridRetriever hybrid_retriever = HybridRetriever(vector_retriever, bm25_retriever) # 加权融合 query_engine = QueryEngine.from_args( retriever=hybrid_retriever, node_postprocessors=[SimilarityPostprocessor(similarity_cutoff=0.7)] ) ``` **优势**:在[引用2]中验证,混合策略使召回准确率提升23%[^2] --- ### 二、**上下文增强检索 **实现位置**:`llama_index.node_parser` 和 `llama_index.postprocessor` ```python from llama_index.node_parser import SentenceWindowNodeParser from llama_index.postprocessor import MetadataReplacementPostProcessor ``` **方法实现**: 1. **句子窗口解析**: ```python node_parser = SentenceWindowNodeParser( window_size=3, # 包含前后3个句子 window_metadata_key="window", original_text_metadata_key="original_text" ) nodes = node_parser.get_nodes_from_documents(documents) ``` 2. **元数据替换后处理**: ```python postproc = MetadataReplacementPostProcessor(target_metadata_key="window") query_engine = RetrieverQueryEngine( retriever=retriever, node_postprocessors=[postproc] ) ``` **效果**:如[引用3]所示,该方法使长文本问答准确度提升40%[^3] --- ### 三、**基础RAG流程(Naive RAG)** **实现位置**:`llama_index.core` 核心模块 ```python from llama_index import SimpleDirectoryReader, GPTSimpleVectorIndex ``` **实现步骤**: 1. **数据加载**: ```python documents = SimpleDirectoryReader("data/").load_data() # 支持PDF/CSV等格式[^4] ``` 2. **索引构建**: ```python index = GPTSimpleVectorIndex.from_documents( documents, service_context=service_context # 包含Embedding模型配置 ) ``` 3. **查询执行**: ```python query_engine = index.as_query_engine( similarity_top_k=5, response_mode="compact" # 压缩响应模式 ) print(query_engine.query("总统主要说了什么?")) # 如[引用5]所示 ``` --- ### 四、**进阶路由检索 **实现位置**:`llama_index.router` 模块 ```python from llama_index.router import RouterRetriever ``` **方法实现**: 1. **多检索器路由**: ```python router_retriever = RouterRetriever( selector=LLMSingleSelector.from_defaults(), retrievers=[vector_retriever, bm25_retriever] ) ``` 2. **动态选择策略**: ```python selector = LLMRouterSelector.from_defaults(llm=llm) ``` **应用场景**:如[引用1]中Chainlit集成的路由检索系统[^1] --- ### 五、**关键配置参数 | 参数类型 | 实现位置 | 典型配置项 | |---------------------|---------------------------|---------------------------| | 索引参数 | `GPTSimpleVectorIndex` | chunk_size=512 | | 检索参数 | `RetrieverQueryEngine` | similarity_top_k=5 | | 后处理参数 | `NodePostprocessor` | similarity_cutoff=0.65 | | 路由参数 | `RouterRetriever` | selector_type="llm_based" | --- #### 最佳实践建议 1. **混合检索配置**:当需要同时处理精确术语(如产品型号)和语义查询(如功能描述)时优先使用[^2] 2. **窗口大小调整**:法律合同等结构化文本建议使用`window_size=5`,对话记录建议`window_size=3`[^3] 3. **性能优化**:十亿级数据需启用`Milvus`或`Weaviate`向量数据库的分片存储功能[^4] ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

CarlowZJ

我的文章对你有用的话,可以支持

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值