LangChain简明使用笔记(4)认知结构

第四部分 认知架构

架构1:模型调用

from typing import Annotated, TypedDict
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langchain_ollama import ChatOllama
from langchain_core.messages import HumanMessage
​
# 定义模型
model = ChatOllama(model="deepseek-r1:32b", temperature=0)
​
# 定义状态类型
class State(TypedDict):
​
    messages: Annotated[list, add_messages]
​
# 定义聊天机器人函数
def chatbot(state: State):
    answer = model.invoke(state["messages"])
    return {"messages": [answer]}
​
# 创建状态图构建器
builder = StateGraph(State)
​
# 添加聊天机器人节点
builder.add_node("chatbot", chatbot)
​
# 添加边
builder.add_edge(START, "chatbot")
builder.add_edge("chatbot", END)
​
# 编译状态图
graph = builder.compile()
​
# 示例用法
input = {"messages": [HumanMessage("你好!")]}
for chunk in graph.stream(input):
    print(chunk)

架构2:链

链按照预定义的顺序调用多个大语言模型

from typing import Annotated, TypedDict
​
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_ollama import ChatOllama
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
​
# 低温度模型,用于生成SQL查询
model_low_temp = ChatOllama(model="deepseek-r1:32b", temperature=0.1)
# 高温度模型,用于生成自然语言输出
model_high_temp = ChatOllama(model="deepseek-r1:32b", temperature=0.7)
​
​
# 定义状态类型
class State(TypedDict):
    # 跟踪对话历史
    messages: Annotated[list, add_messages]
    # 输入
    user_query: str
    # 输出
    sql_query: str
    sql_explanation: str
​
​
# 定义输入类型
class Input(TypedDict):
    user_query: str
​
​
# 定义输出类型
class Output(TypedDict):
    sql_query: str
    sql_explanation: str
​
​
# 生成SQL查询的提示
generate_prompt = SystemMessage(
    "你是一个乐于助人的数据分析师,负责根据用户的问题生成SQL查询。"
)
​
​
# 生成SQL查询的函数
def generate_sql(state: State) -> State:
    user_message = HumanMessage(state["user_query"])
    messages = [generate_prompt, *state["messages"], user_message]
    res = model_low_temp.invoke(messages)
    return {
        "sql_query": res.content,
        "messages": [user_message, res],
    }
​
​
# 解释SQL查询的提示
explain_prompt = SystemMessage(
    "你是一个乐于助人的数据分析师,负责向用户解释SQL查询。"
)
​
​
# 解释SQL查询的函数 - 这里是唯一的修改点
def explain_sql(state: State) -> State:
    # 创建明确的解释请求
    explanation_request = HumanMessage(
        f"请解释以下SQL查询是如何解决我的问题的:\n\n{state['sql_query']}\n\n我的原始问题是:{state['user_query']}"
    )
​
    messages = [
        explain_prompt,
        explanation_request,  # 添加明确的解释请求
    ]
​
    res = model_high_temp.invoke(messages)
​
    return {
        "sql_explanation": res.content,
        "messages": [*state["messages"], explanation_request, res],  # 保留完整历史
    }
​
​
# 创建状态图构建器
builder = StateGraph(State, input=Input, output=Output)
builder.add_node("generate_sql", generate_sql)
builder.add_node("explain_sql", explain_sql)
builder.add_edge(START, "generate_sql")
builder.add_edge("generate_sql", "explain_sql")
builder.add_edge("explain_sql", END)
​
# 编译状态图
graph = builder.compile()
​
# 示例用法
result = graph.invoke({"user_query": "每个产品的总销售额是多少?"})
print(result)
# 生成SQL查询的函数
def generate_sql(state: State) -> State:
    # 从state字典中访问"user_query"键,获取用户问题 并将其封装成HumanMessage类
    user_message = HumanMessage(state["user_query"])
    # 构建完整的对话上下文:
    # 1. 系统提示(generate_prompt)
    # 2. 历史对话消息(state["messages"])
    # 使用解包操作符*将所有之前的对话消息展开并添加到列表中
    # 3. 新的用户消息
    messages = [generate_prompt, *state["messages"], user_message]
    res = model_low_temp.invoke(messages)
    return {
        "sql_query": res.content,
        # 更新对话历史
        "messages": [user_message, res],
    }

架构3:路由

from typing import Annotated, Literal, TypedDict
​
from langchain_core.documents import Document
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.vectorstores.in_memory import InMemoryVectorStore
from langchain_ollama import ChatOllama, OllamaEmbeddings
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
​
​
embeddings = OllamaEmbeddings(
    model="nomic-embed-text:latest",
)
​
# 低温度模型,用于生成SQL查询
model_low_temp = ChatOllama(model="deepseek-r1:32b", temperature=0.1)
# 高温度模型,用于生成自然语言输出
model_high_temp = ChatOllama(model="deepseek-r1:32b", temperature=0.7)
​
# 定义状态类型
class State(TypedDict):
    # 跟踪对话历史
    messages: Annotated[list, add_messages]
    # 输入
    user_query: str
    # 只能从这两个领域进行选择
    domain: Literal["records", "insurance"]
    documents: list[Document]
    answer: str
​
# 定义输入类型
class Input(TypedDict):
    user_query: str
​
# 定义输出类型
class Output(TypedDict):
    documents: list[Document]
    answer: str
​
# 示例文档用于测试
sample_docs = [
    Document(page_content="患者病历...", metadata={"domain": "records"}),
    Document(
        page_content="保险政策详情...", metadata={"domain": "insurance"}
    ),
]
​
# 初始化向量存储
medical_records_store = InMemoryVectorStore.from_documents(sample_docs, embeddings)
medical_records_retriever = medical_records_store.as_retriever()
​
insurance_faqs_store = InMemoryVectorStore.from_documents(sample_docs, embeddings)
insurance_faqs_retriever = insurance_faqs_store.as_retriever()
​
router_prompt = SystemMessage(
    """你需要决定将用户查询路由到哪个领域。有两个领域可以选择:
- records: 包含患者的病历,如诊断、治疗和处方。
- insurance: 包含关于保险政策、索赔和覆盖范围的常见问题。
​
只输出领域名称。"""
)
​
# 定义路由节点
def router_node(state: State) -> State:
    user_message = HumanMessage(state["user_query"])
    messages = [router_prompt, *state["messages"], user_message]
    res = model_low_temp.invoke(messages)
    return {
        "domain": res.content,
        # 更新对话历史
        "messages": [user_message, res],
    }
​
# 选择检索器
def pick_retriever(
    state: State,
) -> Literal["retrieve_medical_records", "retrieve_insurance_faqs"]:
    if state["domain"] == "records":
        return "retrieve_medical_records"
    else:
        return "retrieve_insurance_faqs"
​
# 检索医疗记录
def retrieve_medical_records(state: State) -> State:
    documents = medical_records_retriever.invoke(state["user_query"])
    return {
        "documents": documents,
    }
​
# 检索保险常见问题
def retrieve_insurance_faqs(state: State) -> State:
    documents = insurance_faqs_retriever.invoke(state["user_query"])
    return {
        "documents": documents,
    }
​
# 医疗记录提示
medical_records_prompt = SystemMessage(
    "你是一个乐于助人的医疗聊天机器人,基于患者的病历回答问题,如诊断、治疗和处方。"
)
​
# 保险常见问题提示
insurance_faqs_prompt = SystemMessage(
    "你是一个乐于助人的医疗保险聊天机器人,回答关于保险政策、索赔和覆盖范围的常见问题。"
)
​
# 生成回答
def generate_answer(state: State) -> State:
    if state["domain"] == "records":
        prompt = medical_records_prompt
    else:
        prompt = insurance_faqs_prompt
    messages = [
        prompt,
        *state["messages"],
        HumanMessage(f"Documents: {state['documents']}"),
    ]
    res = model_high_temp.invoke(messages)
    return {
        "answer": res.content,
        # 更新对话历史
        "messages": res,
    }
​
# 创建状态图构建器
builder = StateGraph(State, input=Input, output=Output)
builder.add_node("router", router_node)
builder.add_node("retrieve_medical_records", retrieve_medical_records)
builder.add_node("retrieve_insurance_faqs", retrieve_insurance_faqs)
builder.add_node("generate_answer", generate_answer)
builder.add_edge(START, "router")
builder.add_conditional_edges("router", pick_retriever)
builder.add_edge("retrieve_medical_records", "generate_answer")
builder.add_edge("retrieve_insurance_faqs", "generate_answer")
builder.add_edge("generate_answer", END)
​
# 编译状态图
graph = builder.compile()
​
# 示例用法
input = {"user_query": "我是否覆盖COVID-19治疗?"}
for chunk in graph.stream(input):
    print(chunk)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值