在问答应用程序中,向用户展示生成答案所用的来源信息非常重要。实现这一功能的最简单方式是让链条返回每次生成中检索到的文档。本篇文章将以我们在RAG教程中通过Lilian Weng的《LLM Powered Autonomous Agents》博客构建的Q&A应用为基础,讨论以下两种方法:
- 使用内置的
create_retrieval_chain
,它默认返回来源。 - 使用一个简单的LCEL实现,以展示操作原理。
此外,我们还将展示如何将来源信息结构化到模型响应中,使模型可以报告其在生成答案时使用了哪些具体来源。
依赖设置
我们将使用OpenAI嵌入和Chroma向量存储,但这里展示的内容可以与任何Embeddings、VectorStore或Retriever一起使用。需要安装以下软件包:
%pip install --upgrade --quiet langchain langchain-community langchainhub langchain-openai langchain-chroma bs4
需要设置环境变量OPENAI_API_KEY
,可以直接设置或从.env
文件中加载:
import getpass
import os
os.environ["OPENAI_API_KEY"] = getpass.getpass()
# import dotenv
# dotenv.load_dotenv()
使用create_retrieval_chain
首先我们选择一个LLM(大语言模型):
pip install -qU langchain-openai
import getpass
import os
os.environ["OPENAI_API_KEY"] = getpass.getpass()
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(model="gpt-4o-mini")
以下是我们在RAG教程中基于Lilian Weng的博客构建的带有来源展示的Q&A应用示例:
import bs4
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_chroma import Chroma
from langchain_community.document_loaders import WebBaseLoader
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import OpenAIEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
# 1. Load, chunk and index the contents of the blog to create a retriever.
loader = WebBaseLoader(
web_paths=("https://lilianweng.github.io/posts/2023-06-23-agent/",),
bs_kwargs=dict(
parse_only=bs4.SoupStrainer(
class_=("post-content", "post-title", "post-header")
)
),
)
docs = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
splits = text_splitter.split_documents(docs)
vectorstore = Chroma.from_documents(documents=splits, embedding=OpenAIEmbeddings())
retriever = vectorstore.as_retriever()
# 2. Incorporate the retriever into a question-answering chain.
system_prompt = (
"You are an assistant for question-answering tasks. "
"Use the following pieces of retrieved context to answer "
"the question. If you don't know the answer, say that you "
"don't know. Use three sentences maximum and keep the "
"answer concise."
"\n\n"
"{context}"
)
prompt = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
("human", "{input}"),
]
)
question_answer_chain = create_stuff_documents_chain(llm, prompt)
rag_chain = create_retrieval_chain(retriever, question_answer_chain)
result = rag_chain.invoke({"input": "What is Task Decomposition?"})
在结果中,"context"
包含了LLM在生成"answer"
时使用的来源。
自定义LCEL实现
下面我们构造一个与create_retrieval_chain
类似的链。这个实现通过构建字典逐步实现功能:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
rag_chain_from_docs = (
{
"input": lambda x: x["input"], # input query
"context": lambda x: format_docs(x["context"]), # context
}
| prompt # format query and context into prompt
| llm # generate response
| StrOutputParser() # coerce to string
)
retrieve_docs = (lambda x: x["input"]) | retriever
chain = RunnablePassthrough.assign(context=retrieve_docs).assign(
answer=rag_chain_from_docs
)
chain.invoke({"input": "What is Task Decomposition"})
将来源结构化到模型响应中
到目前为止,我们只是简单地将检索步骤返回的文档传递到了最终响应中。然而,这可能无法具体说明模型在生成答案时依赖了哪些信息。接下来,我们展示如何将来源信息结构化到模型响应中,使模型报告其确切依赖的上下文。
from typing import List
from langchain_core.runnables import RunnablePassthrough
from typing_extensions import Annotated, TypedDict
# Desired schema for response
class AnswerWithSources(TypedDict):
answer: str
sources: Annotated[List[str], ..., "List of sources (author + year) used to answer the question"]
# Our rag_chain_from_docs has the following changes:
rag_chain_from_docs = (
{
"input": lambda x: x["input"],
"context": lambda x: format_docs(x["context"]),
}
| prompt
| llm.with_structured_output(AnswerWithSources)
)
retrieve_docs = (lambda x: x["input"]) | retriever
chain = RunnablePassthrough.assign(context=retrieve_docs).assign(
answer=rag_chain_from_docs
)
response = chain.invoke({"input": "What is Chain of Thought?"})
import json
print(json.dumps(response["answer"], indent=2))
这样,模型可以清楚地报告其在回答问题时依赖的具体来源信息。
如果遇到问题欢迎在评论区交流。
—END—