修改 QueryFusionRetriever _simple_fusion方法(解决bug)

原始方法
from llama_index.core.retrievers import QueryFusionRetriever

def _simple_fusion(
        self, results: Dict[Tuple[str, int], List[NodeWithScore]]
    ) -> List[NodeWithScore]:
        """Apply simple fusion."""
        # Use a dict to de-duplicate nodes
        all_nodes: Dict[str, NodeWithScore] = {}
        for nodes_with_scores in results.values():
            for node_with_score in nodes_with_scores:
                hash = node_with_score.node.hash
                if hash in all_nodes:
                    max_score = max(node_with_score.score, all_nodes[hash].score)
                    all_nodes[hash].score = max_score
                else:
                    all_nodes[hash] = node_with_score

        return sorted(all_nodes.values(), key=lambda x: x.score or 0.0, reverse=True)

改进
原始的 _simple_fusion 方法存在一个问题,即在比较分数时,如果分数为 None,会导致 TypeError。为了改进这个方法,我们需要确保在比较分数时,None 值被处理为默认值(例如 0)。

以下是改进后的 _simple_fusion 方法:

def _simple_fusion(
        self, results: Dict[Tuple[str, int], List[NodeWithScore]]
    ) -> List[NodeWithScore]:
        """Apply simple fusion."""
        # Use a dict to de-duplicate nodes
        all_nodes: Dict[str, NodeWithScore] = {}
        for nodes_with_scores in results.values():
            for node_with_score in nodes_with_scores:
                hash = node_with_score.node.hash
                if hash in all_nodes:
                    score1 = node_with_score.score if node_with_score.score is not None else 0
                    score2 = all_nodes[hash].score if all_nodes[hash].score is not None else 0
                    max_score = max(score1, score2)
                    all_nodes[hash].score = max_score
                else:
                    all_nodes[hash] = node_with_score

        return sorted(all_nodes.values(), key=lambda x: x.score or 0.0, reverse=True)

关键改进点

  1. 处理 None 分数:在比较分数时,确保 None 值被处理为默认值(例如 0)。
  2. 排序:在返回结果之前,按分数降序排序。

详细解释

  1. 处理 None 分数

    • 在比较分数时,使用 score1 = node_with_score.score if node_with_score.score is not None else 0score2 = all_nodes[hash].score if all_nodes[hash].score is not None else 0 确保 None 值被处理为 0。
    • 使用 max_score = max(score1, score2) 计算最大分数。
  2. 排序

    • 使用 sorted(all_nodes.values(), key=lambda x: x.score or 0.0, reverse=True) 按分数降序排序结果。

通过这些改进,你可以确保在比较分数时不会出现 TypeError,并且结果按分数降序排序。希望这些建议能帮助你解决问题。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

需要重新演唱

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值