原始方法
from llama_index.core.retrievers import QueryFusionRetriever
def _simple_fusion(
self, results: Dict[Tuple[str, int], List[NodeWithScore]]
) -> List[NodeWithScore]:
"""Apply simple fusion."""
# Use a dict to de-duplicate nodes
all_nodes: Dict[str, NodeWithScore] = {}
for nodes_with_scores in results.values():
for node_with_score in nodes_with_scores:
hash = node_with_score.node.hash
if hash in all_nodes:
max_score = max(node_with_score.score, all_nodes[hash].score)
all_nodes[hash].score = max_score
else:
all_nodes[hash] = node_with_score
return sorted(all_nodes.values(), key=lambda x: x.score or 0.0, reverse=True)
改进
原始的 _simple_fusion
方法存在一个问题,即在比较分数时,如果分数为 None
,会导致 TypeError
。为了改进这个方法,我们需要确保在比较分数时,None
值被处理为默认值(例如 0)。
以下是改进后的 _simple_fusion
方法:
def _simple_fusion(
self, results: Dict[Tuple[str, int], List[NodeWithScore]]
) -> List[NodeWithScore]:
"""Apply simple fusion."""
# Use a dict to de-duplicate nodes
all_nodes: Dict[str, NodeWithScore] = {}
for nodes_with_scores in results.values():
for node_with_score in nodes_with_scores:
hash = node_with_score.node.hash
if hash in all_nodes:
score1 = node_with_score.score if node_with_score.score is not None else 0
score2 = all_nodes[hash].score if all_nodes[hash].score is not None else 0
max_score = max(score1, score2)
all_nodes[hash].score = max_score
else:
all_nodes[hash] = node_with_score
return sorted(all_nodes.values(), key=lambda x: x.score or 0.0, reverse=True)
关键改进点
- 处理
None
分数:在比较分数时,确保None
值被处理为默认值(例如 0)。 - 排序:在返回结果之前,按分数降序排序。
详细解释
-
处理
None
分数:- 在比较分数时,使用
score1 = node_with_score.score if node_with_score.score is not None else 0
和score2 = all_nodes[hash].score if all_nodes[hash].score is not None else 0
确保None
值被处理为 0。 - 使用
max_score = max(score1, score2)
计算最大分数。
- 在比较分数时,使用
-
排序:
- 使用
sorted(all_nodes.values(), key=lambda x: x.score or 0.0, reverse=True)
按分数降序排序结果。
- 使用
通过这些改进,你可以确保在比较分数时不会出现 TypeError
,并且结果按分数降序排序。希望这些建议能帮助你解决问题。