import asyncio
from enum import Enum
from typing import Dict, List, Optional, Tuple, cast
from llama_index.core.async_utils import run_async_tasks
from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.constants import DEFAULT_SIMILARITY_TOP_K
from llama_index.core.llms.utils import LLMType, resolve_llm
from llama_index.core.prompts import PromptTemplate
from llama_index.core.prompts.mixin import PromptDictType
from llama_index.core.retrievers import BaseRetriever
from llama_index.core.schema import IndexNode, NodeWithScore, QueryBundle
from llama_index.core.settings import Settings
QUERY_GEN_PROMPT = (
"You are a helpful assistant that generates multiple search queries based on a "
"single input query. Generate {num_queries} search queries, one on each line, "
"related to the following input query:\n"
"Query: {query}\n"
"Queries:\n"
)
class FUSION_MODES(str, Enum):
"""Enum for different fusion modes."""
RECIPROCAL_RANK = "reciprocal_rerank"
RELATIVE_SCORE = "relative_score"
DIST_BASED_SCORE = "dist_based_score"
SIMPLE = "simple"
class QueryFusionRetriever(BaseRetriever):
def __init__(
self,
retrievers: List[BaseRetriever],
llm: Optional[LLMType] = None,
query_gen_prompt: Optional[str] = None,
mode: FUSION_MODES = FUSION_MODES.SIMPLE,
similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
num_queries: int = 4,
use_async: bool = True,
verbose: bool = False,
callback_manager: Optional[CallbackManager] = None,
objects: Optional[List[IndexNode]] = None,
object_map: Optional[dict] = None,
retriever_weights: Optional[List[float]] = None,
) -> None:
self.num_queries = num_queries
self.query_gen_prompt = query_gen_prompt or QUERY_GEN_PROMPT
self.similarity_top_k = similarity_top_k
self.mode = mode
self.use_async = use_async
self._retrievers = retrievers
if retriever_weights is None:
self._retriever_weights = [1.0 / len(retrievers)] * len(retrievers)
else:
total_weight = sum(retriever_weights)
self._retriever_weights = [w / total_weight for w in retriever_weights]
self._llm = (
resolve_llm(llm, callback_manager=callback_manager) if llm else Settings.llm
)
super().__init__(
callback_manager=callback_manager,
object_map=object_map,
objects=objects,
verbose=verbose,
)
def _get_prompts(self) -> PromptDictType:
"""Get prompts."""
return {