这段代码实现了一个基于向量数据库(ChromaDB)和文本嵌入模型(SentenceTransformer)的问答系统,并通过 HTTP 服务器提供服务,用户可以向服务器发送问题,服务器会尝试在预加载的标准问题向量数据库中查找最相似的问题,并返回对应的答案。主要功能包括文本清洗、关键词提取、问题向量表示生成、在向量数据库中查询匹配问题以及选择最佳映射问题等,同时还具备从本地 JSON 文件加载数据到向量数据库的功能。
前端代码参照
代码结构与各部分功能详细分析
- 导入模块部分:
import http.server
import socketserver
import cgi
import json
import os
from chromadb.config import Settings
from chromadb import Client
import chromadb
import re
import string
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer
导入了多个必要的模块,用于处理 HTTP 请求、操作向量数据库、文本处理、特征提取以及使用预训练的文本嵌入模型等操作。
2. 初始化工作目录及相关组件部分:
work_dir = "E:/hw2024/rag"
model_path = os.path.join(work_dir, "text2vec-base-chinese-sentence")
# 初始化文本嵌入模型
try:
model = SentenceTransformer("text2vec-base-chinese-sentence")
print("成功加载 SentenceTransformer 模型")
except Exception as e:
print(f"加载 SentenceTransformer 模型失败,原因: {e}")
raise
chroma_db_dir = os.path.join(work_dir, ".chromadb")
if not os.path.exists(chroma_db_dir):
os.makedirs(chroma_db_dir)
print(f"创建 ChromaDB 数据目录: {chroma_db_dir}")
# 初始化ChromaDB客户端
try:
client = chromadb.Client()
print("成功初始化 ChromaDB 客户端")
except Exception as e:
print(f"初始化 ChromaDB 客户端失败,原因: {e}")
raise
# 获取或创建一个集合
try:
collection = client.get_or_create_collection(name="standard_questions")
print("成功获取或创建 ChromaDB 集合")
except Exception as e:
print(f"获取或创建 ChromaDB 集合失败,原因: {e}")
raise
- 首先定义了工作目录
work_dir
,并基于此构建了模型文件所在目录model_path
(虽然模型通常会从默认仓库下载,但保留此结构方便理解),然后尝试加载SentenceTransformer
模型,如果失败则抛出异常。 - 创建了 ChromaDB 的数据目录
chroma_db_dir
,接着初始化了 ChromaDB 客户端以及获取或创建了名为standard_questions
的集合,用于存储标准问题向量等相关信息,每个步骤若失败都会打印相应错误原因并抛出异常。
- 文本处理相关函数部分:
clean_text
函数:
def clean_text(text):
"""
文本清洗与规范化函数
参数:
text (str): 输入文本
返回:
str: 清洗后的文本
"""
print(f"开始清洗文本: {text}")
# 去除HTML标签
text = re.sub(r'<.*?>', '', text)
# 去除特殊字符和标点符号
text = text.translate(str.maketrans('', '', string.punctuation))
# 转换为小写
text = text.lower()
# 去除多余空格
text =''.join(text.split())
print(f"清洗后的文本: {text}")
return text
用于对输入文本进行清洗和规范化操作,依次去除 HTML 标签、特殊字符和标点符号、将文本转换为小写以及去除多余空格,最后返回清洗后的文本,并在过程中打印相关信息。
extract_keywords
函数:
def extract_keywords(text):
"""
关键词提取函数(使用TF-IDF算法)
参数:
text (str): 输入文本
返回:
list: 关键词列表
"""
vectorizer = TfidfVectorizer()
tfidf_matrix = vectorizer.fit_transform([text])
feature_names = vectorizer.get_feature_names_out()
keywords = []
for col in tfidf_matrix.nonzero()[1]:
keywords.append(feature_names[col])
print(f"提取的关键词: {keywords}")
return keywords
使用TF-IDF
算法来提取输入文本中的关键词,先创建TF-IDF
向量器,对输入文本进行拟合和转换得到矩阵,然后从矩阵中获取非零元素对应的特征名称作为关键词,最后返回关键词列表并打印相关信息。
vectorize_question
函数:
def vectorize_question(question):
"""
将问题转换为向量表示
参数:
question (str): 输入问题
返回:
list: 问题向量
"""
print(f"开始将问题 {question} 转换为向量表示")
try:
vector = model.encode([question])[0]
print(f"成功将问题转换为向量,向量维度: {len(vector)}")
return vector
except Exception as e:
print(f"将问题转换为向量时失败,原因: {e}")
raise
利用之前加载的SentenceTransformer
模型将输入的问题转换为向量表示,如果转换过程中出现异常则打印错误原因并抛出异常,转换成功后返回向量并打印向量维度信息。
4. 查询与选择相关函数部分:
some_other_condition
函数:
def some_other_condition(candidate_question, user_question):
candidate_keywords = set(extract_keywords(candidate_question))
user_keywords = set(extract_keywords(user_question))
keyword_overlap = len(user_keywords.intersection(candidate_keywords))
return keyword_overlap >= 2 # 这里假设关键词重叠度大于等于2就满足条件,可以根据实际情况修改这个数值
通过计算用户问题和候选问题的关键词重叠度(以集合交集的元素个数来衡量),判断是否满足一定条件(当前设定为关键词重叠度大于等于 2),可以根据实际情况调整这个判断阈值。
query_vector_database
函数:
def query_vector_database(user_question_vector, user_question, top_k=5, similarity_threshold=0.7):
print(f"开始在向量数据库中查询,用户问题向量: {user_question_vector}")
print(f"当前 collection 对象信息: {collection}")
try:
results = collection.query(
query_embeddings=[user_question_vector],
n_results=top_k
)
print(f"查询得到的原始结果: