Adaptive RAG
自适应 RAG(Adaptive RAG)
概述
自适应 RAG是一种RAG策略,结合了(1)查询分析与(2)主动/自我纠正的RAG方法。在论文中,他们报告了通过查询分析来路由到以下几种策略:
- 不进行检索(No Retrieval)
- 一次性 RAG(Single-shot RAG)
- 迭代 RAG(Iterative RAG)
我们将使用LangGraph基于这些概念进行扩展和实现。在我们的实现中,我们将路由到以下两种策略:
- 网络搜索(Web search):用于与近期事件相关的问题。
- 自我纠正 RAG(Self-corrective RAG):用于与我们索引相关的问题。
系统架构图
系统的图形化表示如下所示:
设置环境(Setup)
首先,下载所需的包并设置必要的API密钥。
1. 安装必要的包
在Jupyter Notebook或终端中运行以下命令安装所需的包:
pip install -U langchain_community tiktoken langchain-openai langchain-cohere langchainhub chromadb langchain langgraph tavily-python
2. 设置API密钥
接下来,设置OpenAI、Cohere和Tavily的API密钥。以下代码将提示您输入API密钥并将其存储在环境变量中:
import getpass
import os
def _set_env(var: str):
if var not in os.environ:
os.environ[var] = getpass.getpass(f"{var}: ")
_set_env("OPENAI_API_KEY")
_set_env("COHERE_API_KEY")
_set_env("TAVILY_API_KEY")
3. 设置LangSmith用于LangGraph开发
LangSmith是一个用于调试、测试和监控LangGraph项目的工具。通过注册LangSmith,您可以使用跟踪数据来优化LangGraph应用程序的性能。详细的注册和使用方法请参考LangSmith入门指南。
创建索引(Create Index)
1. 构建索引
我们首先需要构建一个文档索引,以便后续的检索和生成过程。
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings
# 设置嵌入模型
embd = OpenAIEmbeddings()
# 要索引的文档URL
urls = [
"https://lilianweng.github.io/posts/2023-06-23-agent/",
"https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
"https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
]
# 使用WebBaseLoader加载文档
docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]
# 使用RecursiveCharacterTextSplitter拆分文档
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=500, chunk_overlap=0
)
doc_splits = text_splitter.split_documents(docs_list)
# 将拆分后的文档添加到向量存储中
vectorstore = Chroma.from_documents(
documents=doc_splits,
collection_name="rag-chroma",
embedding=embd,
)
retriever = vectorstore.as_retriever()
解释:
- WebBaseLoader:从指定的URL递归加载网页内容。
- RecursiveCharacterTextSplitter:将长文档拆分成较小的块,以便LLM更高效地处理。
- Chroma:使用向量存储(vectorstore)管理文档的嵌入向量,并提供高效的相似度检索。
- retriever:将向量存储作为检索器,供LLM调用以获取相关文档。
LLMs配置
使用Pydantic与LangChain
此部分使用Pydantic v2的BaseModel
,需要langchain-core >= 0.3
。使用langchain-core < 0.3
将导致因混合使用Pydantic v1和v2而出错。
1. 路由器(Router)
路由器负责分析用户查询,并决定将查询路由到哪种检索策略(网络搜索或自我纠正RAG)。
from typing import Literal
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field
# 数据模型
class RouteQuery(BaseModel):
"""将用户查询路由到最相关的数据源。"""
datasource: Literal["vectorstore", "web_search"] = Field(
...,
description="根据用户问题选择将其路由到web search或vectorstore。",
)
# 初始化LLM并绑定结构化输出
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
structured_llm_router = llm.with_structured_output(RouteQuery)
# 定义系统消息模板
system = """You are an expert at routing a user question to a vectorstore or web search.
The vectorstore contains documents related to agents, prompt engineering, and adversarial attacks.
Use the vectorstore for questions on these topics. Otherwise, use web-search."""
route_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("human", "{question}"),
]
)
# 构建路由链
question_router = route_prompt | structured_llm_router
# 示例调用
print(
question_router.invoke(
{"question": "Who will the Bears draft first in the NFL draft?"}
)
)
print(question_router.invoke({"question": "What are the types of agent memory?"}))
输出示例:
datasource='web_search'
datasource='vectorstore'
解释:
- RouteQuery:定义了路由器的输出结构,包括
datasource
字段,指示将查询路由到vectorstore
还是web_search
。 - ChatPromptTemplate:定义了与LLM交互的模板,包括系统消息和用户消息的占位符。
- question_router:结合了提示模板和LLM的路由链。
- 示例调用:根据用户问题,LLM决定将查询路由到
web_search
或vectorstore
。
2. 检索评分器(Retrieval Grader)
检索评分器用于评估检索到的文档是否与用户问题相关。
# 数据模型
class GradeDocuments(BaseModel):
"""评估检索到的文档相关性的二进制评分。"""
binary_score: str = Field(
description="文档是否与问题相关,'yes'或'no'"
)
# 初始化LLM并绑定结构化输出
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
structured_llm_grader = llm.with_structured_output(GradeDocuments)
# 定义系统消息模板
system = """You are a grader assessing relevance of a retrieved document to a user question. \n
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."""
grade_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("human", "Retrieved document: \n\n {document} \n\n User question: {question}"),
]
)
# 构建评分链
retrieval_grader = grade_prompt | structured_llm_grader
# 示例调用
question = "agent memory"
docs = retriever.invoke(question)
doc_txt = docs[1].page_content
print(retrieval_grader.invoke({"question": question, "document": doc_txt}))
binary_score='no'
输出示例:
GradeDocuments(binary_score='no')
解释:
- GradeDocuments:定义了评分器的输出结构,包括
binary_score
字段,值为"yes"
或"no"
。 - grade_prompt:定义了用于评估文档相关性的提示模板。
- retrieval_grader:结合了提示模板和LLM的评分链。
- 示例调用:评估特定文档是否与用户问题相关。
3. 生成回答节点(Generate)
生成回答节点基于检索到的文档生成最终回答。
from langchain import hub
from langchain_core.output_parsers import StrOutputParser
# 获取提示模板
prompt = hub.pull("rlm/rag-prompt")
# 初始化LLM
llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)
# 后处理函数:格式化文档
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
# 构建RAG链
rag_chain = prompt | llm | StrOutputParser()
# 运行RAG链
generation = rag_chain.invoke({"context": docs, "question": question})
print(generation)
输出示例:
The design of generative agents combines LLM with memory, planning, and reflection mechanisms to enable agents to behave based on past experience and interact with other agents. Memory stream is a long-term memory module that records agents' experiences in natural language. The retrieval model surfaces context to inform the agent's behavior based on relevance, recency, and importance.
解释:
- hub.pull(“rlm/rag-prompt”):从LangChain Hub拉取预定义的RAG提示模板。
- rag_chain:结合提示模板和LLM,创建一个RAG链。
- generation:基于上下文和用户问题生成的回答。
4. 幻觉评分器(Hallucination Grader)
幻觉评分器用于评估生成的回答是否基于检索到的事实,防止模型生成虚假信息。
# 数据模型
class GradeHallucinations(BaseModel):
"""评估生成回答是否基于事实的二进制评分。"""
binary_score: str = Field(
description="回答是否基于事实,'yes'或'no'"
)
# 初始化LLM并绑定结构化输出
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
structured_llm_grader = llm.with_structured_output(GradeHallucinations)
# 定义系统消息模板
system = """You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts. \n
Give a binary score 'yes' or 'no'. 'Yes' means that the answer is grounded in / supported by the set of facts."""
hallucination_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("human", "Set of facts: \n\n {documents} \n\n LLM generation: {generation}"),
]
)
# 构建幻觉评分链
hallucination_grader = hallucination_prompt | structured_llm_grader
# 示例调用
hallucination_grader.invoke({"documents": docs, "generation": generation})
# 输出示例
GradeHallucinations(binary_score='yes')
解释:
- GradeHallucinations:定义了幻觉评分器的输出结构,包括
binary_score
字段,值为"yes"
或"no"
。 - hallucination_prompt:定义了用于评估幻觉的提示模板。
- hallucination_grader:结合了提示模板和LLM的幻觉评分链。
- 示例调用:评估生成的回答是否基于检索到的事实。
5. 答案评分器(Answer Grader)
答案评分器用于评估生成的回答是否有效地回答了用户的问题。
# 数据模型
class GradeAnswer(BaseModel):
"""评估回答是否解决了问题的二进制评分。"""
binary_score: str = Field(
description="回答是否解决了问题,'yes'或'no'"
)
# 初始化LLM并绑定结构化输出
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
structured_llm_grader = llm.with_structured_output(GradeAnswer)
# 定义系统消息模板
system = """You are a grader assessing whether an answer addresses / resolves a question \n
Give a binary score 'yes' or 'no'. 'Yes' means that the answer resolves the question."""
answer_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("human", "User question: \n\n {question} \n\n LLM generation: {generation}"),
]
)
# 构建答案评分链
answer_grader = answer_prompt | structured_llm_grader
# 示例调用
answer_grader.invoke({"question": question, "generation": generation})
# 输出示例
GradeAnswer(binary_score='yes')
解释:
- GradeAnswer:定义了答案评分器的输出结构,包括
binary_score
字段,值为"yes"
或"no"
。 - answer_prompt:定义了用于评估答案的提示模板。
- answer_grader:结合了提示模板和LLM的答案评分链。
- 示例调用:评估生成的回答是否有效地解决了用户的问题。
6. 问题重写器(Question Re-writer)
问题重写器用于优化用户的问题,以提高检索效果。
# 初始化LLM
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
# 定义系统消息模板
system = """You a question re-writer that converts an input question to a better version that is optimized \n
for vectorstore retrieval. Look at the input and try to reason about the underlying semantic intent / meaning."""
re_write_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
(
"human",
"Here is the initial question: \n\n {question} \n Formulate an improved question.",
),
]
)
# 构建问题重写链
question_rewriter = re_write_prompt | llm | StrOutputParser()
# 示例调用
question_rewriter.invoke({"question": question})
# 输出示例
"What is the role of memory in an agent's functioning?"
解释:
- re_write_prompt:定义了用于重写问题的提示模板。
- question_rewriter:结合了提示模板和LLM的问题重写链。
- 示例调用:优化用户的原始问题,以提高检索效果。
7. 网络搜索工具(Web Search Tool)
网络搜索工具用于处理与近期事件相关的问题,通过网络搜索获取最新信息。
from langchain_community.tools.tavily_search import TavilySearchResults
# 初始化网络搜索工具
web_search_tool = TavilySearchResults(k=3)
解释:
- TavilySearchResults:定义了网络搜索工具,设置返回结果的数量为3。
- web_search_tool:网络搜索工具实例,供后续调用以获取相关信息。
构建图(Construct the Graph)
1. 定义图状态(Define Graph State)
首先,定义图的状态结构,包含问题、生成的回答和相关文档列表。
from typing import List
from typing_extensions import TypedDict
class GraphState(TypedDict):
"""
表示图的状态。
属性:
question: 用户问题
generation: LLM生成的回答
documents: 文档列表
"""
question: str
generation: str
documents: List[str]
解释:
- GraphState:定义了图的状态结构,包括用户问题(
question
)、生成的回答(generation
)和相关文档列表(documents
)。
2. 定义图流程(Define Graph Flow)
构建图的逻辑流程,包括检索、生成、评分和重写等节点。
from langchain.schema import Document
def retrieve(state):
"""
检索文档
Args:
state (dict): 当前图的状态
Returns:
dict: 更新状态,包含检索到的文档
"""
print("---RETRIEVE---")
question = state["question"]
# 调用检索器
documents = retriever.invoke(question)
return {"documents": documents, "question": question}
def generate(state):
"""
生成回答
Args:
state (dict): 当前图的状态
Returns:
dict: 更新状态,包含生成的回答
"""
print("---GENERATE---")
question = state["question"]
documents = state["documents"]
# RAG生成
generation = rag_chain.invoke({"context": documents, "question": question})
return {"documents": documents, "question": question, "generation": generation}
def grade_documents(state):
"""
确定检索到的文档是否与问题相关
Args:
state (dict): 当前图的状态
Returns:
dict: 更新状态,包含筛选后的相关文档
"""
print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
question = state["question"]
documents = state["documents"]
# 评分每个文档
filtered_docs = []
for d in documents:
score = retrieval_grader.invoke(
{"question": question, "document": d.page_content}
)
grade = score.binary_score
if grade == "yes":
print("---GRADE: DOCUMENT RELEVANT---")
filtered_docs.append(d)
else:
print("---GRADE: DOCUMENT NOT RELEVANT---")
continue
return {"documents": filtered_docs, "question": question}
def transform_query(state):
"""
转换查询,生成更好的问题
Args:
state (dict): 当前图的状态
Returns:
dict: 更新状态,包含重新表述的问题
"""
print("---TRANSFORM QUERY---")
question = state["question"]
documents = state["documents"]
# 重写问题
better_question = question_rewriter.invoke({"question": question})
return {"documents": documents, "question": better_question}
def web_search(state):
"""
基于重新表述的问题进行网络搜索
Args:
state (dict): 当前图的状态
Returns:
dict: 更新状态,包含网络搜索结果
"""
print("---WEB SEARCH---")
question = state["question"]
# 网络搜索
docs = web_search_tool.invoke({"query": question})
web_results = "\n".join([d["content"] for d in docs])
web_results = Document(page_content=web_results)
return {"documents": web_results, "question": question}
解释:
- retrieve:根据用户问题调用检索器,获取相关文档。
- generate:基于检索到的文档生成回答。
- grade_documents:评估每个检索到的文档是否与用户问题相关,并筛选出相关文档。
- transform_query:优化用户的问题,以提高检索效果。
- web_search:针对与近期事件相关的问题,进行网络搜索以获取最新信息。
3. 定义边(Edges)
定义节点之间的连接关系,决定流程的执行顺序。
def route_question(state):
"""
路由问题到网络搜索或RAG
Args:
state (dict): 当前图的状态
Returns:
str: 下一步调用的节点
"""
print("---ROUTE QUESTION---")
question = state["question"]
source = question_router.invoke({"question": question})
if source.datasource == "web_search":
print("---ROUTE QUESTION TO WEB SEARCH---")
return "web_search"
elif source.datasource == "vectorstore":
print("---ROUTE QUESTION TO RAG---")
return "retrieve"
def decide_to_generate(state):
"""
决定是否生成回答,或重新生成问题
Args:
state (dict): 当前图的状态
Returns:
str: 决策结果,决定下一步调用的节点
"""
print("---ASSESS GRADED DOCUMENTS---")
filtered_documents = state["documents"]
if not filtered_documents:
# 所有文档均不相关,重新转换查询
print(
"---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---"
)
return "transform_query"
else:
# 有相关文档,生成回答
print("---DECISION: GENERATE---")
return "generate"
def grade_generation_v_documents_and_question(state):
"""
确定生成的回答是否基于文档且回答了问题
Args:
state (dict): 当前图的状态
Returns:
str: 决策结果,决定下一步调用的节点
"""
print("---CHECK HALLUCINATIONS---")
question = state["question"]
documents = state["documents"]
generation = state["generation"]
score = hallucination_grader.invoke(
{"documents": documents, "generation": generation}
)
grade = score.binary_score
# 检查幻觉
if grade == "yes":
print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
# 检查回答是否解决了问题
print("---GRADE GENERATION vs QUESTION---")
score = answer_grader.invoke({"question": question, "generation": generation})
grade = score.binary_score
if grade == "yes":
print("---DECISION: GENERATION ADDRESSES QUESTION---")
return "useful"
else:
print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
return "transform_query"
else:
print("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
return "generate"
解释:
- route_question:根据路由器的决策,将问题路由到
web_search
或retrieve
节点。 - decide_to_generate:根据评估结果,决定是否生成回答或重新转换查询。
- grade_generation_v_documents_and_question:评估生成的回答是否基于检索到的文档且有效回答了问题,决定是否结束流程或重新生成回答。
4. 编译图(Compile Graph)
使用StateGraph
将所有节点和边连接起来,并编译图。
from langgraph.graph import END, StateGraph, START
workflow = StateGraph(GraphState)
# 定义节点
workflow.add_node("web_search", web_search) # 网络搜索节点
workflow.add_node("retrieve", retrieve) # 检索节点
workflow.add_node("grade_documents", grade_documents) # 评估文档相关性节点
workflow.add_node("generate", generate) # 生成回答节点
workflow.add_node("transform_query", transform_query) # 转换查询节点
# 定义边
workflow.add_conditional_edges(
START,
route_question,
{
"web_search": "web_search",
"vectorstore": "retrieve",
},
)
workflow.add_edge("web_search", "generate")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
"grade_documents",
decide_to_generate,
{
"transform_query": "transform_query",
"generate": "generate",
},
)
workflow.add_edge("transform_query", "retrieve")
workflow.add_conditional_edges(
"generate",
grade_generation_v_documents_and_question,
{
"not supported": "generate",
"useful": END,
"not useful": "transform_query",
},
)
# 编译图
app = workflow.compile()
解释:
- workflow.add_node:将各个节点添加到图中。
- “web_search”:网络搜索节点,处理与近期事件相关的问题。
- “retrieve”:检索节点,从向量存储中获取相关文档。
- “grade_documents”:评估文档相关性节点,筛选相关文档。
- “generate”:生成回答节点,基于相关文档生成最终回答。
- “transform_query”:转换查询节点,优化用户问题以提高检索效果。
- workflow.add_conditional_edges:根据条件决定节点之间的路由。
- START节点:根据
route_question
函数的输出,决定进入web_search
或retrieve
节点。 - "grade_documents"节点:根据
decide_to_generate
函数的输出,决定进入transform_query
或generate
节点。 - "generate"节点:根据
grade_generation_v_documents_and_question
函数的输出,决定是否结束流程(END
)或重新生成回答(generate
),或重新转换查询(transform_query
)。
- START节点:根据
- workflow.add_edge:定义节点之间的直接连接。
使用图(Use the Graph)
1. 导入必要模块
from pprint import pprint
2. 运行
定义输入并通过图进行处理。
# 示例调用1
inputs = {
"question": "What player at the Bears expected to draft first in the 2024 NFL draft?"
}
for output in app.stream(inputs):
for key, value in output.items():
# 节点
pprint(f"Node '{key}':")
# 可选:打印每个节点的完整状态
# pprint.pprint(value["keys"], indent=2, width=80, depth=None)
pprint("\n---\n")
# 最终生成的回答
pprint(value["generation"])
# 示例调用2
inputs = {"question": "What are the types of agent memory?"}
for output in app.stream(inputs):
for key, value in output.items():
# 节点
pprint(f"Node '{key}':")
# 可选:打印每个节点的完整状态
# pprint.pprint(value["keys"], indent=2, width=80, depth=None)
pprint("\n---\n")
# 最终生成的回答
pprint(value["generation"])
输出示例1:
---ROUTE QUESTION---
---ROUTE QUESTION TO WEB SEARCH---
---WEB SEARCH---
Node 'web_search':
'\n---\n'
---GENERATE---
---CHECK HALLUCINATIONS---
---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---
---GRADE GENERATION vs QUESTION---
---DECISION: GENERATION ADDRESSES QUESTION---
Node 'generate':
'\n---\n'
('It is expected that the Chicago Bears could have the opportunity to draft the first defensive player in the 2024 NFL draft. The Bears have the first overall pick in the draft, giving them a prime position to select top talent. The top wide receiver Marvin Harrison Jr. from Ohio State is also mentioned as a potential pick for the Cardinals.')
Trace:
https://smith.langchain.com/public/7e3aa7e5-c51f-45c2-bc66-b34f17ff2263/r
输出示例2:
---ROUTE QUESTION---
---ROUTE QUESTION TO RAG---
---RETRIEVE---
Node 'retrieve':
'\n---\n'
---CHECK DOCUMENT RELEVANCE TO QUESTION---
---GRADE: DOCUMENT RELEVANT---
---GRADE: DOCUMENT RELEVANT---
---GRADE: DOCUMENT NOT RELEVANT---
---GRADE: DOCUMENT RELEVANT---
---ASSESS GRADED DOCUMENTS---
---DECISION: GENERATE---
Node 'grade_documents':
'\n---\n'
---GENERATE---
---CHECK HALLUCINATIONS---
---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---
---GRADE GENERATION vs QUESTION---
---DECISION: GENERATION ADDRESSES QUESTION---
Node 'generate':
'\n---\n'
('The types of agent memory include Sensory Memory, Short-Term Memory (STM) or Working Memory, and Long-Term Memory (LTM) with subtypes of Explicit / declarative memory and Implicit / procedural memory. Sensory memory retains sensory information briefly, STM stores information for cognitive tasks, and LTM stores information for a long time with different types of memories.')
Trace:
https://smith.langchain.com/public/fdf0a180-6d15-4d09-bb92-f84f2105ca51/r
解释:
-
示例调用1:
- 用户问题被路由到
web_search
节点。 - 通过网络搜索获取相关信息。
- 生成回答后,评估回答是否基于检索到的事实且有效回答了问题。
- 最终生成的回答展示。
- 用户问题被路由到
-
示例调用2:
- 用户问题被路由到
retrieve
节点,从向量存储中检索相关文档。 - 评估每个文档的相关性,并筛选出相关文档。
- 基于相关文档生成回答。
- 评估回答的准确性和相关性。
- 最终生成的回答展示。
- 用户问题被路由到
评估(Eval)
在本节中,我们将评估使用LangGraph实现的自适应RAG系统与基线方法(Context Stuffing)的性能对比。
1. 导入必要模块
import langsmith
from langsmith.schemas import Example, Run
from langsmith.evaluation import evaluate
2. 克隆公共数据集
克隆一个公共的LCEL问题数据集,用于评估:
client = langsmith.Client()
# 克隆数据集到您的租户
try:
public_dataset = (
"https://smith.langchain.com/public/326674a6-62bd-462d-88ae-eea49d503f9d/d"
)
client.clone_public_dataset(public_dataset)
except:
print("Please setup LangSmith")
解释:
- clone_public_dataset:将公共数据集克隆到您的LangSmith租户中,以便进行评估。
- public_dataset:指定要克隆的数据集URL。
3. 定义自定义评估器
创建两个评估器,用于检查生成的回答是否正确导入和执行。
def check_import(run: Run, example: Example) -> dict:
"""检查导入语句是否正确"""
imports = run.outputs.get("imports")
try:
exec(imports)
return {"key": "import_check", "score": 1}
except Exception:
return {"key": "import_check", "score": 0}
def check_execution(run: Run, example: Example) -> dict:
"""检查代码块是否能正确执行"""
imports = run.outputs.get("imports")
code = run.outputs.get("code")
try:
exec(imports + "\n" + code)
return {"key": "code_execution_check", "score": 1}
except Exception:
return {"key": "code_execution_check", "score": 0}
解释:
- check_import:尝试执行导入语句,如果成功,返回分数1;否则,返回分数0。
- check_execution:尝试执行导入语句和代码块,如果成功,返回分数1;否则,返回分数0。
4. 定义预测函数
定义两个预测函数,分别用于基线方法和自适应RAG方法。
def predict_base_case(example: dict):
"""基线方法:Context Stuffing"""
solution = code_gen_chain.invoke(
{"context": concatenated_content, "messages": [("user", example["question"])]}
)
return {"imports": solution.imports, "code": solution.code}
def predict_langgraph(example: dict):
"""自适应RAG方法"""
graph = app.invoke(
{"messages": [("user", example["question"])], "iterations": 0, "error": ""}
)
solution = graph["generation"]
return {"imports": solution.imports, "code": solution.code}
解释:
- predict_base_case:使用基线方法(Context Stuffing)生成回答。
- predict_langgraph:使用自适应RAG方法生成回答。
5. 运行评估
使用LangSmith的evaluate
函数,分别评估基线方法和自适应RAG方法的性能。
# 评估器列表
code_evaluator = [check_import, check_execution]
# 数据集名称
dataset_name = "lcel-teacher-eval"
# 运行基线方法的评估
try:
experiment_results_ = evaluate(
predict_base_case,
data=dataset_name,
evaluators=code_evaluator,
experiment_prefix=f"test-without-langgraph-{llm.model}",
max_concurrency=2,
metadata={
"llm": llm.model,
},
)
except:
print("Please setup LangSmith")
# 运行自适应RAG方法的评估
try:
experiment_results = evaluate(
predict_langgraph,
data=dataset_name,
evaluators=code_evaluator,
experiment_prefix=f"test-with-langgraph-{llm.model}-{flag}",
max_concurrency=2,
metadata={
"llm": llm.model,
"feedback": flag,
},
)
except:
print("Please setup LangSmith")
解释:
- evaluate:运行评估,传入预测函数、数据集、评估器列表和其他配置参数。
- predict_base_case:基线方法的预测函数。
- predict_langgraph:自适应RAG方法的预测函数。
- code_evaluator:评估器列表,用于检查回答的导入和执行情况。
- experiment_prefix:定义实验的前缀,便于区分不同的实验结果。
- metadata:附加的元数据,用于记录LLM类型和反馈标志。
6. 结果
根据评估结果,自适应RAG方法的表现优于基线方法,特别是在添加重试机制后性能有所提升。然而,反思机制并未带来预期的改进,反而在某些情况下导致性能下降。此外,使用GPT-4模型的性能优于Claude3模型。
结果摘要:
- 自适应RAG优于基线方法(LangGraph outperforms base case):添加重试机制显著提高了性能。
- 反思机制未带来改进(Reflection did not help):在重试前进行反思反而导致性能下降,相比之下,直接将错误反馈给LLM更为有效。
- GPT-4优于Claude3(GPT-4 outperforms Claude3):GPT-4模型在执行工具调用时的错误率较低,表现优于Claude3模型。
您可以通过访问以下链接查看详细的评估结果:
总结
通过本节的讲解,您已经学习了如何使用LangGraph实现一个自适应RAG系统。这个系统能够根据用户查询动态选择检索策略,结合RAG和自我纠正机制,生成准确且相关的回答。具体来说,您已经掌握了以下内容:
- 系统设置:安装必要的包,配置API密钥,并设置LangSmith进行开发和监控。
- 索引创建:使用
WebBaseLoader
和Chroma
创建检索工具,索引并检索相关文档。 - LLM配置:
- 使用OpenAI的GPT-3.5和GPT-4进行路由、评分和生成。
- 定义Pydantic模型来结构化存储生成的回答和评分结果。
- 构建路由器、评分器和生成链。
- 状态管理:定义图的状态结构,包括用户问题、生成的回答和相关文档列表。
- 图定义:
- 定义检索、生成、评分和重写的节点。
- 定义条件边路由,决定流程的执行顺序。
- 评估:
- 使用LangSmith的评估功能,比较自适应RAG方法与基线方法的性能。
- 通过自定义评估器检查回答的准确性和相关性。
下一步建议:
- 扩展功能:可以进一步扩展系统,如增加更多的单元测试,集成更多的工具或优化重试和反思机制。
- 优化路由逻辑:根据评估结果,优化路由器的决策逻辑,提高系统的鲁棒性和生成回答的质量。
- 多模型集成:结合不同的LLM模型,探索多模型协作的可能性,进一步提升回答的准确性和效率。
- 部署和监控:将系统部署到生产环境中,并使用LangSmith进行持续的监控和优化,确保系统稳定运行。
汇总
# adaptive_rag.py
import getpass
import os
from typing import List
from typing_extensions import TypedDict
from pprint import pprint
# LangChain 和 LangGraph 相关导入
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain import hub
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from pydantic import BaseModel, Field
from langgraph.graph import END, StateGraph, START
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain.schema import Document
from typing import Literal
# 环境设置函数
def _set_env(var: str):
if not os.environ.get(var):
os.environ[var] = getpass.getpass(f"{var}: ")
_set_env("OPENAI_API_KEY")
_set_env("COHERE_API_KEY")
_set_env("TAVILY_API_KEY")
# 定义图的状态
class GraphState(TypedDict):
"""
表示图的状态。
属性:
question: 用户问题
generation: LLM生成的回答
web_search: 是否进行网络搜索
documents: 文档列表
"""
question: str
generation: str
web_search: str
documents: List[str]
# 初始化检索器
def initialize_retriever():
urls = [
"https://lilianweng.github.io/posts/2023-06-23-agent/",
"https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
"https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
]
# 加载文档
docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]
# 分割文档
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=500, chunk_overlap=0
)
doc_splits = text_splitter.split_documents(docs_list)
# 添加到向量数据库
vectorstore = Chroma.from_documents(
documents=doc_splits,
collection_name="rag-chroma",
embedding=OpenAIEmbeddings(),
)
retriever = vectorstore.as_retriever()
return retriever
# 定义评分器的数据模型和函数
# 1. 路由器(Router)
class RouteQuery(BaseModel):
"""将用户查询路由到最相关的数据源。"""
datasource: Literal["vectorstore", "web_search"] = Field(
...,
description="根据用户问题选择将其路由到网络搜索或向量存储。",
)
def initialize_question_router():
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
structured_llm_router = llm.with_structured_output(RouteQuery)
system = """你是一个专家,负责将用户的问题路由到向量存储或网络搜索。
向量存储包含与代理、提示工程和对抗性攻击相关的文档。
对于这些主题的问题,请使用向量存储。否则,请使用网络搜索。"""
route_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("human", "{question}"),
]
)
question_router = route_prompt | structured_llm_router
return question_router
# 2. 检索评分器(Retrieval Grader)
class GradeDocuments(BaseModel):
"""评估检索文档相关性的二元评分。"""
binary_score: str = Field(
description="文档是否与问题相关,'yes' 或 'no'"
)
def initialize_retrieval_grader():
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
structured_llm_grader = llm.with_structured_output(GradeDocuments)
system = """你是一个评分员,负责评估检索到的文档是否与用户的问题相关。
这不需要是严格的测试,目标是过滤掉错误的检索结果。
如果文档包含与用户问题相关的关键词或语义含义,请将其评分为相关。
请给出二元评分“yes”或“no”,以指示文档是否与问题相关。"""
grade_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("human", "检索到的文档:\n\n {document} \n\n 用户问题:{question}"),
]
)
retrieval_grader = grade_prompt | structured_llm_grader
return retrieval_grader
# 3. 幻觉评分器(Hallucination Grader)
class GradeHallucinations(BaseModel):
"""评估回答中是否存在幻觉的二元评分。"""
binary_score: str = Field(
description="回答是否基于事实支持,'yes' 或 'no'"
)
def initialize_hallucination_grader():
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
structured_llm_grader = llm.with_structured_output(GradeHallucinations)
system = """你是一个评分员,负责评估LLM生成的回答是否基于一组检索到的事实。
请给出二元评分“yes”或“no”。“yes”表示回答是基于这些事实支持的。"""
hallucination_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("human", "事实集:\n\n {documents} \n\n LLM生成的回答:{generation}"),
]
)
hallucination_grader = hallucination_prompt | structured_llm_grader
return hallucination_grader
# 4. 回答评分器(Answer Grader)
class GradeAnswer(BaseModel):
"""评估回答是否解决问题的二元评分。"""
binary_score: str = Field(
description="回答是否解决了问题,'yes' 或 'no'"
)
def initialize_answer_grader():
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
structured_llm_grader = llm.with_structured_output(GradeAnswer)
system = """你是一个评分员,负责评估一个回答是否解决了用户的问题。
请给出二元评分“yes”或“no”。“yes”表示回答解决了问题。"""
answer_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("human", "用户问题:\n\n {question} \n\n LLM生成的回答:{generation}"),
]
)
answer_grader = answer_prompt | structured_llm_grader
return answer_grader
# 5. 问题重写器(Question Re-writer)
def initialize_question_rewriter():
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
system = """你是一个问题重写器,负责将输入的问题转换为更好的版本,以优化向量存储检索或网络搜索的效果。
请查看输入的问题,并尝试推理其潜在的语义意图或含义。"""
re_write_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
(
"human",
"这是初始问题:\n\n {question} \n 请制定一个改进后的问题。",
),
]
)
question_rewriter = re_write_prompt | llm | StrOutputParser()
return question_rewriter
# 6. 生成回答链(Generate)
def initialize_rag_chain():
prompt = hub.pull("rlm/rag-prompt")
llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
rag_chain = prompt | llm | StrOutputParser()
return rag_chain
# 定义节点函数
def retrieve(state, retriever):
"""
检索文档
参数:
state (dict): 当前图的状态
retriever: 检索器对象
返回:
dict: 更新后的状态,包含检索到的文档
"""
print("---检索---")
question = state["question"]
# 检索
documents = retriever.invoke(question)
return {"documents": documents, "question": question}
def generate(state, rag_chain):
"""
生成回答
参数:
state (dict): 当前图的状态
rag_chain: RAG生成链对象
返回:
dict: 更新后的状态,包含生成的回答
"""
print("---生成回答---")
question = state["question"]
documents = state["documents"]
# RAG生成
generation = rag_chain.invoke({"context": documents, "question": question})
return {"documents": documents, "question": question, "generation": generation}
def grade_documents(state, retrieval_grader):
"""
评估检索到的文档是否相关
参数:
state (dict): 当前图的状态
retrieval_grader: 检索评分器对象
返回:
dict: 更新后的状态,包含过滤后的相关文档和是否进行网络搜索的标志
"""
print("---检查文档与问题的相关性---")
question = state["question"]
documents = state["documents"]
# 评分每个文档
filtered_docs = []
web_search = "No"
for d in documents:
score = retrieval_grader.invoke(
{"question": question, "document": d.page_content}
)
grade = score.binary_score
if grade == "yes":
print("---评分:文档相关---")
filtered_docs.append(d)
else:
print("---评分:文档不相关---")
web_search = "Yes"
continue
return {"documents": filtered_docs, "question": question, "web_search": web_search}
def transform_query(state, question_rewriter):
"""
优化问题
参数:
state (dict): 当前图的状态
question_rewriter: 问题重写器对象
返回:
dict: 更新后的状态,包含优化后的问题
"""
print("---优化问题---")
question = state["question"]
documents = state["documents"]
# 重写问题
better_question = question_rewriter.invoke({"question": question})
print(f"优化后的问题:{better_question}")
return {"documents": documents, "question": better_question}
def web_search(state, web_search_tool):
"""
基于优化后的问题进行网络搜索
参数:
state (dict): 当前图的状态
web_search_tool: 网络搜索工具对象
返回:
dict: 更新后的状态,包含追加的网络搜索结果
"""
print("---网络搜索---")
question = state["question"]
# 网络搜索
docs = web_search_tool.invoke({"query": question})
web_results = "\n".join([d["content"] for d in docs])
web_result_doc = Document(page_content=web_results)
return {"documents": [web_result_doc], "question": question}
def route_question(state, question_router):
"""
路由问题到网络搜索或向量存储
参数:
state (dict): 当前图的状态
question_router: 路由器对象
返回:
str: 下一个节点的名称
"""
print("---路由问题---")
question = state["question"]
source = question_router.invoke({"question": question})
if source.datasource == "web_search":
print("---路由问题到网络搜索---")
return "web_search"
elif source.datasource == "vectorstore":
print("---路由问题到向量存储---")
return "retrieve"
def decide_to_generate(state):
"""
决定是否生成回答或重新优化问题
参数:
state (dict): 当前图的状态
返回:
str: 下一个节点的名称
"""
print("---评估已评分的文档---")
web_search = state["web_search"]
if web_search == "Yes":
# 所有文档均不相关,重新优化问题
print("---决策:所有文档与问题不相关,优化问题---")
return "transform_query"
else:
# 有相关文档,生成回答
print("---决策:生成回答---")
return "generate"
def grade_generation_v_documents_and_question(state, hallucination_grader, answer_grader):
"""
评估生成的回答是否基于文档并解决了问题
参数:
state (dict): 当前图的状态
hallucination_grader: 幻觉评分器对象
answer_grader: 回答评分器对象
返回:
str: 下一个节点的名称
"""
print("---检查幻觉---")
question = state["question"]
documents = state["documents"]
generation = state["generation"]
score = hallucination_grader.invoke(
{"documents": documents, "generation": generation}
)
grade = score.binary_score
# 检查是否存在幻觉
if grade == "yes":
print("---决策:生成的回答基于文档---")
# 检查回答是否解决了问题
print("---评分生成的回答是否解决问题---")
score = answer_grader.invoke({"question": question, "generation": generation})
grade = score.binary_score
if grade == "yes":
print("---决策:生成的回答解决了问题---")
return "useful"
else:
print("---决策:生成的回答未解决问题---")
return "not useful"
else:
print("---决策:生成的回答未基于文档,重试---")
return "not supported"
# 构建并编译图
def build_workflow(
retrieve_fn,
grade_documents_fn,
generate_fn,
transform_query_fn,
decide_to_generate_fn,
grade_generation_fn,
route_question_fn,
web_search_fn,
retriever,
rag_chain,
retrieval_grader,
hallucination_grader,
answer_grader,
question_rewriter,
question_router,
web_search_tool
):
workflow = StateGraph(GraphState)
# 定义节点
workflow.add_node("web_search", lambda state: web_search_fn(state, web_search_tool))
workflow.add_node("retrieve", lambda state: retrieve_fn(state, retriever))
workflow.add_node("grade_documents", lambda state: grade_documents_fn(state, retrieval_grader))
workflow.add_node("generate", lambda state: generate_fn(state, rag_chain))
workflow.add_node("transform_query", lambda state: transform_query_fn(state, question_rewriter))
# 构建边
workflow.add_conditional_edges(
START,
lambda state: route_question_fn(state, question_router),
{
"web_search": "web_search",
"vectorstore": "retrieve",
},
)
workflow.add_edge("web_search", "generate")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
"grade_documents",
decide_to_generate_fn,
{
"transform_query": "transform_query",
"generate": "generate",
},
)
workflow.add_edge("transform_query", "retrieve")
workflow.add_conditional_edges(
"generate",
grade_generation_fn,
{
"not supported": "generate",
"useful": END,
"not useful": "transform_query",
},
)
# 编译图
app = workflow.compile()
return app
# 运行图
def run_workflow(app, inputs):
for output in app.stream(inputs):
for key, value in output.items():
# 打印每个节点的状态
pprint(f"节点 '{key}':")
# 可选:打印每个节点的详细状态
# pprint.pprint(value["keys"], indent=2, width=80, depth=None)
pprint("\n---\n")
# 打印最终生成的回答
pprint(value.get("generation", "没有生成回答"))
def main():
# 初始化组件
retriever = initialize_retriever()
question_router = initialize_question_router()
retrieval_grader = initialize_retrieval_grader()
hallucination_grader = initialize_hallucination_grader()
answer_grader = initialize_answer_grader()
question_rewriter = initialize_question_rewriter()
rag_chain = initialize_rag_chain()
web_search_tool = TavilySearchResults(k=3)
# 构建工作流
app = build_workflow(
retrieve_fn=retrieve,
grade_documents_fn=grade_documents,
generate_fn=generate,
transform_query_fn=transform_query,
decide_to_generate_fn=decide_to_generate,
grade_generation_fn=grade_generation_v_documents_and_question,
route_question_fn=route_question,
web_search_fn=web_search,
retriever=retriever,
rag_chain=rag_chain,
retrieval_grader=retrieval_grader,
hallucination_grader=hallucination_grader,
answer_grader=answer_grader,
question_rewriter=question_rewriter,
question_router=question_router,
web_search_tool=web_search_tool
)
# 运行第一个示例
print("=== 示例 1 ===")
inputs1 = {
"question": "What player at the Bears expected to draft first in the 2024 NFL draft?"
}
run_workflow(app, inputs1)
# 运行第二个示例
print("\n=== 示例 2 ===")
inputs2 = {"question": "What are the types of agent memory?"}
run_workflow(app, inputs2)
if __name__ == "__main__":
main()
代码说明
- 环境设置:
- 使用
getpass
获取OPENAI_API_KEY
、COHERE_API_KEY
和TAVILY_API_KEY
并设置为环境变量。 - 请确保在运行脚本时输入有效的 API 密钥。
- 使用
- 检索器初始化:
- 从指定的 URL 加载文档。
- 使用
RecursiveCharacterTextSplitter
将文档分割成较小的片段。 - 将文档片段存储到 Chroma 向量数据库中,以便高效检索。
- 评分器初始化:
- 路由器(Router):决定将用户的问题路由到向量存储还是网络搜索。
- 检索评分器(Retrieval Grader):评估检索到的文档是否与问题相关。
- 幻觉评分器(Hallucination Grader):评估生成的回答是否基于检索到的文档,避免幻觉。
- 回答评分器(Answer Grader):评估生成的回答是否解决了用户的问题。
- 问题重写器初始化:
- 优化用户输入的问题,以便更好地进行网络搜索或向量存储检索。
- 生成链初始化(Generate):
- 使用 LangChain 的
hub.pull("rlm/rag-prompt")
获取预定义的 RAG 提示语。 - 配置 LLM 生成回答。
- 使用 LangChain 的
- 节点函数定义:
- retrieve:检索相关文档。
- generate:基于检索到的文档生成回答。
- grade_documents:评估检索到的文档是否相关。
- transform_query:优化用户的问题。
- web_search:基于优化后的问题进行网络搜索。
- route_question:路由问题到网络搜索或向量存储。
- decide_to_generate:决定是生成回答还是优化问题。
- grade_generation_v_documents_and_question:评估生成的回答是否基于文档并解决了问题。
- 图的构建与编译:
- 使用 LangGraph 的
StateGraph
定义工作流图。 - 添加节点和边,定义流程控制逻辑。
- 编译图为可执行的应用对象
app
。
- 使用 LangGraph 的
- 运行图:
- 提供输入问题,运行整个工作流,生成并打印最终的回答。
- 示例中包含两个问题进行演示。
执行示例
运行脚本后,您将看到如下类似的输出:
plaintext复制代码=== 示例 1 ===
---路由问题---
---路由问题到网络搜索---
---网络搜索---
节点 'web_search':
---
---生成回答---
---检查幻觉---
---决策:生成的回答基于文档---
---评分生成的回答是否解决问题---
---决策:生成的回答解决了问题---
节点 'generate':
---
('It is expected that the Chicago Bears could have the opportunity to draft the first defensive player in the 2024 NFL draft. The Bears have the first overall pick in the draft, giving them a prime position to select top talent. The top wide receiver Marvin Harrison Jr. from Ohio State is also mentioned as a potential pick for the Cardinals.')
=== 示例 2 ===
---路由问题---
---路由问题到向量存储---
---检索---
---检查文档与问题的相关性---
---评分:文档相关---
---评分:文档相关---
---评分:文档不相关---
---评分:文档相关---
---评估已评分的文档---
---决策:生成回答---
节点 'grade_documents':
---
---生成回答---
---检查幻觉---
---决策:生成的回答基于文档---
---评分生成的回答是否解决问题---
---决策:生成的回答解决了问题---
节点 'generate':
---
('The types of agent memory include Sensory Memory, Short-Term Memory (STM) or Working Memory, and Long-Term Memory (LTM) with subtypes of Explicit / declarative memory and Implicit / procedural memory. Sensory memory retains sensory information briefly, STM stores information for cognitive tasks, and LTM stores information for a long time with different types of memories.')
注意事项
- 包版本:请确保安装的包版本与代码兼容,尤其是
langchain
和langgraph
。如果遇到兼容性问题,请参考相应包的官方文档进行调整。 - Prompt 模板:代码中使用了
hub.pull("rlm/rag-prompt")
来获取 RAG 提示语。请确保该提示语存在于 LangChain 的 Hub 中,或者根据需要自定义提示语。 - 错误处理:为了简化代码示例,未添加详细的错误处理逻辑。在实际应用中,建议添加适当的异常处理,以提高代码的鲁棒性。
- LangSmith:代码中提到了 LangSmith,用于调试和监控 LangGraph 项目。如果需要使用,请参考 LangSmith 官方文档 进行配置。
总结
Adaptive RAG (ARAG) 通过将查询分析与主动/自我纠正的 RAG 方法结合起来,动态地选择最合适的数据源(如向量存储或网络搜索),进一步提升了文档检索和回答生成的质量和相关性。通过使用 LangGraph 构建流程图,可以实现自动化的工作流控制,确保每一步骤的逻辑和评估都得到有效执行。整体流程包括:
- 路由问题:根据查询内容决定使用向量存储还是网络搜索。
- 检索相关文档(向量存储)或 进行网络搜索。
- 评估文档的相关性。
- 生成基于文档的回答。
- 评估回答的准确性和相关性。
- 如果所有文档不相关或评分器不确定,优化问题并重新检索。
这种方法不仅提高了回答的准确性,还通过网络搜索补充了信息来源,减少了因信息不足或错误信息引发的问题,是实现高质量问答系统的有效策略。