1. 什么是RAG技术
RAG is short for Retrieval Augmented Generation。结合了检索模型和生成模型的能力,以提高文本生成任务的性能。具体来说,RAG技术允许大型语言模型(Large Language Model, LLM)在生成回答时,不仅依赖于其内部知识,还能检索并利用外部数据源中的信息。
对于这个概念,我自己的理解是,大模型相当于是一个人,而RAG技术检索并利用的外部数据源就是书本、或者电子/数据资料。而RAG就是人检索并根据书本或者电子资料生成任务的能力。
比如一个人一目十行,理解能力强,可以快速地汲取知识并加以理解从而输出,就代表这个人的学习能力强,就相当于RAG技术性能优越。而另一个人阅读能力差,不容易理解新知识,就相当于RAG技术没做好,性能不行。
在这张图中,我把人类智能比作RAG技术,人类比作AI,外部知识来源比作向量数据库(一般与RAG一起使用)。RAG的实现越好,那么相当于越智能,则AI的能力越强。
2. RAG技术的Working Pipeline
首先我们要搜集插入到向量数据库 中,也即实体的文档、结构化知识、手册,读取文本内容,进行文本分割,进行向量嵌入后插入向量数据库中。
当用户请求大模型时,首先将查询向量化,随后检索向量库得到相似度高的知识,作为背景注入到prompt,随后大模型再生成回答。
3. RAG的实现
在github上,有一个RAG实现的Web应用的Demo。Langchain-Chatchat
我们同样打算以Web应用的模式构建一个能够被请求用来检索知识的向量数据库。因此先学习阅读一下这个项目的代码。
3.1. Web应用的入口:挂载Web应用路径
这一部分其实和RAG本身关系不大了,属于是网络通信方面的部分。但因为它是整个应用的入口,所以有必要探索一下。
首先在这个项目的README文件中,我们发现了这个Web应用还有个在线的接口文档。
从这个接口文档中,可以看到对于知识库(Knowledge Base) 的接口,这一部分就涉及了向量数据库。
我们可以通过在IDE中全局搜索这些接口,来找到暴露这些应用路径的地方。
可以看到,server/api.py下挂载了这些接口,我们来到这个文件一探究竟。其中不乏这样的函数:
app.post("/knowledge_base/create_knowledge_base",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="创建知识库"
)(create_kb)
app.post("/knowledge_base/delete_knowledge_base",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="删除知识库"
)(delete_kb)
app.get("/knowledge_base/list_files",
tags=["Knowledge Base Management"],
response_model=ListResponse,
summary="获取知识库内的文件列表"
)(list_files)
app.post("/knowledge_base/search_docs",
tags=["Knowledge Base Management"],
response_model=List[DocumentWithVSId],
summary="搜索知识库"
)(search_docs)
我们点到每个函数中的参数,即create_kb这样的参数,来到了一个名叫kb_api.py的文件,其中暴露了这个函数(create_kb)。
此时我们就通过挂载Web应用路径的入口,找到了与向量数据库交互的模块。
3.2. 与向量数据库交互
现在来看看这些与向量数据库交互的函数。
通过交互函数看知识库工程架构
首先我们关注到create_kb中的这样一部分代码:
kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model)
try:
kb.create_kb()
光看这个名字,我们就能知道,这是一个工厂方法的设计模式。获取知识库的方式并不是直接拿到知识库的操作柄,而是先通过提供知识库服务的工厂拿到一项知识库的服务。
对于get_service函数,如下:
@staticmethod
def get_service(kb_name: str,
vector_store_type: Union[str, SupportedVSType],
embed_model: str = EMBEDDING_MODEL,
) -> KBService:
if isinstance(vector_store_type, str):
vector_store_type = getattr(SupportedVSType, vector_store_type.upper())
if SupportedVSType.FAISS == vector_store_type:
from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService
return FaissKBService(kb_name, embed_model=embed_model)
elif SupportedVSType.PG == vector_store_type:
from server.knowledge_base.kb_service.pg_kb_service import PGKBService
return PGKBService(kb_name, embed_model=embed_model)
elif SupportedVSType.MILVUS == vector_store_type:
from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
return MilvusKBService(kb_name,embed_model=embed_model)
elif SupportedVSType.ZILLIZ == vector_store_type:
from server.knowledge_base.kb_service.zilliz_kb_service import ZillizKBService
return ZillizKBService(kb_name, embed_model=embed_model)
elif SupportedVSType.DEFAULT == vector_store_type:
from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
return MilvusKBService(kb_name,
embed_model=embed_model) # other milvus parameters are set in model_config.kbs_config
elif SupportedVSType.ES == vector_store_type:
from server.knowledge_base.kb_service.es_kb_service import ESKBService
return ESKBService(kb_name, embed_model=embed_model)
elif SupportedVSType.CHROMADB == vector_store_type:
from server.knowledge_base.kb_service.chromadb_kb_service import ChromaKBService
return ChromaKBService(kb_name, embed_model=embed_model)
elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier.
from server.knowledge_base.kb_service.default_kb_service import DefaultKBService
return DefaultKBService(kb_name)
那么这个是在干什么?显然,他根据向量嵌入的方式,确定要创建的数据库服务是基于哪个向量数据库的,可能是chroma,也可能是Faiss,等等。
总之,它返回了一个KBService子类的实例。而这里KBService并非是一个可实例化的类,因为它是抽象类。
在server/knowledge_base/kb_service中,我们可以看到Class Definition。
@abstractmethod
def do_create_kb(self):
"""
创建知识库子类实自己逻辑
"""
pass
在类定义中,出现了@abstractmethod注解,说明这是个抽象类。
那么其实现都在哪里呢?经过一番翻阅,在server/knowledge_base/kb_service下,包括了大量的基于不同数据库的实现类。
在翻阅代码时,我关注到了项目默认的向量数据库是faiss,因此我们可以来到faiss_kb_service中查看。
class FaissKBService(KBService):
vs_path: str
kb_path: str
vector_name: str = None
类定义中,对于KBService的继承赫然在目。
再回到通过KBServiceFactory创建KBService处:
kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model)
try:
kb.create_kb()
我们溯源create_kb,可以发现:
def create_kb(self):
"""
创建知识库
"""
if not os.path.exists(self.doc_path):
os.makedirs(self.doc_path)
self.do_create_kb()
status = add_kb_to_db(self.kb_name, self.kb_info, self.vs_type(), self.embed_model)
return status
可以看到,create_kb调用了self(实例自身)的do_create_kb()。而这就是刚才提到的抽象方法,也就是它会根据不同类对其的覆写,执行不同的逻辑。
def do_create_kb(self):
if not os.path.exists(self.vs_path):
os.makedirs(self.vs_path)
self.load_vector_store()
def load_vector_store(self) -> ThreadSafeFaiss:
return kb_faiss_pool.load_vector_store(kb_name=self.kb_name,
vector_name=self.vector_name,
embed_model=self.embed_model)
例如faiss就有自己独特的创建数据库的方式。
因此这个设计架构就明确了,是一个四层的Web-静态工厂-抽象类-实体类的架构。如下图所示:
Mapping from Abstract Working Pipeline to Code
现在我们知道了如何获取一个向量数据库的服务。但在哪里使用它,如何使用它呢?正如先前RAG的Working Pipeline中所说,用户在请求大模型进行任务时,先通过检索向量数据库获取相似知识优化Prompt,再进行提问。那么这样一套流程,是如何映射到代码中的,我们是如何使用向量数据库提供的检索功能的?
找到RAG流程的入口
为了找到这个接口的入口,我还是先翻看了server/api.py文件,其中包括了:
app.post("/chat/chat",
tags=["Chat"],
summary="与llm模型对话(通过LLMChain)",
)(chat)
app.post("/chat/search_engine_chat",
tags=["Chat"],
summary="与搜索引擎对话",
)(search_engine_chat)
app.post("/chat/feedback",
tags=["Chat"],
summary="返回llm模型对话评分",
)(chat_feedback)
app.post("/chat/knowledge_base_chat",
tags=["Chat"],
summary="与知识库对话")(knowledge_base_chat)
app.post("/chat/file_chat",
tags=["Knowledge Base Management"],
summary="文件对话"
)(file_chat)
app.post("/chat/agent_chat",
tags=["Chat"],
summary="与agent对话")(agent_chat)
一开始我以为/chat/chat这个接口是包括了RAG流程的接口,但后来我翻了翻代码,发觉并没有检索向量数据库。
随后经过一些翻阅,我找到了/chat/knowledge_base_chat这个一接口:
async def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
score_threshold: float = Body(
SCORE_THRESHOLD,
description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右",
ge=0,
le=2
),
history: List[History] = Body(
[],
description="历史对话",
examples=[[
{"role": "user",
"content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant",
"content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(
None,
description="限制LLM生成Token数量,默认None代表模型最大值"
),
prompt_name: str = Body(
"default",
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"
),
request: Request = None,
):
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
history = [History.from_data(h) for h in history]
async def knowledge_base_chat_iterator(
query: str,
top_k: int,
history: Optional[List[History]],
model_name: str = model_name,
prompt_name: str = prompt_name,
) -> AsyncIterable[str]:
nonlocal max_tokens
callback = AsyncIteratorCallbackHandler()
if isinstance(max_tokens, int) and max_tokens <= 0:
max_tokens = None
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=[callback],
)
docs = await run_in_threadpool(search_docs,
query=query,
knowledge_base_name=knowledge_base_name,
top_k=top_k,
score_threshold=score_threshold)
# 加入reranker
if USE_RERANKER:
reranker_model_path = get_model_path(RERANKER_MODEL)
reranker_model = LangchainReranker(top_n=top_k,
device=embedding_device(),
max_length=RERANKER_MAX_LENGTH,
model_name_or_path=reranker_model_path
)
print("-------------before rerank-----------------")
print(docs)
docs = reranker_model.compress_documents(documents=docs,
query=query)
print("------------after rerank------------------")
print(docs)
context = "\n".join([doc.page_content for doc in docs])
if len(docs) == 0: # 如果没有找到相关文档,使用empty模板
prompt_template = get_prompt_template("knowledge_base_chat", "empty")
else:
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_template() for i in history] + [input_msg])
chain = LLMChain(prompt=chat_prompt, llm=model)
# Begin a task that runs in the background.
task = asyncio.create_task(wrap_done(
chain.acall({"context": context, "question": query}),
callback.done),
)
source_documents = []
for inum, doc in enumerate(docs):
filename = doc.metadata.get("source")
parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name": filename})
base_url = request.base_url
url = f"{base_url}knowledge_base/download_doc?" + parameters
text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
source_documents.append(text)
if len(source_documents) == 0: # 没有找到相关文档
source_documents.append(f"<span style='color:red'>未找到相关文档,该回答为大模型自身能力解答!</span>")
if stream:
async for token in callback.aiter():
# Use server-sent-events to stream the response
yield json.dumps({"answer": token}, ensure_ascii=False)
yield json.dumps({"docs": source_documents}, ensure_ascii=False)
else:
answer = ""
async for token in callback.aiter():
answer += token
yield json.dumps({"answer": answer,
"docs": source_documents},
ensure_ascii=False)
await task
return EventSourceResponse(knowledge_base_chat_iterator(query, top_k, history,model_name,prompt_name))
他这个函数签名非常长,一堆参数,但实际有用的其实主要还是集中在query,也即用户查询上,其他的都是要调用langchain的库或者与向量数据库交互的必要参数。top k个相关向量是RAG技术的一部分,也是必要的参数。
源码解读
首先,先获取了数据库服务。(当然也可能数据库不存在)
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
随后选择LLM模型实例:
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=[callback],
)
再在对应的向量数据库中检索相关文档(top k个)
docs = await run_in_threadpool(search_docs,
query=query,
knowledge_base_name=knowledge_base_name,
top_k=top_k,
score_threshold=score_threshold)
这个异步调用中的search_docs暴露自server/knowledge_basekb_doc_api.py,如下:
def search_docs(
query: str = Body("", description="用户输入", examples=["你好"]),
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
score_threshold: float = Body(SCORE_THRESHOLD,
description="知识库匹配相关度阈值,取值范围在0-1之间,"
"SCORE越小,相关度越高,"
"取到1相当于不筛选,建议设置在0.5左右",
ge=0, le=1),
file_name: str = Body("", description="文件名称,支持 sql 通配符"),
metadata: dict = Body({}, description="根据 metadata 进行过滤,仅支持一级键"),
) -> List[DocumentWithVSId]:
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
data = []
if kb is not None:
if query:
docs = kb.search_docs(query, top_k, score_threshold)
data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs]
elif file_name or metadata:
data = kb.list_docs(file_name=file_name, metadata=metadata)
for d in data:
if "vector" in d.metadata:
del d.metadata["vector"]
return data
首先还是获取数据库服务,随后调用服务类暴露的search_docs函数(这个很显然,对于不同向量数据库来说,肯定是具体实现不一样), 随后返回相似度在阈值内的top_k个结果。
if len(docs) == 0: # 如果没有找到相关文档,使用empty模板
prompt_template = get_prompt_template("knowledge_base_chat", "empty")
else:
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_template() for i in history] + [input_msg])
chain = LLMChain(prompt=chat_prompt, llm=model)
随后,建立prompt模板。然后根据历史会话信息建立当前对话的prompt。
之后通过LangChain提供的LLMChain,获取能够进行用户任务的中间件。
# Begin a task that runs in the background.
task = asyncio.create_task(wrap_done(
chain.acall({"context": context, "question": query}),
callback.done),
)
随后启动一个后台的异步任务,将向量数据库中检索到的文档作为知识背景,用户的输入作为问题。
source_documents = []
for inum, doc in enumerate(docs):
filename = doc.metadata.get("source")
parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name": filename})
base_url = request.base_url
url = f"{base_url}knowledge_base/download_doc?" + parameters
text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
source_documents.append(text)
if len(source_documents) == 0: # 没有找到相关文档
source_documents.append(f"<span style='color:red'>未找到相关文档,该回答为大模型自身能力解答!</span>")
一般LLM回答问题,会把自己参考的文献放出来(比如说Kimi),这一部分做的就是拼接参考文献字符串。
return EventSourceResponse(knowledge_base_chat_iterator(query, top_k, history,model_name,prompt_name))
最后返回大模型的回答。
这个过程就是RAG的Working Pipeline在代码部分中的映射。
将知识嵌入到知识库
这一部分相对而言比较直接。在server/api.py中,有这么一段:
app.post("/knowledge_base/upload_docs",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="上传文件到知识库,并/或进行向量化"
)(upload_docs)
找到对应的upload_docs,在server/knowledge_basekb_doc_api.py中。
def upload_docs(
files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
override: bool = Form(False, description="覆盖已有文件"),
to_vector_store: bool = Form(True, description="上传文件后是否进行向量化"),
chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"),
chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
docs: Json = Form({}, description="自定义的docs,需要转为json字符串",
examples=[{"test.txt": [Document(page_content="custom doc")]}]),
not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库(用于FAISS)"),
) -> BaseResponse:
"""
API接口:上传文件,并/或向量化
"""
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
failed_files = {}
file_names = list(docs.keys())
# 先将上传的文件保存到磁盘
for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override):
filename = result["data"]["file_name"]
if result["code"] != 200:
failed_files[filename] = result["msg"]
if filename not in file_names:
file_names.append(filename)
# 对保存的文件进行向量化
if to_vector_store:
result = update_docs(
knowledge_base_name=knowledge_base_name,
file_names=file_names,
override_custom_docs=True,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
zh_title_enhance=zh_title_enhance,
docs=docs,
not_refresh_vs_cache=True,
)
failed_files.update(result.data["failed_files"])
if not not_refresh_vs_cache:
kb.save_vector_store()
return BaseResponse(code=200, msg="文件上传与向量化完成", data={"failed_files": failed_files})
这一部分最重要的还是save_vector_store函数,不过这一部分属于每种数据库自己的实现了。
我们可以看一个faiss的
def load_vector_store(self) -> ThreadSafeFaiss:
return kb_faiss_pool.load_vector_store(kb_name=self.kb_name,
vector_name=self.vector_name,
embed_model=self.embed_model)
def load_vector_store(
self,
kb_name: str,
vector_name: str = None,
create: bool = True,
embed_model: str = EMBEDDING_MODEL,
embed_device: str = embedding_device(),
) -> ThreadSafeFaiss:
self.atomic.acquire()
vector_name = vector_name or embed_model
cache = self.get((kb_name, vector_name)) # 用元组比拼接字符串好一些
if cache is None:
item = ThreadSafeFaiss((kb_name, vector_name), pool=self)
self.set((kb_name, vector_name), item)
with item.acquire(msg="初始化"):
self.atomic.release()
logger.info(f"loading vector store in '{kb_name}/vector_store/{vector_name}' from disk.")
vs_path = get_vs_path(kb_name, vector_name)
if os.path.isfile(os.path.join(vs_path, "index.faiss")):
embeddings = self.load_kb_embeddings(kb_name=kb_name, embed_device=embed_device, default_embed_model=embed_model)
vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT")
elif create:
# create an empty vector store
if not os.path.exists(vs_path):
os.makedirs(vs_path)
vector_store = self.new_vector_store(embed_model=embed_model, embed_device=embed_device)
vector_store.save_local(vs_path)
else:
raise RuntimeError(f"knowledge base {kb_name} not exist.")
item.obj = vector_store
item.finish_loading()
else:
self.atomic.release()
return self.get((kb_name, vector_name))
其实这个模块是个缓存机制,也就是说每次检索都会查看是否已经有这个向量数据库的操作柄了。如果有直接返回,如果没有则加载一遍,这个加载的过程集中在:
def get(self, key: str) -> ThreadSafeObject:
if cache := self._cache.get(key):
cache.wait_for_loading()
return cache
那么他返回的是什么呢?是一个对应数据库的操作柄,定义如下:
class ThreadSafeFaiss(ThreadSafeObject):
def __repr__(self) -> str:
cls = type(self).__name__
return f"<{cls}: key: {self.key}, obj: {self._obj}, docs_count: {self.docs_count()}>"
def docs_count(self) -> int:
return len(self._obj.docstore._dict)
def save(self, path: str, create_path: bool = True):
with self.acquire():
if not os.path.isdir(path) and create_path:
os.makedirs(path)
ret = self._obj.save_local(path)
logger.info(f"已将向量库 {self.key} 保存到磁盘")
return ret
def clear(self):
ret = []
with self.acquire():
ids = list(self._obj.docstore._dict.keys())
if ids:
ret = self._obj.delete(ids)
assert len(self._obj.docstore._dict) == 0
logger.info(f"已将向量库 {self.key} 清空")
return ret
本质上是存储向量化文档的一个对象。
4. 体验这个应用
虽然README中说了怎么用,但这里想补充下。
首先大模型你可以不下载(如果不用这个服务),但向量嵌入模型必须下载。如果你hugging-face用git clone拉不下来,上去手动下也行。
其次如果你的电脑配不了cuda环境,那么你就没办法加载运行大模型。不过你可以选择放弃大模型服务,因为还有向量知识库的服务可以用。
只需要在启动脚本里把加载运行大模型部分的代码注释掉就行(以下是完整的启动脚本):
import asyncio
import multiprocessing as mp
import os
import subprocess
import sys
from multiprocessing import Process
from datetime import datetime
from pprint import pprint
from langchain_core._api import deprecated
try:
import numexpr
n_cores = numexpr.utils.detect_number_of_cores()
os.environ["NUMEXPR_MAX_THREADS"] = str(n_cores)
except:
pass
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs import (
LOG_PATH,
log_verbose,
logger,
LLM_MODELS,
EMBEDDING_MODEL,
TEXT_SPLITTER_NAME,
FSCHAT_CONTROLLER,
FSCHAT_OPENAI_API,
FSCHAT_MODEL_WORKERS,
API_SERVER,
WEBUI_SERVER,
HTTPX_DEFAULT_TIMEOUT,
)
from server.utils import (fschat_controller_address, fschat_model_worker_address,
fschat_openai_api_address, get_httpx_client, get_model_worker_config,
MakeFastAPIOffline, FastAPI, llm_device, embedding_device)
from server.knowledge_base.migrate import create_tables
import argparse
from typing import List, Dict
from configs import VERSION
@deprecated(
since="0.3.0",
message="模型启动功能将于 Langchain-Chatchat 0.3.x重写,支持更多模式和加速启动,0.2.x中相关功能将废弃",
removal="0.3.0")
def create_controller_app(
dispatch_method: str,
log_level: str = "INFO",
) -> FastAPI:
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.controller import app, Controller, logger
logger.setLevel(log_level)
controller = Controller(dispatch_method)
sys.modules["fastchat.serve.controller"].controller = controller
MakeFastAPIOffline(app)
app.title = "FastChat Controller"
app._controller = controller
return app
def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
"""
kwargs包含的字段如下:
host:
port:
model_names:[`model_name`]
controller_address:
worker_address:
对于Langchain支持的模型:
langchain_model:True
不会使用fschat
对于online_api:
online_api:True
worker_class: `provider`
对于离线模型:
model_path: `model_name_or_path`,huggingface的repo-id或本地路径
device:`LLM_DEVICE`
"""
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
import argparse
parser = argparse.ArgumentParser()
args = parser.parse_args([])
for k, v in kwargs.items():
setattr(args, k, v)
if worker_class := kwargs.get("langchain_model"): # Langchian支持的模型不用做操作
from fastchat.serve.base_model_worker import app
worker = ""
# 在线模型API
elif worker_class := kwargs.get("worker_class"):
from fastchat.serve.base_model_worker import app
worker = worker_class(model_names=args.model_names,
controller_addr=args.controller_address,
worker_addr=args.worker_address)
# sys.modules["fastchat.serve.base_model_worker"].worker = worker
sys.modules["fastchat.serve.base_model_worker"].logger.setLevel(log_level)
# 本地模型
else:
from configs.model_config import VLLM_MODEL_DICT
if kwargs["model_names"][0] in VLLM_MODEL_DICT and args.infer_turbo == "vllm":
import fastchat.serve.vllm_worker
from fastchat.serve.vllm_worker import VLLMWorker, app, worker_id
from vllm import AsyncLLMEngine
from vllm.engine.arg_utils import AsyncEngineArgs
args.tokenizer = args.model_path
args.tokenizer_mode = 'auto'
args.trust_remote_code = True
args.download_dir = None
args.load_format = 'auto'
args.dtype = 'auto'
args.seed = 0
args.worker_use_ray = False
args.pipeline_parallel_size = 1
args.tensor_parallel_size = 1
args.block_size = 16
args.swap_space = 4 # GiB
args.gpu_memory_utilization = 0.90
args.max_num_batched_tokens = None # 一个批次中的最大令牌(tokens)数量,这个取决于你的显卡和大模型设置,设置太大显存会不够
args.max_num_seqs = 256
args.disable_log_stats = False
args.conv_template = None
args.limit_worker_concurrency = 5
args.no_register = False
args.num_gpus = 1 # vllm worker的切分是tensor并行,这里填写显卡的数量
args.engine_use_ray = False
args.disable_log_requests = False
# 0.2.1 vllm后要加的参数, 但是这里不需要
args.max_model_len = None
args.revision = None
args.quantization = None
args.max_log_len = None
args.tokenizer_revision = None
# 0.2.2 vllm需要新加的参数
args.max_paddings = 256
if args.model_path:
args.model = args.model_path
if args.num_gpus > 1:
args.tensor_parallel_size = args.num_gpus
for k, v in kwargs.items():
setattr(args, k, v)
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args)
worker = VLLMWorker(
controller_addr=args.controller_address,
worker_addr=args.worker_address,
worker_id=worker_id,
model_path=args.model_path,
model_names=args.model_names,
limit_worker_concurrency=args.limit_worker_concurrency,
no_register=args.no_register,
llm_engine=engine,
conv_template=args.conv_template,
)
sys.modules["fastchat.serve.vllm_worker"].engine = engine
sys.modules["fastchat.serve.vllm_worker"].worker = worker
sys.modules["fastchat.serve.vllm_worker"].logger.setLevel(log_level)
else:
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id
args.gpus = "0" # GPU的编号,如果有多个GPU,可以设置为"0,1,2,3"
args.max_gpu_memory = "22GiB"
args.num_gpus = 1 # model worker的切分是model并行,这里填写显卡的数量
args.load_8bit = False
args.cpu_offloading = None
args.gptq_ckpt = None
args.gptq_wbits = 16
args.gptq_groupsize = -1
args.gptq_act_order = False
args.awq_ckpt = None
args.awq_wbits = 16
args.awq_groupsize = -1
args.model_names = [""]
args.conv_template = None
args.limit_worker_concurrency = 5
args.stream_interval = 2
args.no_register = False
args.embed_in_truncate = False
for k, v in kwargs.items():
setattr(args, k, v)
if args.gpus:
if args.num_gpus is None:
args.num_gpus = len(args.gpus.split(','))
if len(args.gpus.split(",")) < args.num_gpus:
raise ValueError(
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
gptq_config = GptqConfig(
ckpt=args.gptq_ckpt or args.model_path,
wbits=args.gptq_wbits,
groupsize=args.gptq_groupsize,
act_order=args.gptq_act_order,
)
awq_config = AWQConfig(
ckpt=args.awq_ckpt or args.model_path,
wbits=args.awq_wbits,
groupsize=args.awq_groupsize,
)
worker = ModelWorker(
controller_addr=args.controller_address,
worker_addr=args.worker_address,
worker_id=worker_id,
model_path=args.model_path,
model_names=args.model_names,
limit_worker_concurrency=args.limit_worker_concurrency,
no_register=args.no_register,
device=args.device,
num_gpus=args.num_gpus,
max_gpu_memory=args.max_gpu_memory,
load_8bit=args.load_8bit,
cpu_offloading=args.cpu_offloading,
gptq_config=gptq_config,
awq_config=awq_config,
stream_interval=args.stream_interval,
conv_template=args.conv_template,
embed_in_truncate=args.embed_in_truncate,
)
sys.modules["fastchat.serve.model_worker"].args = args
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
# sys.modules["fastchat.serve.model_worker"].worker = worker
sys.modules["fastchat.serve.model_worker"].logger.setLevel(log_level)
MakeFastAPIOffline(app)
app.title = f"FastChat LLM Server ({args.model_names[0]})"
app._worker = worker
return app
def create_openai_api_app(
controller_address: str,
api_keys: List = [],
log_level: str = "INFO",
) -> FastAPI:
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.openai_api_server import app, CORSMiddleware, app_settings
from fastchat.utils import build_logger
logger = build_logger("openai_api", "openai_api.log")
logger.setLevel(log_level)
app.add_middleware(
CORSMiddleware,
allow_credentials=True,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
sys.modules["fastchat.serve.openai_api_server"].logger = logger
app_settings.controller_address = controller_address
app_settings.api_keys = api_keys
MakeFastAPIOffline(app)
app.title = "FastChat OpeanAI API Server"
return app
def _set_app_event(app: FastAPI, started_event: mp.Event = None):
@app.on_event("startup")
async def on_startup():
if started_event is not None:
started_event.set()
def run_controller(log_level: str = "INFO", started_event: mp.Event = None):
import uvicorn
import httpx
from fastapi import Body
import time
import sys
from server.utils import set_httpx_config
set_httpx_config()
app = create_controller_app(
dispatch_method=FSCHAT_CONTROLLER.get("dispatch_method"),
log_level=log_level,
)
_set_app_event(app, started_event)
# add interface to release and load model worker
@app.post("/release_worker")
def release_worker(
model_name: str = Body(..., description="要释放模型的名称", samples=["chatglm-6b"]),
# worker_address: str = Body(None, description="要释放模型的地址,与名称二选一", samples=[FSCHAT_CONTROLLER_address()]),
new_model_name: str = Body(None, description="释放后加载该模型"),
keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
) -> Dict:
available_models = app._controller.list_models()
if new_model_name in available_models:
msg = f"要切换的LLM模型 {new_model_name} 已经存在"
logger.info(msg)
return {"code": 500, "msg": msg}
if new_model_name:
logger.info(f"开始切换LLM模型:从 {model_name} 到 {new_model_name}")
else:
logger.info(f"即将停止LLM模型: {model_name}")
if model_name not in available_models:
msg = f"the model {model_name} is not available"
logger.error(msg)
return {"code": 500, "msg": msg}
worker_address = app._controller.get_worker_address(model_name)
if not worker_address:
msg = f"can not find model_worker address for {model_name}"
logger.error(msg)
return {"code": 500, "msg": msg}
with get_httpx_client() as client:
r = client.post(worker_address + "/release",
json={"new_model_name": new_model_name, "keep_origin": keep_origin})
if r.status_code != 200:
msg = f"failed to release model: {model_name}"
logger.error(msg)
return {"code": 500, "msg": msg}
if new_model_name:
timer = HTTPX_DEFAULT_TIMEOUT # wait for new model_worker register
while timer > 0:
models = app._controller.list_models()
if new_model_name in models:
break
time.sleep(1)
timer -= 1
if timer > 0:
msg = f"sucess change model from {model_name} to {new_model_name}"
logger.info(msg)
return {"code": 200, "msg": msg}
else:
msg = f"failed change model from {model_name} to {new_model_name}"
logger.error(msg)
return {"code": 500, "msg": msg}
else:
msg = f"sucess to release model: {model_name}"
logger.info(msg)
return {"code": 200, "msg": msg}
host = FSCHAT_CONTROLLER["host"]
port = FSCHAT_CONTROLLER["port"]
if log_level == "ERROR":
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
uvicorn.run(app, host=host, port=port, log_level=log_level.lower())
def run_model_worker(
model_name: str = LLM_MODELS[0],
controller_address: str = "",
log_level: str = "INFO",
q: mp.Queue = None,
started_event: mp.Event = None,
):
import uvicorn
from fastapi import Body
import sys
from server.utils import set_httpx_config
set_httpx_config()
kwargs = get_model_worker_config(model_name)
host = kwargs.pop("host")
port = kwargs.pop("port")
kwargs["model_names"] = [model_name]
kwargs["controller_address"] = controller_address or fschat_controller_address()
kwargs["worker_address"] = fschat_model_worker_address(model_name)
model_path = kwargs.get("model_path", "")
kwargs["model_path"] = model_path
app = create_model_worker_app(log_level=log_level, **kwargs)
_set_app_event(app, started_event)
if log_level == "ERROR":
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
# add interface to release and load model
@app.post("/release")
def release_model(
new_model_name: str = Body(None, description="释放后加载该模型"),
keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
) -> Dict:
if keep_origin:
if new_model_name:
q.put([model_name, "start", new_model_name])
else:
if new_model_name:
q.put([model_name, "replace", new_model_name])
else:
q.put([model_name, "stop", None])
return {"code": 200, "msg": "done"}
uvicorn.run(app, host=host, port=port, log_level=log_level.lower())
def run_openai_api(log_level: str = "INFO", started_event: mp.Event = None):
import uvicorn
import sys
from server.utils import set_httpx_config
set_httpx_config()
controller_addr = fschat_controller_address()
app = create_openai_api_app(controller_addr, log_level=log_level)
_set_app_event(app, started_event)
host = FSCHAT_OPENAI_API["host"]
port = FSCHAT_OPENAI_API["port"]
if log_level == "ERROR":
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
uvicorn.run(app, host=host, port=port)
def run_api_server(started_event: mp.Event = None, run_mode: str = None):
from server.api import create_app
import uvicorn
from server.utils import set_httpx_config
set_httpx_config()
app = create_app(run_mode=run_mode)
_set_app_event(app, started_event)
host = API_SERVER["host"]
port = API_SERVER["port"]
uvicorn.run(app, host=host, port=port)
def run_webui(started_event: mp.Event = None, run_mode: str = None):
from server.utils import set_httpx_config
set_httpx_config()
host = WEBUI_SERVER["host"]
port = WEBUI_SERVER["port"]
cmd = ["streamlit", "run", "webui.py",
"--server.address", host,
"--server.port", str(port),
"--theme.base", "light",
"--theme.primaryColor", "#165dff",
"--theme.secondaryBackgroundColor", "#f5f5f5",
"--theme.textColor", "#000000",
]
if run_mode == "lite":
cmd += [
"--",
"lite",
]
p = subprocess.Popen(cmd)
started_event.set()
p.wait()
def parse_args() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument(
"-a",
"--all-webui",
action="store_true",
help="run fastchat's controller/openai_api/model_worker servers, run api.py and webui.py",
dest="all_webui",
)
parser.add_argument(
"--all-api",
action="store_true",
help="run fastchat's controller/openai_api/model_worker servers, run api.py",
dest="all_api",
)
parser.add_argument(
"--llm-api",
action="store_true",
help="run fastchat's controller/openai_api/model_worker servers",
dest="llm_api",
)
parser.add_argument(
"-o",
"--openai-api",
action="store_true",
help="run fastchat's controller/openai_api servers",
dest="openai_api",
)
parser.add_argument(
"-m",
"--model-worker",
action="store_true",
help="run fastchat's model_worker server with specified model name. "
"specify --model-name if not using default LLM_MODELS",
dest="model_worker",
)
parser.add_argument(
"-n",
"--model-name",
type=str,
nargs="+",
default=LLM_MODELS,
help="specify model name for model worker. "
"add addition names with space seperated to start multiple model workers.",
dest="model_name",
)
parser.add_argument(
"-c",
"--controller",
type=str,
help="specify controller address the worker is registered to. default is FSCHAT_CONTROLLER",
dest="controller_address",
)
parser.add_argument(
"--api",
action="store_true",
help="run api.py server",
dest="api",
)
parser.add_argument(
"-p",
"--api-worker",
action="store_true",
help="run online model api such as zhipuai",
dest="api_worker",
)
parser.add_argument(
"-w",
"--webui",
action="store_true",
help="run webui.py server",
dest="webui",
)
parser.add_argument(
"-q",
"--quiet",
action="store_true",
help="减少fastchat服务log信息",
dest="quiet",
)
parser.add_argument(
"-i",
"--lite",
action="store_true",
help="以Lite模式运行:仅支持在线API的LLM对话、搜索引擎对话",
dest="lite",
)
args = parser.parse_args()
return args, parser
def dump_server_info(after_start=False, args=None):
import platform
import langchain
import fastchat
from server.utils import api_address, webui_address
print("\n")
print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)
print(f"操作系统:{platform.platform()}.")
print(f"python版本:{sys.version}")
print(f"项目版本:{VERSION}")
print(f"langchain版本:{langchain.__version__}. fastchat版本:{fastchat.__version__}")
print("\n")
models = LLM_MODELS
if args and args.model_name:
models = args.model_name
print(f"当前使用的分词器:{TEXT_SPLITTER_NAME}")
print(f"当前启动的LLM模型:{models} @ {llm_device()}")
for model in models:
pprint(get_model_worker_config(model))
print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {embedding_device()}")
if after_start:
print("\n")
print(f"服务端运行信息:")
if args.openai_api:
print(f" OpenAI API Server: {fschat_openai_api_address()}")
if args.api:
print(f" Chatchat API Server: {api_address()}")
if args.webui:
print(f" Chatchat WEBUI Server: {webui_address()}")
print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)
print("\n")
async def start_main_server():
import time
import signal
def handler(signalname):
"""
Python 3.9 has `signal.strsignal(signalnum)` so this closure would not be needed.
Also, 3.8 includes `signal.valid_signals()` that can be used to create a mapping for the same purpose.
"""
def f(signal_received, frame):
raise KeyboardInterrupt(f"{signalname} received")
return f
# This will be inherited by the child process if it is forked (not spawned)
signal.signal(signal.SIGINT, handler("SIGINT"))
signal.signal(signal.SIGTERM, handler("SIGTERM"))
mp.set_start_method("spawn")
manager = mp.Manager()
run_mode = None
queue = manager.Queue()
args, parser = parse_args()
if args.all_webui:
args.openai_api = True
args.model_worker = True
args.api = True
args.api_worker = True
args.webui = True
elif args.all_api:
args.openai_api = True
args.model_worker = True
args.api = True
args.api_worker = True
args.webui = False
elif args.llm_api:
args.openai_api = True
args.model_worker = True
args.api_worker = True
args.api = False
args.webui = False
if args.lite:
args.model_worker = False
run_mode = "lite"
dump_server_info(args=args)
if len(sys.argv) > 1:
logger.info(f"正在启动服务:")
logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}")
processes = {"online_api": {}, "model_worker": {}}
def process_count():
return len(processes) + len(processes["online_api"]) + len(processes["model_worker"]) - 2
if args.quiet or not log_verbose:
log_level = "ERROR"
else:
log_level = "INFO"
controller_started = manager.Event()
if args.openai_api:
process = Process(
target=run_controller,
name=f"controller",
kwargs=dict(log_level=log_level, started_event=controller_started),
daemon=True,
)
processes["controller"] = process
process = Process(
target=run_openai_api,
name=f"openai_api",
daemon=True,
)
processes["openai_api"] = process
# model_worker_started = []
# if args.model_worker:
# for model_name in args.model_name:
# config = get_model_worker_config(model_name)
# if not config.get("online_api"):
# e = manager.Event()
# model_worker_started.append(e)
# process = Process(
# target=run_model_worker,
# name=f"model_worker - {model_name}",
# kwargs=dict(model_name=model_name,
# controller_address=args.controller_address,
# log_level=log_level,
# q=queue,
# started_event=e),
# daemon=True,
# )
# processes["model_worker"][model_name] = process
#
# if args.api_worker:
# for model_name in args.model_name:
# config = get_model_worker_config(model_name)
# if (config.get("online_api")
# and config.get("worker_class")
# and model_name in FSCHAT_MODEL_WORKERS):
# e = manager.Event()
# model_worker_started.append(e)
# process = Process(
# target=run_model_worker,
# name=f"api_worker - {model_name}",
# kwargs=dict(model_name=model_name,
# controller_address=args.controller_address,
# log_level=log_level,
# q=queue,
# started_event=e),
# daemon=True,
# )
# processes["online_api"][model_name] = process
api_started = manager.Event()
if args.api:
process = Process(
target=run_api_server,
name=f"API Server",
kwargs=dict(started_event=api_started, run_mode=run_mode),
daemon=True,
)
processes["api"] = process
webui_started = manager.Event()
if args.webui:
process = Process(
target=run_webui,
name=f"WEBUI Server",
kwargs=dict(started_event=webui_started, run_mode=run_mode),
daemon=True,
)
processes["webui"] = process
if process_count() == 0:
parser.print_help()
else:
try:
# 保证任务收到SIGINT后,能够正常退出
if p := processes.get("controller"):
p.start()
p.name = f"{p.name} ({p.pid})"
controller_started.wait() # 等待controller启动完成
if p := processes.get("openai_api"):
p.start()
p.name = f"{p.name} ({p.pid})"
for n, p in processes.get("model_worker", {}).items():
p.start()
p.name = f"{p.name} ({p.pid})"
for n, p in processes.get("online_api", []).items():
p.start()
p.name = f"{p.name} ({p.pid})"
# for e in model_worker_started:
# e.wait()
if p := processes.get("api"):
p.start()
p.name = f"{p.name} ({p.pid})"
api_started.wait()
if p := processes.get("webui"):
p.start()
p.name = f"{p.name} ({p.pid})"
webui_started.wait()
dump_server_info(after_start=True, args=args)
while True:
cmd = queue.get()
e = manager.Event()
if isinstance(cmd, list):
model_name, cmd, new_model_name = cmd
if cmd == "start": # 运行新模型
logger.info(f"准备启动新模型进程:{new_model_name}")
process = Process(
target=run_model_worker,
name=f"model_worker - {new_model_name}",
kwargs=dict(model_name=new_model_name,
controller_address=args.controller_address,
log_level=log_level,
q=queue,
started_event=e),
daemon=True,
)
process.start()
process.name = f"{process.name} ({process.pid})"
processes["model_worker"][new_model_name] = process
e.wait()
logger.info(f"成功启动新模型进程:{new_model_name}")
elif cmd == "stop":
if process := processes["model_worker"].get(model_name):
time.sleep(1)
process.terminate()
process.join()
logger.info(f"停止模型进程:{model_name}")
else:
logger.error(f"未找到模型进程:{model_name}")
elif cmd == "replace":
if process := processes["model_worker"].pop(model_name, None):
logger.info(f"停止模型进程:{model_name}")
start_time = datetime.now()
time.sleep(1)
process.terminate()
process.join()
process = Process(
target=run_model_worker,
name=f"model_worker - {new_model_name}",
kwargs=dict(model_name=new_model_name,
controller_address=args.controller_address,
log_level=log_level,
q=queue,
started_event=e),
daemon=True,
)
process.start()
process.name = f"{process.name} ({process.pid})"
processes["model_worker"][new_model_name] = process
e.wait()
timing = datetime.now() - start_time
logger.info(f"成功启动新模型进程:{new_model_name}。用时:{timing}。")
else:
logger.error(f"未找到模型进程:{model_name}")
# for process in processes.get("model_worker", {}).values():
# process.join()
# for process in processes.get("online_api", {}).values():
# process.join()
# for name, process in processes.items():
# if name not in ["model_worker", "online_api"]:
# if isinstance(p, dict):
# for work_process in p.values():
# work_process.join()
# else:
# process.join()
except Exception as e:
logger.error(e)
logger.warning("Caught KeyboardInterrupt! Setting stop event...")
finally:
for p in processes.values():
logger.warning("Sending SIGKILL to %s", p)
# Queues and other inter-process communication primitives can break when
# process is killed, but we don't care here
if isinstance(p, dict):
for process in p.values():
process.kill()
else:
p.kill()
for p in processes.values():
logger.info("Process status: %s", p)
if __name__ == "__main__":
create_tables()
if sys.version_info < (3, 10):
loop = asyncio.get_event_loop()
else:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(start_main_server())
# 服务启动后接口调用示例:
# import openai
# openai.api_key = "EMPTY" # Not support yet
# openai.api_base = "http://localhost:8888/v1"
# model = "chatglm3-6b"
# # create a chat completion
# completion = openai.ChatCompletion.create(
# model=model,
# messages=[{"role": "user", "content": "Hello! What is your name?"}]
# )
# # print the completion
# print(completion.choices[0].message.content)
随后启动起来长这样:
当然大模型对话还是不能用的,因为根本没加载运行大模型。不过亲测向量知识库可以用。我就往知识库里传了个tmp.txt文件。
Web服务这边也是显示向量嵌入正常。