<think>我们确实可以将上述RAG系统实现在一个.py文件中,但需要注意以下几点:
1. 代码会相对较长,但逻辑集中,便于理解和运行。
2. 我们将所有必要的类、函数和配置集中在一个文件中。
3. 为了简化,我们使用本地嵌入模型(如sentence-transformers)和FAISS向量数据库。
4. 使用OpenAI作为LLM,但也可以替换为其他模型。
以下是单文件实现的完整代码:</think>### 单文件 RAG 系统实现
是的,完全可以将整个 RAG 系统实现在一个 Python 文件中。以下是完整的单文件实现,包含所有核心功能:
```python
import os
import json
import pickle
import faiss
import numpy as np
import requests
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain.llms import OpenAI
from langchain.docstore.document import Document
from typing import List, Dict, Any, Tuple, Optional
class RAGSystem:
def __init__(self, config: Dict[str, Any]):
"""初始化RAG系统"""
self.config = config
self.embedding_model = self._load_embedding_model()
self.llm = self._load_llm()
self.vector_db = None
self.doc_id_map = {}
# 自动加载或创建索引
self._initialize_index()
def _load_embedding_model(self) -> HuggingFaceEmbeddings:
"""加载嵌入模型"""
model_name = self.config.get("embedding_model", "sentence-transformers/all-mpnet-base-v2")
return HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs={"device": self.config.get("embedding_device", "cpu")}
)
def _load_llm(self):
"""加载大语言模型"""
llm_provider = self.config.get("llm_provider", "openai")
if llm_provider == "openai":
return OpenAI(
api_key=self.config["openai_api_key"],
model_name=self.config.get("llm_model", "gpt-3.5-turbo"),
temperature=self.config.get("temperature", 0.7)
)
elif llm_provider == "huggingface":
# 使用本地Hugging Face模型
from langchain.llms import HuggingFacePipeline
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
model_name = self.config.get("hf_model", "gpt2")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_length=self.config.get("max_length", 200)
)
return HuggingFacePipeline(pipeline=pipe)
else:
raise ValueError(f"不支持的LLM提供者: {llm_provider}")
def _initialize_index(self):
"""初始化索引:加载现有或创建新索引"""
index_path = self.config["index_path"]
if os.path.exists(index_path):
print(f"加载现有索引: {index_path}")
self._load_vector_index()
else:
print(f"创建新索引: {index_path}")
self._create_new_index()
def _create_new_index(self):
"""创建新索引"""
data_dir = self.config["data_dir"]
if not os.path.exists(data_dir):
print(f"数据目录不存在: {data_dir}")
if self.config.get("auto_download", False):
self._download_sample_data()
else:
raise FileNotFoundError(f"数据目录不存在: {data_dir}")
documents = self._load_and_process_documents()
self._create_vector_index(documents)
def _download_sample_data(self):
"""下载示例数据"""
print("下载示例数据...")
data_dir = self.config["data_dir"]
os.makedirs(data_dir, exist_ok=True)
sample_urls = [
"https://raw.githubusercontent.com/langchain-ai/langchain/master/docs/docs_skeleton.json"
]
for url in sample_urls:
response = requests.get(url)
filename = os.path.basename(url)
file_path = os.path.join(data_dir, filename)
with open(file_path, "wb") as f:
f.write(response.content)
print(f"下载完成: {filename}")
def _load_and_process_documents(self) -> List[Document]:
"""加载并处理文档"""
documents = []
data_dir = self.config["data_dir"]
# 支持多种文件格式
for filename in os.listdir(data_dir):
file_path = os.path.join(data_dir, filename)
if filename.endswith(".json") or filename.endswith(".jsonl"):
documents.extend(self._load_json_documents(file_path))
elif filename.endswith(".txt"):
documents.extend(self._load_text_documents(file_path))
if not documents:
raise ValueError(f"在 {data_dir} 中没有找到可处理的文档")
# 文本分块
return self._split_documents(documents)
def _load_json_documents(self, file_path: str) -> List[Document]:
"""加载JSON或JSONL文档"""
documents = []
with open(file_path, "r") as f:
if file_path.endswith(".jsonl"):
# 处理JSONL文件
for line in f:
try:
data = json.loads(line)
doc = self._create_document_from_data(data)
documents.append(doc)
except json.JSONDecodeError:
print(f"跳过无效的JSON行: {line.strip()}")
else:
# 处理JSON文件
try:
data = json.load(f)
if isinstance(data, list):
for item in data:
doc = self._create_document_from_data(item)
documents.append(doc)
elif isinstance(data, dict):
doc = self._create_document_from_data(data)
documents.append(doc)
except json.JSONDecodeError:
print(f"无效的JSON文件: {file_path}")
return documents
def _load_text_documents(self, file_path: str) -> List[Document]:
"""加载纯文本文档"""
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
return [Document(
page_content=content,
metadata={
"source": file_path,
"title": os.path.basename(file_path),
"category": "text"
}
)]
def _create_document_from_data(self, data: Dict) -> Document:
"""从数据创建文档对象"""
return Document(
page_content=data.get("content", data.get("text", "")),
metadata={
"source": data.get("url", data.get("source", "")),
"title": data.get("title", ""),
"category": data.get("category", "unknown"),
"timestamp": data.get("timestamp", "")
}
)
def _split_documents(self, documents: List[Document]) -> List[Document]:
"""分割文档为块"""
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=self.config.get("chunk_size", 1000),
chunk_overlap=self.config.get("chunk_overlap", 200),
length_function=len
)
return text_splitter.split_documents(documents)
def _create_vector_index(self, documents: List[Document]):
"""创建FAISS向量索引"""
# 创建向量数据库
self.vector_db = FAISS.from_documents(
documents=documents,
embedding=self.embedding_model
)
# 保存索引
os.makedirs(os.path.dirname(self.config["index_path"]), exist_ok=True)
self.vector_db.save_local(self.config["index_path"])
# 创建文档ID映射
for idx, doc in enumerate(documents):
self.doc_id_map[idx] = {
"source": doc.metadata["source"],
"title": doc.metadata["title"],
"category": doc.metadata["category"]
}
# 保存映射表
map_path = os.path.join(os.path.dirname(self.config["index_path"]), "doc_id_map.pkl")
with open(map_path, "wb") as f:
pickle.dump(self.doc_id_map, f)
print(f"✅ 向量索引已创建并保存至 {self.config['index_path']}")
def _load_vector_index(self):
"""加载现有的FAISS向量索引"""
index_path = self.config["index_path"]
# 加载向量数据库
self.vector_db = FAISS.load_local(
folder_path=index_path,
embeddings=self.embedding_model
)
# 加载文档映射表
map_path = os.path.join(os.path.dirname(index_path), "doc_id_map.pkl")
if os.path.exists(map_path):
with open(map_path, "rb") as f:
self.doc_id_map = pickle.load(f)
print(f"✅ 文档映射表已加载")
else:
print("⚠️ 文档映射表未找到,将使用空映射")
def ask_question(self, question: str, history: Optional[List] = None) -> Dict:
"""提问并获取答案"""
if self.vector_db is None:
raise ValueError("向量数据库未初始化")
# 创建检索器
retriever = self.vector_db.as_retriever(
search_kwargs={
"k": self.config.get("retrieval_top_k", 5),
"score_threshold": self.config.get("score_threshold", 0.4)
}
)
# 创建问答链
qa_chain = self._create_qa_chain(retriever)
# 执行问答
result = qa_chain({"query": question})
# 提取源文档
source_docs = result["source_documents"]
doc_ids = [doc.metadata.get("doc_id", idx) for idx, doc in enumerate(source_docs)]
# 获取完整上下文
full_contexts = [self.doc_id_map.get(did, {"title": "未知", "source": ""}) for did in doc_ids]
return {
"question": question,
"answer": result["result"],
"source_documents": source_docs,
"full_contexts": full_contexts
}
def _create_qa_chain(self, retriever) -> Any:
"""创建问答链"""
# 自定义提示模板
prompt_template = """
基于以下上下文信息,请以专业、准确的方式回答用户的问题。如果上下文信息不足以回答问题,请如实告知用户。
上下文信息:
{context}
问题: {question}
请提供详细的回答:
"""
QA_PROMPT = PromptTemplate(
template=prompt_template,
input_variables=["context", "question"]
)
# 创建问答链
return RetrievalQA.from_chain_type(
llm=self.llm,
chain_type="stuff",
retriever=retriever,
return_source_documents=True,
chain_type_kwargs={"prompt": QA_PROMPT}
)
def add_document(self, content: str, metadata: Dict):
"""添加单个文档到索引"""
if self.vector_db is None:
raise ValueError("向量数据库未初始化")
# 创建文档对象
doc = Document(page_content=content, metadata=metadata)
# 分割文档
split_docs = self._split_documents([doc])
# 添加到索引
self.vector_db.add_documents(split_docs)
# 更新文档映射
start_idx = max(self.doc_id_map.keys()) + 1 if self.doc_id_map else 0
for idx, doc in enumerate(split_docs):
self.doc_id_map[start_idx + idx] = {
"source": doc.metadata["source"],
"title": doc.metadata["title"],
"category": doc.metadata["category"]
}
print(f"✅ 添加了 {len(split_docs)} 个文档块")
def save_index(self):
"""保存索引到磁盘"""
if self.vector_db is None:
raise ValueError("向量数据库未初始化")
# 保存索引
self.vector_db.save_local(self.config["index_path"])
# 保存映射表
map_path = os.path.join(os.path.dirname(self.config["index_path"]), "doc_id_map.pkl")
with open(map_path, "wb") as f:
pickle.dump(self.doc_id_map, f)
print(f"✅ 索引已保存至 {self.config['index_path']}")
def create_default_config() -> Dict:
"""创建默认配置"""
return {
"data_dir": "data",
"index_path": "index/faiss_index",
"embedding_model": "sentence-transformers/all-mpnet-base-v2",
"embedding_device": "cpu",
"llm_provider": "openai",
"openai_api_key": "your_api_key_here",
"llm_model": "gpt-3.5-turbo",
"temperature": 0.7,
"chunk_size": 1000,
"chunk_overlap": 200,
"retrieval_top_k": 5,
"score_threshold": 0.4,
"auto_download": True # 自动下载示例数据
}
def interactive_cli():
"""交互式命令行界面"""
config = create_default_config()
# 设置OpenAI API密钥
api_key = input("请输入OpenAI API密钥(或直接回车使用默认值): ")
if api_key.strip():
config["openai_api_key"] = api_key
# 初始化RAG系统
rag = RAGSystem(config)
# 交互问答
print("\nRAG系统已就绪,输入问题开始查询(输入'q'退出)")
history = []
while True:
question = input("\n>>> 问题: ")
if question.lower() in ['q', 'quit', 'exit']:
break
try:
result = rag.ask_question(question, history)
# 显示结果
print(f"\n💡 答案: {result['answer']}")
if result["full_contexts"]:
print("\n📚 信息来源:")
for i, ctx in enumerate(result["full_contexts"]):
print(f" {i+1}. {ctx['title']} ({ctx['category']})")
print(f" 来源: {ctx['source']}")
# 添加到历史
history.append({"question": question, "answer": result["answer"]})
except Exception as e:
print(f"❌ 错误: {str(e)}")
# 保存索引
rag.save_index()
print("\n索引已保存,再见!")
if __name__ == "__main__":
interactive_cli()
```
### 单文件实现说明
这个单文件实现包含了 RAG 系统的所有核心功能:
#### 1. 完整功能集成
- **文档处理**:支持 JSON、JSONL 和 TXT 格式
- **索引管理**:自动创建或加载 FAISS 索引
- **问答系统**:完整的检索增强生成流程
- **交互界面**:内置命令行交互界面
#### 2. 自包含特性
- **自动数据下载**:如果本地没有数据,会自动下载示例数据
```python
def _download_sample_data(self):
"""下载示例数据"""
print("下载示例数据...")
# ...下载逻辑...
```
- **配置管理**:内置默认配置
```python
def create_default_config() -> Dict:
"""创建默认配置"""
return {
"data_dir": "data",
"index_path": "index/faiss_index",
# ...其他配置...
}
```
- **错误处理**:健壮的错误处理和用户提示
```python
try:
result = rag.ask_question(question, history)
# ...显示结果...
except Exception as e:
print(f"❌ 错误: {str(e)}")
```
#### 3. 简化使用
- **零配置运行**:只需提供 OpenAI API 密钥即可运行
- **交互式 CLI**:内置命令行界面
```python
def interactive_cli():
"""交互式命令行界面"""
# ...初始化...
while True:
question = input("\n>>> 问题: ")
# ...处理问题...
```
#### 4. 文件结构
当运行此脚本时,会自动创建以下目录结构:
```
项目目录/
├── data/ # 存储原始数据
├── index/ # 存储FAISS索引
│ ├── faiss_index # FAISS索引文件
│ └── doc_id_map.pkl # 文档ID映射
└── rag_system.py # 主脚本文件
```
### 使用示例
#### 1. 首次运行
```bash
python rag_system.py
```
系统将提示输入 OpenAI API 密钥,然后自动下载示例数据并创建索引。
#### 2. 添加自定义文档
```python
# 在脚本中添加以下代码
rag.add_document(
content="东北石油大学创建于1960年,是一所以石油石化为特色的多科性大学。",
metadata={
"source": "https://www.nepu.edu.cn/about",
"title": "学校简介",
"category": "学校概况"
}
)
rag.save_index()
```
#### 3. 提问示例
```
>>> 问题: 东北石油大学是什么时候创建的?
💡 答案: 东北石油大学创建于1960年。
📚 信息来源:
1. 学校简介 (学校概况)
来源: https://www.nepu.edu.cn/about
```
### 单文件优势
1. **简化部署**:只需一个 Python 文件即可运行整个系统
2. **易于分享**:方便与他人共享和协作
3. **快速原型**:适合快速验证和演示
4. **学习友好**:所有代码集中在一个文件中,便于理解系统架构
5. **减少依赖**:只需安装必要的 Python 包即可运行
### 扩展建议
如果需要扩展功能,可以在现有基础上添加:
1. **支持更多文件格式**:
```python
def _load_documents(self, file_path: str) -> List[Document]:
if file_path.endswith(".pdf"):
return self._load_pdf_document(file_path)
elif file_path.endswith(".docx"):
return self._load_docx_document(file_path)
# ...其他格式...
```
2. **添加缓存机制**:
```python
from functools import lru_cache
@lru_cache(maxsize=100)
def ask_question(self, question: str) -> Dict:
# ...实现...
```
3. **集成简单 Web 界面**:
```python
from flask import Flask, request, jsonify
app = Flask(__name__)
rag = RAGSystem(create_default_config())
@app.route('/ask', methods=['POST'])
def ask_endpoint():
data = request.json
question = data.get('question')
result = rag.ask_question(question)
return jsonify(result)
if __name__ == '__main__':
app.run(port=5000)
```
这个单文件实现包含了 RAG 系统的所有核心功能,同时保持了代码的简洁性和易用性,非常适合快速部署和演示。