import asyncio
from enum import Enum
from typing import Dict, List, Optional, Tuple, cast
from llama_index.core.async_utils import run_async_tasks
from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.constants import DEFAULT_SIMILARITY_TOP_K
from llama_index.core.llms.utils import LLMType, resolve_llm
from llama_index.core.prompts import PromptTemplate
from llama_index.core.prompts.mixin import PromptDictType
from llama_index.core.retrievers import BaseRetriever
from llama_index.core.schema import IndexNode, NodeWithScore, QueryBundle
from llama_index.core.settings import Settings
QUERY_GEN_PROMPT = (
"You are a helpful assistant that generates multiple search queries based on a "
"single input query. Generate {num_queries} search queries, one on each line, "
"related to the following input query:\n"
"Query: {query}\n"
"Queries:\n"
)
class FUSION_MODES(str, Enum):
"""Enum for different fusion modes."""
RECIPROCAL_RANK = "reciprocal_rerank" # apply reciprocal rank fusion
RELATIVE_SCORE = "relative_score" # apply relative score fusion
DIST_BASED_SCORE = "dist_based_score" # apply distance-based score fusion
SIMPLE = "simple" # simple re-ordering of results based on original scores
class QueryFusionRetriever(BaseRetriever):
def __init__(
self,
retrievers: List[BaseRetriever],
llm: Optional[LLMType] = None,
query_gen_prompt: Optional[str] = None,
mode: FUSION_MODES = FUSION_MODES.SIMPLE,
similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
num_queries: int = 4,
use_async: bool = True,
verbose: bool = False,
callback_manager: Optional[CallbackManager] = None,
objects: Optional[List[IndexNode]] = None,
object_map: Optional[dict] = None,
retriever_weights: Optional[List[float]] = None,
) -> None:
self.num_queries = num_queries
self.query_gen_prompt = query_gen_prompt or QUERY_GEN_PROMPT
self.similarity_top_k = similarity_top_k
self.mode = mode
self.use_async = use_async
self._retrievers = retrievers
if retriever_weights is None:
self._retriever_weights = [1.0 / len(retrievers)] * len(retrievers)
else:
# Sum of retriever_weights must be 1
total_weight = sum(retriever_weights)
self._retriever_weights = [w / total_weight for w in retriever_weights]
self._llm = (
resolve_llm(llm, callback_manager=callback_manager) if llm else Settings.llm
)
super().__init__(
callback_manager=callback_manager,
object_map=object_map,
objects=objects,
verbose=verbose,
)
def _get_prompts(self) -> PromptDictType:
"""Get prompts."""
return {
120 llamaindex.core.retrievers.fusion_retriever源码
最新推荐文章于 2025-05-21 10:36:40 发布