RAG初步实现

该文章已生成可运行项目,

一、RAG简介

RAG(Retrieval-Augmented Generation,检索增强生成)是一种结合了信息检索技术与语言生成模型的人工智能技术。它通过从外部知识库中检索相关信息,并将其作为提示(Prompt)输入给大型语言模型(LLMs),以增强模型处理知识密集型任务的能力 。

用人话来讲,RAG检索就是通过模型将数据集编码,用户在提问之后,模型将用户的问题也进行编码,然后从编码好的数据集中寻找与用户问题编码相似的内容。找到相似内容之后,模型会根据你写的prompt参考检索的内容来组织答案。

二、代码实现

下面是一个简单的RAG实例:

from langchain.prompts import PromptTemplate
from langchain_openai import ChatOpenAI
from langchain_core.pydantic_v1 import BaseModel, Field

from langchain.text_splitter import CharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_core.embeddings import Embeddings
from langchain.vectorstores import FAISS
from sentence_transformers import SentenceTransformer
import torch
import os

from langchain_community.document_loaders import UnstructuredWordDocumentLoader
from langchain.retrievers.document_compressors import DocumentCompressorPipeline,EmbeddingsFilter
from langchain_community.document_transformers import EmbeddingsRedundantFilter

from file_loader import CustomDocumentLoader

from langchain.prompts import PromptTemplate
from langchain_openai import ChatOpenAI
from langchain_core.pydantic_v1 import BaseModel, Field

device = "cuda:0" # the device to load the model onto
#加载embedding模型
model_name = "/home/media4090/wz/model/vllm/Xorbits/bge-large-zh-v1.5"
hf = HuggingFaceEmbeddings(model_name=model_name, model_kwargs={'device': device})

def get_files_in_folder(folder):
    file_paths = []  # 用于存储文件路径的列表
    # 遍历文件夹中的所有文件和子文件夹
    for root, directories, files in os.walk(folder):
        for filename in files:
            file_path = os.path.abspath(os.path.join(root, filename))
            file_paths.append(file_path)
    return file_paths

# 加载数据
def get_data_folder(data_folder_path):
    file_paths = get_files_in_folder(data_folder_path)
    #循环判断所有文件的类别
    all_data = []
    document_data = []
    for file_path in file_paths:
        # 获取文件名
        file_name = os.path.basename(file_path)
        base_extension = file_name.split('.')
        if base_extension[-1] == 'csv':
            #仅仅对csv这种结构化数据进行处理
            loader = CustomDocumentLoader(file_path,'query')
            all_data.extend(loader.load())
        elif base_extension[-1] == 'docx':
            doc_loader = UnstructuredWordDocumentLoader(file_path)
            document_data.extend(doc_loader.load())
    return all_data, document_data

#为csv文件的向量库进行存储到本地
csv_data , _ = get_data_folder('./data/')
text_splitter_csv = CharacterTextSplitter(separator="\n\n", chunk_size=256, chunk_overlap=20)
split_csv = text_splitter_csv.split_documents(csv_data)

text = []
for i in split_csv:        
    text.append([j for j in i.page_content])

from langchain.storage import LocalFileStore
from langchain.embeddings import CacheBackedEmbeddings
store = LocalFileStore("./cache/")
cached_embedder = CacheBackedEmbeddings.from_bytes_store(
    hf, store, namespace='csv_embeddings'
)
vector_store = FAISS.from_documents(split_csv, cached_embedder)

llm = ChatOpenAI(model="/home/media4090/wz/model/qwen/qwen2-7b-instruct", max_tokens=1000, temperature=0, base_url="http://127.0.0.1:9997/v1",api_key="not-needed")

#构建self-rag的过程
class RetrievalResponse(BaseModel):
    response: str = Field(..., title="Determines if retrieval is necessary", description="Output only 'Yes' or 'No'.")
retrieval_prompt = PromptTemplate(
    input_variables=["query"],
    template="给定一个问题 '{query}', 确定是否需要检索。仅输出“是”或“否”。"
)

class RelevanceResponse(BaseModel):
    response: str = Field(..., title="Determines if context is relevant", description="Output only 'Relevant' or 'Irrelevant'.")
relevance_prompt = PromptTemplate(
    input_variables=["query", "context"],
    template="给定一个问题 '{query}' 以及对应的资料 '{context}', 确定上下文是否相关。仅输出“相关”或“不相关”。"
)

generation_prompt = PromptTemplate(
    input_variables=["query", "context"],
    template="给定一个问题 '{query}' 以及对应的资料 '{context}', 依据上述资料对问题进行回答,尽可能不使用超出资料的知识。"
)

support_prompt = PromptTemplate(
    input_variables=["response", "context"],
    template="给定一个问题 '{response}' 以及对应的资料 '{context}', 确定响应是否得到上下文的支持。输出“完全支持”、“部分支持”或“不支持”。",
    validate_template=True
)

utility_prompt = PromptTemplate(
    input_variables=["query", "response"],
    template="给定一个问题 '{query}' 以及对应的资料 '{response}'。 判断资料是否对回答问题有帮助,并在1-5之间进行打分。",
)

# Create LLMChains for each step
retrieval_chain = retrieval_prompt | llm.with_structured_output(RetrievalResponse)
relevance_chain = relevance_prompt | llm.with_structured_output(RelevanceResponse)
generation_chain = generation_prompt | llm
support_chain = support_prompt | llm
utility_chain = utility_prompt | llm


def self_rag(query, vectorstore, top_k=5):
    print(f"\nProcessing query: {query}")
    
    # Step 1: Determine if retrieval is necessary
    print("Step 1: Determining if retrieval is necessary...")
    input_data = {"query": query}
    retrieval_decision = retrieval_chain.invoke(input_data).response.strip().lower()
    print(f"Retrieval decision: {retrieval_decision}")
    
    if retrieval_decision == '是':
        # Step 2: Retrieve relevant documents
        print("Step 2: Retrieving relevant documents...")

        #加入bm25混合判断
        from rank_bm25 import BM25Okapi
        bm25 = BM25Okapi(text)
        # Step 2: Perform BM25 search
        query_one = [i for i in query]
        bm25_scores = bm25.get_scores(query_one)

        #使用原始rag进行检索
        import numpy as np
        alpha = 0.5
        all_docs = vectorstore.similarity_search("", k=vectorstore.index.ntotal)
        docs = vectorstore.similarity_search_with_score(query, k=len(all_docs))
        # Step 4: Normalize scores
        vector_scores = np.array([score for _, score in docs])
        vector_scores = 1 - (vector_scores - np.min(vector_scores)) / (np.max(vector_scores) - np.min(vector_scores))
        bm25_scores = (bm25_scores - np.min(bm25_scores)) / (np.max(bm25_scores) - np.min(bm25_scores))
        # Step 5: Combine scores
        combined_scores = alpha * vector_scores + (1 - alpha) * bm25_scores  
        # Step 6: Rank documents
        sorted_indices = np.argsort(combined_scores)[::-1]
        contexts = [all_docs[i].page_content for i in sorted_indices[:3]]

        docs = vectorstore.similarity_search(query, k=3)
        contexts = [doc.page_content for doc in docs]
        print(f"Retrieved {len(contexts)} documents")
        
        # Step 3: Evaluate relevance of retrieved documents
        print("Step 3: Evaluating relevance of retrieved documents...")
        relevant_contexts = []
        for i, context in enumerate(contexts):
            input_data = {"query": query, "context": context}
            relevance = relevance_chain.invoke(input_data).response.strip().lower()
            print(f"Document {i+1} relevance: {relevance}")
            if relevance == '相关':
                relevant_contexts.append(context)
        
        print(f"Number of relevant contexts: {len(relevant_contexts)}")
        
        # If no relevant contexts found, generate without retrieval
        if not relevant_contexts:
            print("No relevant contexts found. Generating without retrieval...")
            input_data = {"query": query, "context": "No relevant context found."}
            return generation_chain.invoke(input_data).content
        
        # Step 4: Generate response using relevant contexts
        print("Step 4: Generating responses using relevant contexts...")
        responses = []
        for i, context in enumerate(relevant_contexts):
            print(f"Generating response for context {i+1}...")
            input_data = {"query": query, "context": context}
            # print(generation_chain.invoke(input_data))
            response = generation_chain.invoke(input_data).content
            
            # Step 5: Assess support
            print(f"Step 5: Assessing support for response {i+1}...")
            input_data = {"response": response, "context": context}
            support = support_chain.invoke(input_data).content.strip().lower()
            print(f"Support assessment: {support}")
            
            # Step 6: Evaluate utility
            print(f"Step 6: Evaluating utility for response {i+1}...")
            input_data = {"query": query, "response": response}
            import re
            input_str = utility_chain.invoke(input_data).content
            print(input_str)
            input_str = re.findall(r'\d+', input_str)
            if input_str == []:
                input_str = [0]
            utility = int(input_str[-1])
            print(f"Utility score: {utility}")
            
            responses.append((response, support, utility))
        
        # Select the best response based on support and utility
        print("Selecting the best response...")
        best_response = max(responses, key=lambda x: (x[1] == '完全支持', x[2]))
        print(f"Best response support: {best_response[1]}, utility: {best_response[2]}")
        return best_response[0]
    else:
        # Generate without retrieval
        print("Generating without retrieval...")
        input_data = {"query": query, "context": "No retrieval necessary."}
        return generation_chain.invoke(input_data).content
    
query = "气功带是不是最强大专?"
response = self_rag(query, vector_store)

print("\nFinal response:")
print(response)

本文章已经生成可运行项目
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值