RAG原理及本地化实践

基于LLM的应用在特定领域存在保密性和模型幻觉问题,RAG技术可有效提高LLM准确性。本文介绍了RAG原理,它包含数据获取、增强和生成三个步骤。还通过simple-local-rag项目进行本地化实践,涵盖数据准备、编码、获取,加载LLM,提示增强和生成答案等步骤。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

基于LLM的应用在问题回答、信息获取上发挥出了巨大作用。这些通用大模型训练的数据主要来源于互联网上的会话或者个别机构提供的数据,虽然能够提供类似人的交互对答,但是在针对某个特定领域的时候就显得不足。通用大模型在应用中主要有以下问题:

  1. 保密性:对于数据安全敏感的用户,并不希望将自己的数据提供出来给通用模型进行训练使用,而通用大模型没有这些数据的输入,就无法提供满足需求的回答。
  2. 模型幻觉:这是大模型的固有不足。典型的“句句有回应,事事没着落”似的答非所问,原因是通用大模型的训练数据太过宽泛,没有聚焦用户期望的特定场景。

是否有一种技术,能够有效提高LLM的准性,让它能够在用户期望的领域内进行信息获取和反馈?答案是有的,这个技术就是RAG。

RAG原理简介

参考资料1中,Patrick Lewis等人提出了一个基于LLM的新架构,并命名为RAG。RAG不是一种新的深度学习网络模型,我更倾向于将它理解为一种基于LLM模型,并组合了其他技术的一种新型框架。

RAG是Retrieval-Augmented Generation的首字母缩写,同时这三个单词也分别表示了这个技术处理的三个步骤,分别的意思是:

  • Retrieval:数据的获取部分。根据用户输入的query数据,再从用户提供的资料中查找出跟query相关的内容。
  • Augmented:数据增强部分。根据Retrieval部分获取的数据对用户的query输入进行一定的修改,修改后的增强输入再提供给LLM。
  • Generation:这是LLM的传统艺能部分了,就是根据前面增强的输入来产生输出。这时候的输出就跟用户的期望领域关联度相当高了。

RAG的架构示意图如下:

图中输入查询(x),获取预测(y)。过程中将q(x)进行嵌入编码(embedding),并通过MIPS计算它和之前嵌入编码保存的d(z)之间的相似度,获取top_k的相关d(z_{i})集合作为输入的增强,然后输入生成器也就是LLM中(P_{\theta }),最终生成预测(y)。

RAG的本地化实践

RAG的技术目前已经有一些支持的三方框架,比如LlamaIndexLangChain,本文借用github中的simple-local-rag项目,介绍最基础的RAG实现。原文中的代码是基于notebook文档,好处是可以一步步介绍知识点并结合说明代码。但是学习代码时非常不方便直观,所以本文将代码进行了封装重构。读者如果想详细了解每一步的讲解,可以在参考资料2的github项目仓库中点开colab直接学习,如果想大致学习RAG的编码实践,可以继续往下阅读本文。

根据参考资料2的代码,本人改写了一份代码已提交到github(https://github.com/yyaaron/simple-rag-practice

数据准备

  • 下载pdf文档,pdf是用于增强prompt的内容来源。本文举例的是一本营养学pdf文档,从https://pressbooks.oer.hawaii.edu/humannutrition2/open/download?type=pdf 下载。FileDownloader.download_if_not_exist方法实现下载到本地当前文件夹下。
  • 切分页面内容,便于后面进一步编码embedding处理。
    • FilePreprocessor.open_and_read_pdf方法加载文档中的每一页内容,并将句子10句一组放到chunk中。
    • FilePreprocessor.pages_and_chunks将这些句子分成指定长度的一个个chunk。为什么要分成chunk?因为模型的输入也是有要求的,这就要求我们在输入前先将内容切片。同时这个方法将过短的内容忽略掉,因为这些基本上是书本里的衔接语句或者页面信息,没有实质内容。
"""
Download file and process it for later use
"""
import os
import random
import re

import requests
import fitz
from tqdm.auto import tqdm
from spacy.lang.en import English


class FileDownloader:
    file_name = 'human-nutrition-text.pdf'
    download_url = 'https://pressbooks.oer.hawaii.edu/humannutrition2/open/download?type=pdf'

    def __init__(self, file_name: str = 'human-nutrition-text.pdf',
                 download_url: str = 'https://pressbooks.oer.hawaii.edu/humannutrition2/open/download?type=pdf'):
        self.download_url = download_url
        self.file_name = file_name

    @classmethod
    def download_if_not_exist(cls, file_name: str, download_url: str):
        if file_name == "" or file_name is None:
            file_name = cls.file_name

        if download_url == "" or download_url is None:
            download_url = cls.download_url

        if not os.path.exists(file_name):
            # download file from download_url
            response = requests.get(download_url)
            if response.status_code == 200:
                with open(file_name, "wb") as file:
                    file.write(response.content)

                print("[INFO]File downloaded with name", file_name)
            else:
                print("[INFO]Exception occurs while downloading")
        else:
            print("[INFO]File", file_name, "exists, ignore downloading")


class FilePreprocessor:
    num_sentence_chunk_size = 10

    def __init__(self):
        pass

    @classmethod
    def clean_text(cls, text) -> str:
        return text.replace("\n", " ").strip()

    @classmethod
    def text_to_sentences(cls, text) -> list:
        nlp = English()
        nlp.add_pipe("sentencizer")

        doc = nlp(text)
        return list(doc.sents)

    @classmethod
    def split_list(cls, input_list: list, slice_size: int) -> list[list[str]]:
        return [input_list[i: i + slice_size] for i in range(0, len(input_list), slice_size)]

    @classmethod
    def open_and_read_pdf(cls, file_path: str) -> list[dict]:
        doc = fitz.open(file_path)
        pages_and_texts = []
        print("[INFO]Loading content into memory...")
        for page_number, page in tqdm(enumerate(doc)):
            text = page.get_text()
            text = cls.clean_text(text)
            pages_and_texts.append(
                {"page_number": page_number - 41,
                 "page_char_count": len(text),
                 "page_word_count": len(text.split(" ")),
                 "page_sentence_count_raw": len(text.split(". ")),
                 "page_token_count": len(text) / 4,
                 "text": text
                 })

        print("[INFO]Make sentences into chunks...")
        for item in tqdm(pages_and_texts):
            item["sentences"] = cls.text_to_sentences(item["text"])
            item["sentences"] = [str(sentence) for sentence in item["sentences"]]
            item["page_sentence_count_spacy"] = len(item["sentences"])
            item["sentence_chunks"] = cls.split_list(input_list=item["sentences"],
                                                     slice_size=cls.num_sentence_chunk_size)
            item["num_chunks"] = len(item["sentence_chunks"])

        print(f"[INFO]Finish loading file({file_path})")
        return pages_and_texts

    @classmethod
    def pages_and_chunks(cls, pages_and_texts: list[dict], min_token_len: int = 30) -> list[dict]:
        print(f"[INFO]Divide chunks into specified-size pieces and filter out small chunks(size <= {min_token_len})...")
        pages_and_chunks = []
        for item in tqdm(pages_and_texts):
            for sentence_chunk in item["sentence_chunks"]:
                chunk_dict = {}
                chunk_dict["page_number"] = item["page_number"]

                joined_sentence_chunk = "".join(sentence_chunk).replace("  ", " ").strip()
                # ".A" -> ". A" for any full-stop/capital letter combo
                joined_sentence_chunk = re.sub(r'\.([A-Z])', r'. \1', joined_sentence_chunk)

                chunk_token_count = len(joined_sentence_chunk) / 4  # 1 token = ~4 characters
                if chunk_token_count <= min_token_len:
                    continue  # small token sentences are always headers and footers, so filter them out as are useless
                chunk_dict["chunk_token_count"] = chunk_token_count
                chunk_dict["sentence_chunk"] = joined_sentence_chunk
                chunk_dict["chunk_char_count"] = len(joined_sentence_chunk)
                chunk_dict["chunk_word_count"] = len([word for word in joined_sentence_chunk.split(" ")])

                pages_and_chunks.append(chunk_dict)

        print("[INFO]Finish processing chunks for each pages")

        return pages_and_chunks

数据编码

  • embedding文档中的内容。因为大模型并不是直接从文本中进行计算学习,这就需要一种对文本的数学表示方法。

并不能自己随便创建一个编码方式,比如0表示a、1表示b以此类推,因为这并不能达到算法的有效性。因此编码embedding也是有自己的模型来计算的,文中使用了"all-mpnet-base-v2"。

import time
from typing import Union, List

import torch
from numpy import ndarray
from sentence_transformers import SentenceTransformer
from torch import Tensor


class Embedding:
    model_name = ""
    device = ""
    embedding_model = None

    def __init__(self, model_name="all-mpnet-base-v2", device="cuda" if torch.cuda.is_available() else "cpu"):
        self.model_name = model_name
        self.device = device
        self.embedding_model = SentenceTransformer(model_name_or_path=model_name, device=device)

    def encode(self, sentences: str | list[str], batch_size: int = 32,
               convert_to_tensor: bool = False) -> Union[List[Tensor], ndarray, Tensor]:
        embeddings = self.embedding_model.encode(sentences, batch_size=batch_size,
                                                 convert_to_tensor=convert_to_tensor, device=self.device)

        return embeddings

编码后的文本内容就是一个个数值向量,以不再适合人的阅读,但正是机器所需要的。可以将embedding的内容存储到任何向量数据库,也可以保存在本地csv文件中。本文是用csv文件保存,因为数据量较小。

编码数据获取

与前面一步对应,从csv文件中获取编码内容加载到内存或闪存。

import pandas as pd
import numpy as np
import torch


class EmbeddingLoader:
    device = ""
    file_name = ""

    def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu",
                 file_name="text_chunks_and_embeddings_df.csv"):
        self.device = device
        self.file_name = file_name

    @classmethod
    def load(cls, file_path: str = "text_chunks_and_embeddings_df.csv",
             device="cuda" if torch.cuda.is_available() else "cpu"):
        print(f"[INFO]Load embeddings of texts from file:{file_path}...")
        text_chunks_and_embedding_df = pd.read_csv(file_path)
        text_chunks_and_embedding_df["embedding"] = text_chunks_and_embedding_df["embedding"].apply(
            lambda x: np.fromstring(x.strip("[]"), sep=" "))

        pages_and_chunks = text_chunks_and_embedding_df.to_dict(orient="records")

        embeddings = (torch.tensor(np.array(text_chunks_and_embedding_df["embedding"].tolist()), dtype=torch.float32)
                      .to(device))
        print(f"[INFO]Finish loading embeddings of texts")
        return pages_and_chunks, embeddings

加载LLM

经过前期的数据准备和编码,这里需要选择LLM进行问题答案的生成。

  • LLM的选择,根据你本地机器的cpu或gpu资源能支撑的参数,以及适合的业务场景来选择LLM。本文运行环境是rtx 3060,所以最终选择了gemma-2b-it。
  • 下载LLM,AutoTokenizer.from_pretrained下载tokenizer,AutoModelForCausalLM.from_pretrained下载LLM大模型。本文是直接使用了huggingface的transformers方法,这需要登陆huggingface的模型卡页面获取授权,并在setting页面取得huggingface的token进行login操作(参考这里)。
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.utils import is_flash_attn_2_available
from transformers import BitsAndBytesConfig

from rag_utils import RagUtils


class LlmUtils:
    device = "cuda" if torch.cuda.is_available() else "cpu"

    def __init__(self):
        pass

    @classmethod
    def get_gpu_mem_size(cls) -> int:
        gpu_memory_bytes = torch.cuda.get_device_properties(0).total_memory
        gpu_memory_gb = round(gpu_memory_bytes / (2 ** 30))
        return gpu_memory_gb

    @classmethod
    def mode_selector(cls) -> (bool, str):
        gpu_mem_size = cls.get_gpu_mem_size()

        use_quantization_config = False
        model_id = ""
        if gpu_mem_size < 5.1:
            print(
                f"Your available GPU memory is {gpu_mem_size}GB, you may not able to run Gemma locally without quantization")
        elif gpu_mem_size < 8.1:
            print(f"GPU memory: {gpu_mem_size}GB | Recommended model: Gemma 2B in 4-bit precision")
            use_quantization_config = True
            model_id = "google/gemma-2b-it"
        elif gpu_mem_size < 19.0:
            print(
                f"GPU memory: {gpu_mem_size}GB | Recommended model: Gemma 2B in float16 or Gemma 7B in 4-bit precision")
            use_quantization_config = False
            model_id = "google/gemma-2b-it"
        elif gpu_mem_size >= 19.0:
            print(f"GPU memory: {gpu_mem_size}GB | Recommended model: Gemma 7B in 4-bit or float16 precision")
            use_quantization_config = False
            model_id = "google/gemma-7b-it"

        print(f"[INFO]use_quantization_config set to: {use_quantization_config}\n model_id set to: {model_id}")
        return use_quantization_config, model_id

    @classmethod
    def init_model(cls):
        quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)

        if (is_flash_attn_2_available()) and (torch.cuda.get_device_capability(0)[0] >= 8):
            attn_implementation = "flash_attention_2"
        else:
            attn_implementation = "sdpa"
        print(f"[INFO]Attention implementation set to \"{attn_implementation}\"")

        use_quantization_config, model_id = cls.mode_selector()
        print(f"use_quantization_config: {use_quantization_config}, model_id: {model_id}")

        tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_id)

        llm_model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=model_id,
                                                         torch_dtype=torch.float32,
                                                         quantization_config=quantization_config if use_quantization_config else None,
                                                         low_cpu_mem_usage=False,
                                                         attn_implementation=attn_implementation,
                                                         )

        if not use_quantization_config and torch.cuda.is_available():
            llm_model.to("cuda")

        return tokenizer, llm_model

    @classmethod
    def get_model_num_params(cls, model: torch.nn.Module):
        return sum([param.numel() for param in model.parameters()])

    @classmethod
    def get_model_mem_size(cls, model: torch.nn.Module):
        mem_params = sum([param.nelement() * param.element_size() for param in model.parameters()])
        mem_buffers = sum([buf.nelement() * buf.element_size() for buf in model.buffers()])

        model_mem_bytes = mem_params + mem_buffers
        model_mem_mb = model_mem_bytes / (1024 ** 2)
        model_mem_gb = model_mem_bytes / (1024 ** 3)

        return {"model_mem_bytes": model_mem_bytes,
                "model_mem_mb": round(model_mem_mb, 2),
                "model_mem_gb": round(model_mem_gb, 2)}

    @classmethod
    def gen_prompt_with_context(cls, query: str, context_items: list[dict], tokenizer):
        print(f"input text: \n{query}")

        context = "- " + "\n- ".join([item["sentence_chunk"] for item in context_items])

        base_prompt = """Based on the following context items, please answer the query.
        Give yourself room to think by extracting relevant passages from the context before answering the query.
        Don't return the thinking, only return the answer.
        Make sure your answers are as explanatory as possible.
        Use the following examples as reference for the ideal answer style.
        \nExample 1:
        Query: What are the fat-soluble vitamins?
        Answer: The fat-soluble vitamins include Vitamin A, Vitamin D, Vitamin E, and Vitamin K. These vitamins are absorbed along with fats in the diet and can be stored in the body's fatty tissue and liver for later use. Vitamin A is important for vision, immune function, and skin health. Vitamin D plays a critical role in calcium absorption and bone health. Vitamin E acts as an antioxidant, protecting cells from damage. Vitamin K is essential for blood clotting and bone metabolism.
        \nExample 2:
        Query: What are the causes of type 2 diabetes?
        Answer: Type 2 diabetes is often associated with overnutrition, particularly the overconsumption of calories leading to obesity. Factors include a diet high in refined sugars and saturated fats, which can lead to insulin resistance, a condition where the body's cells do not respond effectively to insulin. Over time, the pancreas cannot produce enough insulin to manage blood sugar levels, resulting in type 2 diabetes. Additionally, excessive caloric intake without sufficient physical activity exacerbates the risk by promoting weight gain and fat accumulation, particularly around the abdomen, further contributing to insulin resistance.
        \nExample 3:
        Query: What is the importance of hydration for physical performance?
        Answer: Hydration is crucial for physical performance because water plays key roles in maintaining blood volume, regulating body temperature, and ensuring the transport of nutrients and oxygen to cells. Adequate hydration is essential for optimal muscle function, endurance, and recovery. Dehydration can lead to decreased performance, fatigue, and increased risk of heat-related illnesses, such as heat stroke. Drinking sufficient water before, during, and after exercise helps ensure peak physical performance and recovery.
        \nNow use the following context items to answer the user query:
        {context}
        \nRelevant passages: <extract relevant passages from the context here>
        User query: {query}
        Answer:"""

        base_prompt = base_prompt.format(context=context, query=query)

        dialogue_template = [
            {"role": "user",
             "content": base_prompt}
        ]

        prompt = tokenizer.apply_chat_template(conversation=dialogue_template,
                                               tokenize=False,
                                               add_generation_prompt=True)
        return prompt

    @classmethod
    def ask(cls, query, embeddings, embedding_model, pages_and_chunks, model, tokenizer,
            temperature=0.7,
            max_new_tokens=512,
            format_answer_text=True,
            return_answer_only=True):
        scores, indices = RagUtils.retrieve_relevant_resources(query=query,
                                                               embeddings=embeddings, model=embedding_model)
        context_items = [pages_and_chunks[i] for i in indices]

        for i, item in enumerate(context_items):
            item["score"] = scores[i].cpu()

        prompt = cls.gen_prompt_with_context(query=query, context_items=context_items, tokenizer=tokenizer)

        input_ids = tokenizer(prompt, return_tensors="pt").to(cls.device)

        outputs = model.generate(**input_ids,
                                 temperature=temperature,
                                 do_sample=True,
                                 max_new_tokens=max_new_tokens)

        output_text = tokenizer.decode(outputs[0])

        if format_answer_text:
            output_text = output_text.replace(prompt, "").replace("<bos>", "").replace("<eos>", "").replace(
                "Sure, here's the answer to the user's query:\n\n", "")

        if return_answer_only:
            return output_text

        return output_text, context_items

提示增强

  • 相似查找,根据用户输入的查询,从pdf内容中去查找相似内容的top_k chunks。RagUtils.retrieve_relevant_resources方法用于计算相似度,在LlmUtils.ask会调用的此方法查找相似上下文。
  • 生成增强提示查询,LlmUtils.gen_prompt_with_context方法融合用户输入的query和根据前一步查找到的top_k相似内容,生成一个增强的包含上下文内容的提示语prompt。

生成答案

LlmUtils.ask方法在生成增强提示后,调用llm的generate方法生成答案。

总入口main.py

下面代码调用前面的各个步骤,生成最终包含本地内容的答案。

import numpy as np
import torch.cuda

import rag_utils
from embedding_loader import EmbeddingLoader
from file_processor import FileDownloader, FilePreprocessor
from embedding import Embedding
from rag_utils import RagUtils
from llm_utils import LlmUtils
import pandas as pd

query = "macronutrients functions"

if __name__ == '__main__':
    # ***********************************
    # STEP 1: File preprocess
    # ***********************************

    # download example PDF file
    file_name = 'human-nutrition-text.pdf'
    download_url = 'https://pressbooks.oer.hawaii.edu/humannutrition2/open/download?type=pdf'
    FileDownloader().download_if_not_exist(file_name=file_name, download_url=download_url)

    # by default, we use a pdf file named "human-nutrition-text.pdf"
    pages_and_texts = FilePreprocessor.open_and_read_pdf(file_name)

    # turn texts into chunks, and set minimum size of tokens per chunk as 30, as they are always meaningless
    pages_and_trunk_over_min_size = FilePreprocessor.pages_and_chunks(pages_and_texts=pages_and_texts, min_token_len=30)

    # ***********************************
    # STEP 2: Embedding text chunks
    # ***********************************

    # embedding text chunks
    device = "cuda" if torch.cuda.is_available() else "cpu"  # set device "cuda" if gpu available
    embedding_tool = Embedding(model_name="all-mpnet-base-v2", device=device)
    print(f"[INFO]Embedding sentences by model {embedding_tool.model_name}, on device({embedding_tool.device})...")
    for item in pages_and_trunk_over_min_size:
        item["embedding"] = embedding_tool.encode(item["sentence_chunk"], batch_size=32, convert_to_tensor=False)
    print(f"[INFO]Finish embedding.")

    # ***********************************
    # STEP 3: Save embeddings into csv
    # ***********************************

    # embeddings can be saved in any vector database, but here we use csv file for simplicity
    text_chunks_and_embeddings_df = pd.DataFrame(pages_and_trunk_over_min_size)
    embeddings_df_save_path = "text_chunks_and_embeddings_df.csv"
    text_chunks_and_embeddings_df.to_csv(embeddings_df_save_path, index=False)

    # ***********************************
    # STEP 4: Load embeddings from CSV
    # ***********************************

    # text_chunks_and_embedding_df = pd.read_csv(embeddings_df_save_path)
    # text_chunks_and_embedding_df["embedding"] = text_chunks_and_embedding_df["embedding"].apply(
    #                                                                 lambda x: np.fromstring(x.strip("[]"), sep=" "))
    # pages_and_chunks = text_chunks_and_embedding_df.to_dict(orient="records")
    # embeddings = (torch.tensor(np.array(text_chunks_and_embedding_df["embedding"].tolist()), dtype=torch.float32)
    #               .to(device))

    pages_and_chunks, embeddings = EmbeddingLoader.load(file_path=embeddings_df_save_path, device=device)

    # ***********************************
    # STEP 5: Load LLM
    # ***********************************

    # you can select any LLM according to your scenario, here we select google/gemma-2b-it for demo use
    # LLM can be downloaded from huggingface or kaggle.
    # here we download LLM from huggingface, you should have account of it and login to find your token.
    # access token setting page: https://huggingface.co/settings/tokens
    tokenizer, llm = LlmUtils.init_model()

    # ***********************************
    # STEP 6: generate answer
    # ***********************************

    # at last, we ask LLM with query and contexts with the highest relevance.
    # relevance between context and query is computed in RagUtils.retrieve_relevant_resources
    query = "What are the macronutrients, and what roles do they play in the human body?"
    answer = LlmUtils.ask(query=query, embeddings=embeddings, embedding_model=embedding_tool.embedding_model,
                          pages_and_chunks=pages_and_chunks, model=llm, tokenizer=tokenizer,
                          return_answer_only=True)

    print(f"Answer:\n{answer}")

参考资料:

1. Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks

2. github repo: simple-local-rag

3. YouTube: Local Retrieval Augmented Generation (RAG) from Scratch (step by step tutorial)https://www.youtube.com/watch?v=qN_2fnOPY-M

<think>好的,我现在需要处理用户关于RAG模型在本地环境下的实现或部署的问题。首先,我要确保自己正确理解RAG模型的基本概念。检索增强生成(RAG)结合了检索和生成两部分,通常使用一个检索器从外部知识库中获取相关信息,然后由生成器基于这些信息生成响应。用户想要在本地部署,这可能涉及到选择合适的模型、设置检索系统、处理数据等方面的步骤。 接下来,我需要考虑用户可能的背景。他们可能有一定的技术基础,但需要具体的步骤指导。因此,回答需要结构清晰,分步骤说明,并提供实际的工具或库的例子,比如使用Hugging Face的Transformers库,或者FAISS进行向量检索。 然后,关于本地化部署,可能需要讨论硬件需求,比如GPU是否必要,以及如何优化资源使用。例如,使用较小的模型如BERT-base进行检索,生成部分可能用更轻量的模型如DistilGPT-2,以减少内存消耗。 数据准备也是一个关键点。用户可能需要知道如何构建本地知识库,格式要求,以及如何处理这些数据,比如分块、向量化。这里需要提到文本预处理,使用sentence-transformers生成嵌入,并用FAISS建立索引。 在实现步骤方面,需要分阶段说明:安装依赖、准备数据、构建检索器、搭建生成器、整合RAG流程。每个步骤给出代码示例或命令行指令,比如用pip安装必要的库,或者加载预训练模型的代码片段。 此外,性能优化和常见问题也是用户关心的。可能需要讨论缓存机制、模型量化、批处理等技术,以及如何处理检索结果不相关或生成内容不准确的问题。例如,调整检索器的top_k参数,或者对生成结果进行后处理。 最后,生成相关问题时要覆盖部署的硬件需求、提升效果的方法、处理中文的注意事项、与微调的区别、以及本地部署的优缺点。这些问题可以帮助用户进一步探索相关主题,并解决可能遇到的挑战。 需要确保所有技术细节准确,比如提到的模型名称和库的版本是否正确。同时,避免使用过于专业的术语,保持解释的易懂性。还要检查步骤的逻辑顺序是否合理,是否有遗漏的关键点,比如环境配置或错误处理。 可能还需要考虑用户的实际应用场景,比如是在企业内部使用还是个人项目,这可能会影响部署的复杂度和资源分配。不过用户没有具体说明,所以保持一般性的建议。 总结下来,回答需要涵盖部署步骤、工具选择、数据处理、模型整合、优化技巧和常见问题解决方案,结构清晰,并提供代码示例和实用建议,确保用户能够按照指导在本地环境中实现RAG模型。</think>### RAG模型的本地化部署步骤详解 检索增强生成(RAG)模型的本地部署需要整合检索系统与生成模型,以下是具体实现步骤: #### 一、核心组件选择 1. **检索器模型**:推荐使用轻量化双语文本编码器如$bge-small-zh-v1.5$[^1],其参数量仅33M,适合本地部署 2. **生成器模型**:可选择$ChatGLM3-6B$的4-bit量化版本[^2],显存占用仅6GB 3. **向量数据库**:使用FAISS的IVF索引,支持CPU快速检索 #### 二、部署流程 1. **环境配置** ```bash conda create -n rag python=3.8 pip install transformers faiss-cpu sentence-transformers ``` 2. **知识库构建** ```python from sentence_transformers import SentenceTransformer encoder = SentenceTransformer('bge-small-zh-v1.5') passage_embeddings = encoder.encode(your_documents) ``` 3. **检索系统实现** ```python import faiss index = faiss.IndexFlatIP(384) # 匹配bge模型的输出维度 index.add(passage_embeddings) ``` 4. **生成器整合 ```python from transformers import AutoModelForCausalLM generator = AutoModelForCausalLM.from_pretrained( "THUDM/chatglm3-6b", load_in_4bit=True, device_map="auto" ) ``` #### 三、性能优化技巧 - 使用$batch\_encode$加速文本编码过程 - 对高频查询建立LRU缓存机制 - 对长文档采用滑动窗口分块策略,块大小推荐512 tokens #### 四、典型问题处理 1. **显存不足**:可采用PagedAttention技术[^3],降低峰值显存占用30% 2. **检索延迟**:在FAISS中启用$nprobe=10$参数平衡精度与速度 3. **生成偏差**:通过temperature参数调节(推荐0.7-1.0区间)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值