一、RAG简介
RAG(Retrieval-Augmented Generation,检索增强生成)是一种结合了信息检索技术与语言生成模型的人工智能技术。它通过从外部知识库中检索相关信息,并将其作为提示(Prompt)输入给大型语言模型(LLMs),以增强模型处理知识密集型任务的能力 。
用人话来讲,RAG检索就是通过模型将数据集编码,用户在提问之后,模型将用户的问题也进行编码,然后从编码好的数据集中寻找与用户问题编码相似的内容。找到相似内容之后,模型会根据你写的prompt参考检索的内容来组织答案。
二、代码实现
下面是一个简单的RAG实例:
from langchain.prompts import PromptTemplate
from langchain_openai import ChatOpenAI
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain.text_splitter import CharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_core.embeddings import Embeddings
from langchain.vectorstores import FAISS
from sentence_transformers import SentenceTransformer
import torch
import os
from langchain_community.document_loaders import UnstructuredWordDocumentLoader
from langchain.retrievers.document_compressors import DocumentCompressorPipeline,EmbeddingsFilter
from langchain_community.document_transformers import EmbeddingsRedundantFilter
from file_loader import CustomDocumentLoader
from langchain.prompts import PromptTemplate
from langchain_openai import ChatOpenAI
from langchain_core.pydantic_v1 import BaseModel, Field
device = "cuda:0" # the device to load the model onto
#加载embedding模型
model_name = "/home/media4090/wz/model/vllm/Xorbits/bge-large-zh-v1.5"
hf = HuggingFaceEmbeddings(model_name=model_name, model_kwargs={'device': device})
def get_files_in_folder(folder):
file_paths = [] # 用于存储文件路径的列表
# 遍历文件夹中的所有文件和子文件夹
for root, directories, files in os.walk(folder):
for filename in files:
file_path = os.path.abspath(os.path.join(root, filename))
file_paths.append(file_path)
return file_paths
# 加载数据
def get_data_folder(data_folder_path):
file_paths = get_files_in_folder(data_folder_path)
#循环判断所有文件的类别
all_data = []
document_data = []
for file_path in file_paths:
# 获取文件名
file_name = os.path.basename(file_path)
base_extension = file_name.split('.')
if base_extension[-1] == 'csv':
#仅仅对csv这种结构化数据进行处理
loader = CustomDocumentLoader(file_path,'query')
all_data.extend(loader.load())
elif base_extension[-1] == 'docx':
doc_loader = UnstructuredWordDocumentLoader(file_path)
document_data.extend(doc_loader.load())
return all_data, document_data
#为csv文件的向量库进行存储到本地
csv_data , _ = get_data_folder('./data/')
text_splitter_csv = CharacterTextSplitter(separator="\n\n", chunk_size=256, chunk_overlap=20)
split_csv = text_splitter_csv.split_documents(csv_data)
text = []
for i in split_csv:
text.append([j for j in i.page_content])
from langchain.storage import LocalFileStore
from langchain.embeddings import CacheBackedEmbeddings
store = LocalFileStore("./cache/")
cached_embedder = CacheBackedEmbeddings.from_bytes_store(
hf, store, namespace='csv_embeddings'
)
vector_store = FAISS.from_documents(split_csv, cached_embedder)
llm = ChatOpenAI(model="/home/media4090/wz/model/qwen/qwen2-7b-instruct", max_tokens=1000, temperature=0, base_url="http://127.0.0.1:9997/v1",api_key="not-needed")
#构建self-rag的过程
class RetrievalResponse(BaseModel):
response: str = Field(..., title="Determines if retrieval is necessary", description="Output only 'Yes' or 'No'.")
retrieval_prompt = PromptTemplate(
input_variables=["query"],
template="给定一个问题 '{query}', 确定是否需要检索。仅输出“是”或“否”。"
)
class RelevanceResponse(BaseModel):
response: str = Field(..., title="Determines if context is relevant", description="Output only 'Relevant' or 'Irrelevant'.")
relevance_prompt = PromptTemplate(
input_variables=["query", "context"],
template="给定一个问题 '{query}' 以及对应的资料 '{context}', 确定上下文是否相关。仅输出“相关”或“不相关”。"
)
generation_prompt = PromptTemplate(
input_variables=["query", "context"],
template="给定一个问题 '{query}' 以及对应的资料 '{context}', 依据上述资料对问题进行回答,尽可能不使用超出资料的知识。"
)
support_prompt = PromptTemplate(
input_variables=["response", "context"],
template="给定一个问题 '{response}' 以及对应的资料 '{context}', 确定响应是否得到上下文的支持。输出“完全支持”、“部分支持”或“不支持”。",
validate_template=True
)
utility_prompt = PromptTemplate(
input_variables=["query", "response"],
template="给定一个问题 '{query}' 以及对应的资料 '{response}'。 判断资料是否对回答问题有帮助,并在1-5之间进行打分。",
)
# Create LLMChains for each step
retrieval_chain = retrieval_prompt | llm.with_structured_output(RetrievalResponse)
relevance_chain = relevance_prompt | llm.with_structured_output(RelevanceResponse)
generation_chain = generation_prompt | llm
support_chain = support_prompt | llm
utility_chain = utility_prompt | llm
def self_rag(query, vectorstore, top_k=5):
print(f"\nProcessing query: {query}")
# Step 1: Determine if retrieval is necessary
print("Step 1: Determining if retrieval is necessary...")
input_data = {"query": query}
retrieval_decision = retrieval_chain.invoke(input_data).response.strip().lower()
print(f"Retrieval decision: {retrieval_decision}")
if retrieval_decision == '是':
# Step 2: Retrieve relevant documents
print("Step 2: Retrieving relevant documents...")
#加入bm25混合判断
from rank_bm25 import BM25Okapi
bm25 = BM25Okapi(text)
# Step 2: Perform BM25 search
query_one = [i for i in query]
bm25_scores = bm25.get_scores(query_one)
#使用原始rag进行检索
import numpy as np
alpha = 0.5
all_docs = vectorstore.similarity_search("", k=vectorstore.index.ntotal)
docs = vectorstore.similarity_search_with_score(query, k=len(all_docs))
# Step 4: Normalize scores
vector_scores = np.array([score for _, score in docs])
vector_scores = 1 - (vector_scores - np.min(vector_scores)) / (np.max(vector_scores) - np.min(vector_scores))
bm25_scores = (bm25_scores - np.min(bm25_scores)) / (np.max(bm25_scores) - np.min(bm25_scores))
# Step 5: Combine scores
combined_scores = alpha * vector_scores + (1 - alpha) * bm25_scores
# Step 6: Rank documents
sorted_indices = np.argsort(combined_scores)[::-1]
contexts = [all_docs[i].page_content for i in sorted_indices[:3]]
docs = vectorstore.similarity_search(query, k=3)
contexts = [doc.page_content for doc in docs]
print(f"Retrieved {len(contexts)} documents")
# Step 3: Evaluate relevance of retrieved documents
print("Step 3: Evaluating relevance of retrieved documents...")
relevant_contexts = []
for i, context in enumerate(contexts):
input_data = {"query": query, "context": context}
relevance = relevance_chain.invoke(input_data).response.strip().lower()
print(f"Document {i+1} relevance: {relevance}")
if relevance == '相关':
relevant_contexts.append(context)
print(f"Number of relevant contexts: {len(relevant_contexts)}")
# If no relevant contexts found, generate without retrieval
if not relevant_contexts:
print("No relevant contexts found. Generating without retrieval...")
input_data = {"query": query, "context": "No relevant context found."}
return generation_chain.invoke(input_data).content
# Step 4: Generate response using relevant contexts
print("Step 4: Generating responses using relevant contexts...")
responses = []
for i, context in enumerate(relevant_contexts):
print(f"Generating response for context {i+1}...")
input_data = {"query": query, "context": context}
# print(generation_chain.invoke(input_data))
response = generation_chain.invoke(input_data).content
# Step 5: Assess support
print(f"Step 5: Assessing support for response {i+1}...")
input_data = {"response": response, "context": context}
support = support_chain.invoke(input_data).content.strip().lower()
print(f"Support assessment: {support}")
# Step 6: Evaluate utility
print(f"Step 6: Evaluating utility for response {i+1}...")
input_data = {"query": query, "response": response}
import re
input_str = utility_chain.invoke(input_data).content
print(input_str)
input_str = re.findall(r'\d+', input_str)
if input_str == []:
input_str = [0]
utility = int(input_str[-1])
print(f"Utility score: {utility}")
responses.append((response, support, utility))
# Select the best response based on support and utility
print("Selecting the best response...")
best_response = max(responses, key=lambda x: (x[1] == '完全支持', x[2]))
print(f"Best response support: {best_response[1]}, utility: {best_response[2]}")
return best_response[0]
else:
# Generate without retrieval
print("Generating without retrieval...")
input_data = {"query": query, "context": "No retrieval necessary."}
return generation_chain.invoke(input_data).content
query = "气功带是不是最强大专?"
response = self_rag(query, vector_store)
print("\nFinal response:")
print(response)
5504

被折叠的 条评论
为什么被折叠?



