第四部分 认知架构
架构1:模型调用
from typing import Annotated, TypedDict
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langchain_ollama import ChatOllama
from langchain_core.messages import HumanMessage
# 定义模型
model = ChatOllama(model="deepseek-r1:32b", temperature=0)
# 定义状态类型
class State(TypedDict):
messages: Annotated[list, add_messages]
# 定义聊天机器人函数
def chatbot(state: State):
answer = model.invoke(state["messages"])
return {"messages": [answer]}
# 创建状态图构建器
builder = StateGraph(State)
# 添加聊天机器人节点
builder.add_node("chatbot", chatbot)
# 添加边
builder.add_edge(START, "chatbot")
builder.add_edge("chatbot", END)
# 编译状态图
graph = builder.compile()
# 示例用法
input = {"messages": [HumanMessage("你好!")]}
for chunk in graph.stream(input):
print(chunk)
架构2:链
链按照预定义的顺序调用多个大语言模型
from typing import Annotated, TypedDict
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_ollama import ChatOllama
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
# 低温度模型,用于生成SQL查询
model_low_temp = ChatOllama(model="deepseek-r1:32b", temperature=0.1)
# 高温度模型,用于生成自然语言输出
model_high_temp = ChatOllama(model="deepseek-r1:32b", temperature=0.7)
# 定义状态类型
class State(TypedDict):
# 跟踪对话历史
messages: Annotated[list, add_messages]
# 输入
user_query: str
# 输出
sql_query: str
sql_explanation: str
# 定义输入类型
class Input(TypedDict):
user_query: str
# 定义输出类型
class Output(TypedDict):
sql_query: str
sql_explanation: str
# 生成SQL查询的提示
generate_prompt = SystemMessage(
"你是一个乐于助人的数据分析师,负责根据用户的问题生成SQL查询。"
)
# 生成SQL查询的函数
def generate_sql(state: State) -> State:
user_message = HumanMessage(state["user_query"])
messages = [generate_prompt, *state["messages"], user_message]
res = model_low_temp.invoke(messages)
return {
"sql_query": res.content,
"messages": [user_message, res],
}
# 解释SQL查询的提示
explain_prompt = SystemMessage(
"你是一个乐于助人的数据分析师,负责向用户解释SQL查询。"
)
# 解释SQL查询的函数 - 这里是唯一的修改点
def explain_sql(state: State) -> State:
# 创建明确的解释请求
explanation_request = HumanMessage(
f"请解释以下SQL查询是如何解决我的问题的:\n\n{state['sql_query']}\n\n我的原始问题是:{state['user_query']}"
)
messages = [
explain_prompt,
explanation_request, # 添加明确的解释请求
]
res = model_high_temp.invoke(messages)
return {
"sql_explanation": res.content,
"messages": [*state["messages"], explanation_request, res], # 保留完整历史
}
# 创建状态图构建器
builder = StateGraph(State, input=Input, output=Output)
builder.add_node("generate_sql", generate_sql)
builder.add_node("explain_sql", explain_sql)
builder.add_edge(START, "generate_sql")
builder.add_edge("generate_sql", "explain_sql")
builder.add_edge("explain_sql", END)
# 编译状态图
graph = builder.compile()
# 示例用法
result = graph.invoke({"user_query": "每个产品的总销售额是多少?"})
print(result)
# 生成SQL查询的函数
def generate_sql(state: State) -> State:
# 从state字典中访问"user_query"键,获取用户问题 并将其封装成HumanMessage类
user_message = HumanMessage(state["user_query"])
# 构建完整的对话上下文:
# 1. 系统提示(generate_prompt)
# 2. 历史对话消息(state["messages"])
# 使用解包操作符*将所有之前的对话消息展开并添加到列表中
# 3. 新的用户消息
messages = [generate_prompt, *state["messages"], user_message]
res = model_low_temp.invoke(messages)
return {
"sql_query": res.content,
# 更新对话历史
"messages": [user_message, res],
}
架构3:路由
from typing import Annotated, Literal, TypedDict
from langchain_core.documents import Document
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.vectorstores.in_memory import InMemoryVectorStore
from langchain_ollama import ChatOllama, OllamaEmbeddings
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
embeddings = OllamaEmbeddings(
model="nomic-embed-text:latest",
)
# 低温度模型,用于生成SQL查询
model_low_temp = ChatOllama(model="deepseek-r1:32b", temperature=0.1)
# 高温度模型,用于生成自然语言输出
model_high_temp = ChatOllama(model="deepseek-r1:32b", temperature=0.7)
# 定义状态类型
class State(TypedDict):
# 跟踪对话历史
messages: Annotated[list, add_messages]
# 输入
user_query: str
# 只能从这两个领域进行选择
domain: Literal["records", "insurance"]
documents: list[Document]
answer: str
# 定义输入类型
class Input(TypedDict):
user_query: str
# 定义输出类型
class Output(TypedDict):
documents: list[Document]
answer: str
# 示例文档用于测试
sample_docs = [
Document(page_content="患者病历...", metadata={"domain": "records"}),
Document(
page_content="保险政策详情...", metadata={"domain": "insurance"}
),
]
# 初始化向量存储
medical_records_store = InMemoryVectorStore.from_documents(sample_docs, embeddings)
medical_records_retriever = medical_records_store.as_retriever()
insurance_faqs_store = InMemoryVectorStore.from_documents(sample_docs, embeddings)
insurance_faqs_retriever = insurance_faqs_store.as_retriever()
router_prompt = SystemMessage(
"""你需要决定将用户查询路由到哪个领域。有两个领域可以选择:
- records: 包含患者的病历,如诊断、治疗和处方。
- insurance: 包含关于保险政策、索赔和覆盖范围的常见问题。
只输出领域名称。"""
)
# 定义路由节点
def router_node(state: State) -> State:
user_message = HumanMessage(state["user_query"])
messages = [router_prompt, *state["messages"], user_message]
res = model_low_temp.invoke(messages)
return {
"domain": res.content,
# 更新对话历史
"messages": [user_message, res],
}
# 选择检索器
def pick_retriever(
state: State,
) -> Literal["retrieve_medical_records", "retrieve_insurance_faqs"]:
if state["domain"] == "records":
return "retrieve_medical_records"
else:
return "retrieve_insurance_faqs"
# 检索医疗记录
def retrieve_medical_records(state: State) -> State:
documents = medical_records_retriever.invoke(state["user_query"])
return {
"documents": documents,
}
# 检索保险常见问题
def retrieve_insurance_faqs(state: State) -> State:
documents = insurance_faqs_retriever.invoke(state["user_query"])
return {
"documents": documents,
}
# 医疗记录提示
medical_records_prompt = SystemMessage(
"你是一个乐于助人的医疗聊天机器人,基于患者的病历回答问题,如诊断、治疗和处方。"
)
# 保险常见问题提示
insurance_faqs_prompt = SystemMessage(
"你是一个乐于助人的医疗保险聊天机器人,回答关于保险政策、索赔和覆盖范围的常见问题。"
)
# 生成回答
def generate_answer(state: State) -> State:
if state["domain"] == "records":
prompt = medical_records_prompt
else:
prompt = insurance_faqs_prompt
messages = [
prompt,
*state["messages"],
HumanMessage(f"Documents: {state['documents']}"),
]
res = model_high_temp.invoke(messages)
return {
"answer": res.content,
# 更新对话历史
"messages": res,
}
# 创建状态图构建器
builder = StateGraph(State, input=Input, output=Output)
builder.add_node("router", router_node)
builder.add_node("retrieve_medical_records", retrieve_medical_records)
builder.add_node("retrieve_insurance_faqs", retrieve_insurance_faqs)
builder.add_node("generate_answer", generate_answer)
builder.add_edge(START, "router")
builder.add_conditional_edges("router", pick_retriever)
builder.add_edge("retrieve_medical_records", "generate_answer")
builder.add_edge("retrieve_insurance_faqs", "generate_answer")
builder.add_edge("generate_answer", END)
# 编译状态图
graph = builder.compile()
# 示例用法
input = {"user_query": "我是否覆盖COVID-19治疗?"}
for chunk in graph.stream(input):
print(chunk)