【RAG】基于向量检索的 RAG (BGE示例)

RAG机器人 结构体

  • 文本向量化: 使用 BGE 模型将文档和查询编码为向量。
    (BGE 是专为检索任务优化的开源 Embedding 模型,除了本文API调用,也可以通过Hugging Face 本地部署BGE 开源模型)

  • 向量检索: 从数据库中找到与查询相关的文档片段。

  • 答案生成: 结合检索结果和用户输入,调用文心模型生成最终回答。

class RAG_Bot:
    def __init__(self, vector_db, llm_api, n_results=2):
        self.vector_db = vector_db
        self.llm_api = llm_api
        self.n_results = n_results

    def chat(self, user_query):
        # 1. 检索
        search_results = self.vector_db.search(user_query, self.n_results)

        # 2. 构建 Prompt
        prompt = build_prompt(
            prompt_template, context=search_results['documents'][0], query=user_query)

        # 3. 调用 LLM
        response = self.llm_api(prompt)
        return response
######

# 创建一个RAG机器人
bot = RAG_Bot(
    vector_db,
    llm_api=get_completion
)

user_query = "llama 2有多少参数?"

response = bot.chat(user_query)

print(response)

#####
llama 2有7B, 13B和70B参数。

MyVectorDBConnector:

自定义向量数据库,存储文档向量。
embedding_fn=get_embeddings_bge: 使用 BGE 模型生成向量。
add_documents(paragraphs): 向数据库中添加文档(已提前定义 paragraphs)。

RAG_Bot:

检索增强生成机器人,结合向量搜索与大模型生成。
chat(user_query): 执行“检索→生成”流程:
将用户查询向量化。
从数据库检索相关文档。
将检索结果作为上下文,调用文心模型生成回答。

使用国产模型

import json
import requests
import os

# 通过鉴权接口获取 access token


def get_access_token():
    """
    使用 AK,SK 生成鉴权签名(Access Token)
    :return: access_token,或是None(如果错误)
    """
    url = "https://aip.baidubce.com/oauth/2.0/token"
    params = {
        "grant_type": "client_credentials",
        "client_id": os.getenv('ERNIE_CLIENT_ID'),
        "client_secret": os.getenv('ERNIE_CLIENT_SECRET')
    }

    return str(requests.post(url, params=params).json().get("access_token"))

# 调用文心千帆 调用 BGE Embedding 接口


def get_embeddings_bge(prompts):
    url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_en?access_token=" + get_access_token()
    payload = json.dumps({
        "input": prompts
    })
    headers = {'Content-Type': 'application/json'}

    response = requests.request(
        "POST", url, headers=headers, data=payload).json()
    data = response["data"]
    return [x["embedding"] for x in data]


# 调用文心4.0对话接口
def get_completion_ernie(prompt):

    url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro?access_token=" + get_access_token()
    payload = json.dumps({
        "messages": [
            {
                "role": "user",
                "content": prompt
            }
        ]
    })

    headers = {'Content-Type': 'application/json'}

    response = requests.request(
        "POST", url, headers=headers, data=payload).json()

    return response["result"]

# 创建一个向量数据库对象
new_vector_db = MyVectorDBConnector(
    "demo_ernie",
    embedding_fn=get_embeddings_bge
)
# 向向量数据库中添加文档
new_vector_db.add_documents(paragraphs)

# 创建一个RAG机器人
new_bot = RAG_Bot(
    new_vector_db,
    llm_api=get_completion_ernie
)

user_query = "how many parameters does llama 2 have?"

response = new_bot.chat(user_query)

print(response)

拓展实践

1. 优化 Access Token 管理
  • 缓存 Token:减少鉴权接口调用次数,仅在 Token 过期时刷新。
  • 示例代码
    from datetime import datetime, timedelta
    
    class TokenManager:
        _token = None
        _expires_at = None
    
        @classmethod
        def get_token(cls):
            if cls._token is None or datetime.now() > cls._expires_at:
                cls._refresh_token()
            return cls._token
    
        @classmethod
        def _refresh_token(cls):
            url = "https://aip.baidubce.com/oauth/2.0/token"
            params = {
                "grant_type": "client_credentials",
                "client_id": os.getenv('ERNIE_CLIENT_ID'),
                "client_secret": os.getenv('ERNIE_CLIENT_SECRET')
            }
            response = requests.post(url, params=params)
            response.raise_for_status()
            data = response.json()
            cls._token = data["access_token"]
            # 默认 Token 有效期为 30 天,但建议按实际返回的 expires_in 设置
            cls._expires_at = datetime.now() + timedelta(seconds=data.get("expires_in", 2592000) - 300)  # 提前 5 分钟刷新
    
2. 增强错误处理与重试
  • 重试网络请求:使用 tenacity 库自动重试失败请求。
  • 捕获异常:明确处理常见错误(如网络超时、无效响应)。
  • 示例代码
    from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
    import requests.exceptions as req_exceptions
    
    @retry(
        stop=stop_after_attempt(3),
        wait=wait_exponential(multiplier=1, min=2, max=10),
        retry=retry_if_exception_type((req_exceptions.Timeout, req_exceptions.ConnectionError))
    )
    def safe_api_request(url, headers, payload):
        try:
            response = requests.post(url, headers=headers, data=payload, timeout=10)
            response.raise_for_status()
            return response.json()
        except req_exceptions.HTTPError as e:
            if response.status_code == 401:
                TokenManager._refresh_token()  # Token 可能过期,强制刷新
                raise
            raise ValueError(f"API 错误: {e.response.text}")
    
3. 验证环境变量
  • 启动时检查:确保关键配置已正确设置。
  • 示例代码
    def validate_env_vars():
        required_vars = ['ERNIE_CLIENT_ID', 'ERNIE_CLIENT_SECRET']
        missing_vars = [var for var in required_vars if not os.getenv(var)]
        if missing_vars:
            raise EnvironmentError(f"缺少环境变量: {', '.join(missing_vars)}")
    
    # 在程序初始化时调用
    validate_env_vars()
    
4. 优化向量数据库交互
  • 批量插入文档:减少 API 调用次数。
  • 分块策略:根据 Embedding 模型的最大输入长度分块文本。
  • 示例优化(假设使用 MyVectorDBConnector):
    class MyVectorDBConnector:
        def __init__(self, name, embedding_fn, chunk_size=512):
            self.embedding_fn = embedding_fn
            self.chunk_size = chunk_size  # 根据模型支持的最大长度设置
    
        def add_documents(self, documents):
            chunks = self._chunk_documents(documents)
            embeddings = self.embedding_fn(chunks)
            # 批量存储到向量数据库
    
        def _chunk_documents(self, documents):
            # 实现基于句子或固定长度的分块逻辑
            pass
    

优化后的代码示例

整合上述改进后的核心逻辑:

import os
import json
import logging
from datetime import datetime, timedelta
import requests
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
import requests.exceptions as req_exceptions

# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# 环境变量校验
def validate_env_vars():
    required_vars = ['ERNIE_CLIENT_ID', 'ERNIE_CLIENT_SECRET']
    missing_vars = [var for var in required_vars if not os.getenv(var)]
    if missing_vars:
        raise EnvironmentError(f"Missing env vars: {', '.join(missing_vars)}")
validate_env_vars()

# Token 管理
class TokenManager:
    _token = None
    _expires_at = None

    @classmethod
    def get_token(cls):
        if cls._token is None or datetime.now() > cls._expires_at:
            cls._refresh_token()
        return cls._token

    @classmethod
    def _refresh_token(cls):
        logger.info("Refreshing access token...")
        url = "https://aip.baidubce.com/oauth/2.0/token"
        params = {
            "grant_type": "client_credentials",
            "client_id": os.getenv('ERNIE_CLIENT_ID'),
            "client_secret": os.getenv('ERNIE_CLIENT_SECRET')
        }
        response = requests.post(url, params=params)
        response.raise_for_status()
        data = response.json()
        cls._token = data["access_token"]
        cls._expires_at = datetime.now() + timedelta(seconds=data.get("expires_in", 2592000) - 300)

# 安全 API 请求
@retry(
    stop=stop_after_attempt(3),
    wait=wait_exponential(multiplier=1, min=2, max=10),
    retry=retry_if_exception_type((req_exceptions.Timeout, req_exceptions.ConnectionError))
)
def safe_api_request(url, headers, payload):
    try:
        response = requests.post(url, headers=headers, data=payload, timeout=10)
        response.raise_for_status()
        return response.json()
    except req_exceptions.HTTPError as e:
        if response.status_code == 401:
            TokenManager._refresh_token()
            raise
        logger.error(f"API Error: {e.response.text}")
        raise

# 公共 API 调用封装
def call_ernie_api(endpoint, payload):
    base_url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop"
    url = f"{base_url}/{endpoint}?access_token={TokenManager.get_token()}"
    headers = {'Content-Type': 'application/json'}
    return safe_api_request(url, headers, json.dumps(payload))

# Embedding 接口
def get_embeddings_bge(prompts):
    logger.info(f"Generating embeddings for {len(prompts)} prompts")
    response = call_ernie_api("embeddings/bge_large_en", {"input": prompts})
    return [x["embedding"] for x in response["data"]]

# 文心 4.0 对话接口
def get_completion_ernie(prompt):
    logger.info(f"Generating completion for prompt: {prompt[:50]}...")
    response = call_ernie_api("chat/completions_pro", {
        "messages": [{"role": "user", "content": prompt}]
    })
    return response["result"]
<think>好的,我需要帮助用户了解或实现RAG检索增强生成)的相关技术。首先,我应该明确RAG的基本概念。根据引用[2],RAG结合了检索外部知识库和大语言模型的生成能力,类似于开卷考试,让模型在生成答案时可以参考外部信息。接下来,用户可能需要一个具体的实现步骤或教程,所以我得整理出一个清晰的流程。 第一步是确定知识库的来源。用户需要收集和预处理相关数据,比如PDF、网页文本或数据库内容。然后,将这些数据转化为嵌入向量,可能要用到像Sentence-BERT这样的模型。这里需要提到向量数据库的重要性,比如FAISS或Pinecone,用于高效检索。 接下来是检索阶段。当用户输入查询时,系统需要从知识库中找到最相关的文档片段。这部分需要说明如何计算相似度,比如余弦相似度,以及如何选择top-k个结果。 然后是生成部分。将检索到的上下文和用户问题一起输入大模型,比如GPT-3或Llama,生成最终答案。这里可以举例说明不同的提示模板,比如LangChain的RAG链。 此外,用户可能关心如何优化RAG系统。可以提到调整检索策略,比如混合搜索或重新排序,以及评估检索结果的质量指标。引用[3]提到的Langchain和llama_index作为工具,应该给出具体的代码示例,比如使用LangChain的VectorstoreIndexCreator。 还要注意常见问题,比如数据预处理不当导致检索效果差,或者大模型生成时的幻觉问题。需要强调评估和迭代的重要性,引用[1]提到的上下文过滤可能也是优化点。 最后,生成相关问题,覆盖RAG的原理、工具选择、评估方法、优化技巧和应用场景,帮助用户进一步探索。确保所有数学表达式使用正确的LaTeX格式,比如嵌入向量的距离计算$d = \sqrt{(x_2 - x_1)^2 + (y_2 - y_1)^2}$,并检查引用标记的正确添加。</think>以下分步说明如何实现一个基础的RAG系统,并结合代码示例和关键概念解析: ### 一、RAG实现核心步骤 1. **知识库构建** - 数据收集:从PDF/网页/数据库获取原始数据,格式如: ```python documents = [ "RAG通过检索外部知识增强生成...", "LangChain提供RAG工具链..." ] ``` - 文本分块:使用LangChain的`RecursiveCharacterTextSplitter` ```python from langchain.text_splitter import RecursiveCharacterTextSplitter text_splitter = RecursiveCharacterTextSplitter(chunk_size=500) chunks = text_splitter.split_documents(documents) ``` 2. **向量化处理** $$ \text{embed}(text) \rightarrow \mathbb{R}^d $$ 使用HuggingFace嵌入模型: ```python from sentence_transformers import SentenceTransformer encoder = SentenceTransformer('all-MiniLM-L6-v2') embeddings = encoder.encode(chunks) ``` 3. **向量数据库** 使用FAISS实现相似度检索: ```python import faiss index = faiss.IndexFlatL2(384) index.add(embeddings) ``` 4. **检索-生成流程** ```python query = "如何实现RAG系统?" query_embed = encoder.encode(query) D, I = index.search(query_embed, k=3) # 取top3 context = [chunks[i] for i in I[0]] ``` 5. **提示工程** LangChain的RAG提示模板: ```python from langchain.prompts import PromptTemplate template = """基于以下上下文回答: {context} 问题:{question}""" prompt = PromptTemplate(template=template, input_variables=["context", "question"]) ``` ### 二、关键技术组件 1. **检索优化** - 混合搜索:结合$BM25$与向量相似度 - 重新排序:使用`bge-reranker`等模型 2. **生成控制** 通过温度参数调节创造性: $$ P(w_t|w_{<t},c) = \text{softmax}(f(w_{<t},c)/\tau) $$ 3. **评估指标** - 检索召回率:$Recall@k = \frac{\text{相关文档命中数}}{总相关文档数}$ - 生成质量:BLEU、ROUGE等 ### 三、完整代码示例(LangChain实现) ```python from langchain.embeddings import HuggingFaceEmbeddings from langchain.vectorstores import FAISS from langchain.llms import LlamaCpp # 知识库构建 embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") vector_store = FAISS.from_texts(chunks, embeddings) # 检索增强生成 retriever = vector_store.as_retriever(search_kwargs={"k": 3}) llm = LlamaCpp(model_path="llama-2-7b.Q4_K_M.gguf") from langchain.chains import RetrievalQA qa_chain = RetrievalQA.from_chain_type( llm=llm, retriever=retriever, chain_type="stuff" ) print(qa_chain.run("RAG的实现步骤有哪些?")) ``` ### 四、常见问题与优化 1. **冷启动问题**:可使用Wikipedia预构建索引 2. **长上下文处理**:采用`sliding window`分块策略 3. **多模态扩展**:CLIP支持图像检索增强
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

星星点点洲

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值