LangChain开发【NL2SQL】应用

前言

关于LangGraph的简单介绍,请参考这篇博客:

LangGraph开发Agent智能体应用【基础聊天机器人】-优快云博客

对比LangChain实现NL2SQL

关于用LangChain开发NL2SQL的Agent应用,在这篇博客提供了完整的代码实现:

LangChain开发LLM应用【入门指南】_langchain 开发社区-优快云博客

我在这里赘述一下:

from sqlalchemy import create_engine
from langchain_community.utilities import SQLDatabase
 
# 数据库连接信息
username = 'root'
password = 'MyNewPass1!'
host = 'desk04v.mlprod.bjpdc.qihoo.net'
port = '3306'
database = 'test'
 
engine = create_engine(f'mysql+mysqlconnector://{username}:{password}@{host}:{port}/{database}')
db = SQLDatabase(engine)
result = db.run("select * FROM courses LIMIT 5;")
print(result)
 
from langchain_community.agent_toolkits import create_sql_agent
from langchain_openai import ChatOpenAI
 
llm = ChatOpenAI(model="gpt-4o", temperature=0)
agent_executor = create_sql_agent(llm, db=db, agent_type="openai-tools", verbose=True)
 
agent_executor.invoke(
    "找到学分最高的课程"
)

效果如下:

我并没有在代码中定义执行链,只是给LLM提供了一个工具集,让Agent自行决定如何使用。

可以看到,Agent,先查了一下数据库中有哪些表,找到看上去有用的表后,再查了一下表结构以及预览数据,再生成sql执行(先校验一遍再执行),最后整合结果告诉我结论。

可以说已经是非常智能了。

LangGraph实现实现NL2SQL

LangGraph的方式,就和LangChain不一样了,它的开发方式就是不断给图添加“节点”和“线”,组成一个工作流。

注意:这里的工作流,并不是简单理解的操作流,LangGraph的工作流和LangChain的工作流不是是一个层面的东西,相信你看完这个例子就能感受到了。

还是先上代码!!!

第一步:定义工具集合

LangChain 和 LangGraph是打通的(准确的说,LangGraph是LangChain生态的高级框架)

所以我们可以直接使用LangChain的工具集 SQLDatabaseToolkit

如果你愿意深入看看源码,就知道这个工具集里有四个工具:
执行sql:QuerySQLDataBaseTool 
查看表详情:InfoSQLDatabaseTool
sql语法检查:QuerySQLCheckerTool
查看所有表:ListSQLDatabaseTool

from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_openai import ChatOpenAI
from sqlalchemy import create_engine
from langchain_community.utilities import SQLDatabase

# 数据库连接信息
username = 'root'
password = 'MyNewPass1!'
host = 'desk04v.mlprod.bjpdc.qihoo.net'
port = '3306'
database = 'test'

engine = create_engine(f'mysql+mysqlconnector://{username}:{password}@{host}:{port}/{database}')
db = SQLDatabase(engine)
toolkit = SQLDatabaseToolkit(db=db, llm=ChatOpenAI(temperature=0))
context = toolkit.get_context()
tools = toolkit.get_tools()

第二步:定义LLM节点,并加入到图中

让LLM绑定工具,一定要绑定,就像你需要告诉LLM,可以使用哪些工具,LLM才会生成调用计划

from typing import Annotated

from langchain_openai import ChatOpenAI
from typing_extensions import TypedDict

from langgraph.graph import StateGraph
from langgraph.graph.message import add_messages


class State(TypedDict):
    messages: Annotated[list, add_messages]


graph_builder = StateGraph(State)

# expt_llm = "gpt-4-1106-preview"
expt_llm = "gpt-4o"
llm = ChatOpenAI(temperature=0, model=expt_llm)
# Modification: tell the LLM which tools it can call
llm_with_tools = llm.bind_tools(tools)


def chatbot(state: State):
    return {"messages": [llm_with_tools.invoke(state["messages"])]}


graph_builder.add_node("chatbot", chatbot)

第三步:定义工具节点,并加入到图中

import json

from langchain_core.messages import ToolMessage


class BasicToolNode:
    """运行最后一个AIMessage中请求的工具"""

    def __init__(self, tools: list) -> None:
        self.tools_by_name = {tool.name: tool for tool in tools}

    def __call__(self, inputs: dict):
        if messages := inputs.get("messages", []):
            message = messages[-1]
        else:
            raise ValueError("No message found in input")
        outputs = []
        for tool_call in message.tool_calls:
            print(tool_call["name"])
            print(self.tools_by_name[tool_call["name"]])
            tool_result = self.tools_by_name[tool_call["name"]].invoke(
                tool_call["args"]
            )
            outputs.append(
                ToolMessage(
                    content=json.dumps(tool_result),
                    name=tool_call["name"],
                    tool_call_id=tool_call["id"],
                )
            )
        return {"messages": outputs}


tool_node = BasicToolNode(tools=tools)
graph_builder.add_node("tools", tool_node)

第四步:定义“边”

add_edge方法是直接定义“边”,在例子中表示tools -> chatbot

add_conditional_edges方法是增加条件路由“边”,在例子中表示chatbot根据情况 -> tools 或者 -> __end__

from typing import Literal


def route_tools(
    state: State,
) -> Literal["tools", "__end__"]:
    """如果最后一条消息,在conditional_edge中使用路由到ToolNode,就调用工具。否则,路线到终点。"""
    if isinstance(state, list):
        ai_message = state[-1]
    elif messages := state.get("messages", []):
        ai_message = messages[-1]
    else:
        raise ValueError(f"在tool_edge的输入状态中没有找到消息: {state}")
    if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0:
        return "tools"
    return "__end__"


# ' tools_condition '函数返回"tools",表示LLM要求使用工具,返回"__end__"直接结束。
graph_builder.add_conditional_edges(
    "chatbot",
    route_tools,
    # The following dictionary lets you tell the graph to interpret the condition's outputs as a specific node
    # It defaults to the identity function, but if you
    # want to use a node named something else apart from "tools",
    # You can update the value of the dictionary to something else
    # e.g., "tools": "my_tools"
    {"tools": "tools", "__end__": "__end__"},
)
# 任何时候调用一个工具,我们都会流转到聊天机器人
graph_builder.add_edge("tools", "chatbot")
graph_builder.set_entry_point("chatbot")
graph = graph_builder.compile()

第五步:把图画出来(非必需)

from IPython.display import Image, display

try:
    display(Image(graph.get_graph().draw_mermaid_png()))
except:
    # This requires some extra dependencies and is optional
    pass

效果如下:

整个流程很简单,用大白话讲,就是:

把提问信息传给LLM,LLM决定用什么工具,然后graph就调用工具返回结果传给LLM,LLM拿到结果后有可能继续调用工具,也有可能直接输出答案,如此循环或者终止。

第六步:执行

通过流式调用,传入用户的提问

from langchain_core.messages import BaseMessage

while True:
    user_input = input("User: ")
    if user_input.lower() in ["quit", "exit", "q"]:
        print("Goodbye!")
        break
    for event in graph.stream({"messages": [("user", user_input)]}):
        for value in event.values():
            if isinstance(value["messages"][-1], BaseMessage):
                print("Assistant:", value["messages"][-1].content)

效果如下:

User:  找到5月各地区的运费最贵的用户
Assistant: 
sql_db_list_tables
db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x7fd64efb1f90>
Assistant: "arbitraments, courses, orders, scores, sink_chunjun_1, source_chunjun_1, students, test_binlog_1"
Assistant: 
sql_db_schema
description='Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3' db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x7fd64efb1f90>
Assistant: "\nCREATE TABLE orders (\n\torder_id INTEGER NOT NULL COMMENT '\u8ba2\u5355ID', \n\tcustomer_id VARCHAR(255) COMMENT '\u5ba2\u6237ID', \n\temployee_id INTEGER COMMENT '\u5458\u5de5ID', \n\torder_date DATE COMMENT '\u8ba2\u5355\u65e5\u671f', \n\trequired_date DATE COMMENT '\u8981\u6c42\u4ea4\u8d27\u65e5\u671f', \n\tshipped_date DATE COMMENT '\u53d1\u8d27\u65e5\u671f', \n\tshipper_id INTEGER COMMENT '\u53d1\u8d27\u65b9\u5f0f', \n\tfreight DECIMAL(10, 2) COMMENT '\u8fd0\u8d39', \n\tship_name VARCHAR(255) COMMENT '\u6536\u8d27\u4eba\u540d\u79f0', \n\tship_address VARCHAR(255) COMMENT '\u6536\u8d27\u5730\u5740', \n\tship_city VARCHAR(255) COMMENT '\u6536\u8d27\u57ce\u5e02', \n\tship_region VARCHAR(255) COMMENT '\u6536\u8d27\u5730\u533a', \n\tship_postal_code VARCHAR(255) COMMENT '\u6536\u8d27\u90ae\u7f16', \n\torder_status VARCHAR(50) COMMENT '\u8ba2\u5355\u72b6\u6001', \n\tsnapshot_timestamp TIMESTAMP NULL COMMENT '\u5feb\u7167\u65f6\u95f4\u6233' DEFAULT CURRENT_TIMESTAMP, \n\tPRIMARY KEY (order_id)\n)DEFAULT CHARSET=utf8mb4 COLLATE utf8mb4_0900_ai_ci ENGINE=InnoDB\n\n/*\n3 rows from orders table:\norder_id\tcustomer_id\temployee_id\torder_date\trequired_date\tshipped_date\tshipper_id\tfreight\tship_name\tship_address\tship_city\tship_region\tship_postal_code\torder_status\tsnapshot_timestamp\n1\tCUST001\t1\t2024-05-01\t2024-05-05\t2024-05-03\t1\t100.50\t\u5f20\u4e09\t\u5317\u4eac\u5e02\u671d\u9633\u533a\u5efa\u56fd\u8def100\u53f7\t\u5317\u4eac\t\u534e\u5317\t100022\t\u5df2\u53d1\u8d27\t2024-06-04 17:05:11\n2\tCUST002\t2\t2024-05-02\t2024-05-06\t2024-05-04\t2\t200.75\t\u674e\u56db\t\u4e0a\u6d77\u5e02\u6d66\u4e1c\u65b0\u533a\u4e16\u7eaa\u5927\u9053200\u53f7\t\u4e0a\u6d77\t\u534e\u4e1c\t200120\t\u5df2\u53d1\u8d27\t2024-06-04 17:05:11\n3\tCUST003\t3\t2024-05-03\t2024-05-07\t2024-05-05\t3\t150.00\t\u738b\u4e94\t\u5e7f\u5dde\u5e02\u5929\u6cb3\u533a\u4f53\u80b2\u897f\u8def300\u53f7\t\u5e7f\u5dde\t\u534e\u5357\t510620\t\u5df2\u53d1\u8d27\t2024-06-04 17:05:11\n*/"
Assistant: 
sql_db_query_checker
description='Use this tool to double check if your query is correct before executing it. Always use this tool before executing a query with sql_db_query!' db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x7fd64efb1f90> llm=ChatOpenAI(client=<openai.resources.chat.completions.Completions object at 0x7fd64e081310>, async_client=<openai.resources.chat.completions.AsyncCompletions object at 0x7fd64e099350>, temperature=0.0, openai_api_key=SecretStr('**********'), openai_api_base='https://api.360.cn/v1', openai_proxy='') llm_chain=LLMChain(prompt=PromptTemplate(input_variables=['dialect', 'query'], template='\n{query}\nDouble check the {dialect} query above for common mistakes, including:\n- Using NOT IN with NULL values\n- Using UNION when UNION ALL should have been used\n- Using BETWEEN for exclusive ranges\n- Data type mismatch in predicates\n- Properly quoting identifiers\n- Using the correct number of arguments for functions\n- Casting to the correct data type\n- Using the proper columns for joins\n\nIf there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.\n\nOutput the final SQL query only.\n\nSQL Query: '), llm=ChatOpenAI(client=<openai.resources.chat.completions.Completions object at 0x7fd64e081310>, async_client=<openai.resources.chat.completions.AsyncCompletions object at 0x7fd64e099350>, temperature=0.0, openai_api_key=SecretStr('**********'), openai_api_base='https://api.360.cn/v1', openai_proxy=''))
Assistant: "SELECT ship_region, customer_id, freight FROM orders \nWHERE MONTH(order_date) = 5 \nORDER BY freight DESC \nLIMIT 1;"
Assistant: 
sql_db_query
description="Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use sql_db_schema to query the correct table fields." db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x7fd64efb1f90>
Assistant: "[('\u534e\u4e1c', 'CUST008', Decimal('300.80'))]"
Assistant: 在5月份各地区中,运费最贵的用户信息如下:

- 地区:华东
- 用户ID:CUST008
- 运费:300.80元
User:  统计5月的每个地区的运费最贵的用户
Assistant: 
sql_db_list_tables
db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x7fd64efb1f90>
Assistant: "arbitraments, courses, orders, scores, sink_chunjun_1, source_chunjun_1, students, test_binlog_1"
Assistant: 
sql_db_schema
description='Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3' db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x7fd64efb1f90>
Assistant: "\nCREATE TABLE orders (\n\torder_id INTEGER NOT NULL COMMENT '\u8ba2\u5355ID', \n\tcustomer_id VARCHAR(255) COMMENT '\u5ba2\u6237ID', \n\temployee_id INTEGER COMMENT '\u5458\u5de5ID', \n\torder_date DATE COMMENT '\u8ba2\u5355\u65e5\u671f', \n\trequired_date DATE COMMENT '\u8981\u6c42\u4ea4\u8d27\u65e5\u671f', \n\tshipped_date DATE COMMENT '\u53d1\u8d27\u65e5\u671f', \n\tshipper_id INTEGER COMMENT '\u53d1\u8d27\u65b9\u5f0f', \n\tfreight DECIMAL(10, 2) COMMENT '\u8fd0\u8d39', \n\tship_name VARCHAR(255) COMMENT '\u6536\u8d27\u4eba\u540d\u79f0', \n\tship_address VARCHAR(255) COMMENT '\u6536\u8d27\u5730\u5740', \n\tship_city VARCHAR(255) COMMENT '\u6536\u8d27\u57ce\u5e02', \n\tship_region VARCHAR(255) COMMENT '\u6536\u8d27\u5730\u533a', \n\tship_postal_code VARCHAR(255) COMMENT '\u6536\u8d27\u90ae\u7f16', \n\torder_status VARCHAR(50) COMMENT '\u8ba2\u5355\u72b6\u6001', \n\tsnapshot_timestamp TIMESTAMP NULL COMMENT '\u5feb\u7167\u65f6\u95f4\u6233' DEFAULT CURRENT_TIMESTAMP, \n\tPRIMARY KEY (order_id)\n)DEFAULT CHARSET=utf8mb4 COLLATE utf8mb4_0900_ai_ci ENGINE=InnoDB\n\n/*\n3 rows from orders table:\norder_id\tcustomer_id\temployee_id\torder_date\trequired_date\tshipped_date\tshipper_id\tfreight\tship_name\tship_address\tship_city\tship_region\tship_postal_code\torder_status\tsnapshot_timestamp\n1\tCUST001\t1\t2024-05-01\t2024-05-05\t2024-05-03\t1\t100.50\t\u5f20\u4e09\t\u5317\u4eac\u5e02\u671d\u9633\u533a\u5efa\u56fd\u8def100\u53f7\t\u5317\u4eac\t\u534e\u5317\t100022\t\u5df2\u53d1\u8d27\t2024-06-04 17:05:11\n2\tCUST002\t2\t2024-05-02\t2024-05-06\t2024-05-04\t2\t200.75\t\u674e\u56db\t\u4e0a\u6d77\u5e02\u6d66\u4e1c\u65b0\u533a\u4e16\u7eaa\u5927\u9053200\u53f7\t\u4e0a\u6d77\t\u534e\u4e1c\t200120\t\u5df2\u53d1\u8d27\t2024-06-04 17:05:11\n3\tCUST003\t3\t2024-05-03\t2024-05-07\t2024-05-05\t3\t150.00\t\u738b\u4e94\t\u5e7f\u5dde\u5e02\u5929\u6cb3\u533a\u4f53\u80b2\u897f\u8def300\u53f7\t\u5e7f\u5dde\t\u534e\u5357\t510620\t\u5df2\u53d1\u8d27\t2024-06-04 17:05:11\n*/"
Assistant: 
sql_db_query_checker
description='Use this tool to double check if your query is correct before executing it. Always use this tool before executing a query with sql_db_query!' db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x7fd64efb1f90> llm=ChatOpenAI(client=<openai.resources.chat.completions.Completions object at 0x7fd64e081310>, async_client=<openai.resources.chat.completions.AsyncCompletions object at 0x7fd64e099350>, temperature=0.0, openai_api_key=SecretStr('**********'), openai_api_base='https://api.360.cn/v1', openai_proxy='') llm_chain=LLMChain(prompt=PromptTemplate(input_variables=['dialect', 'query'], template='\n{query}\nDouble check the {dialect} query above for common mistakes, including:\n- Using NOT IN with NULL values\n- Using UNION when UNION ALL should have been used\n- Using BETWEEN for exclusive ranges\n- Data type mismatch in predicates\n- Properly quoting identifiers\n- Using the correct number of arguments for functions\n- Casting to the correct data type\n- Using the proper columns for joins\n\nIf there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.\n\nOutput the final SQL query only.\n\nSQL Query: '), llm=ChatOpenAI(client=<openai.resources.chat.completions.Completions object at 0x7fd64e081310>, async_client=<openai.resources.chat.completions.AsyncCompletions object at 0x7fd64e099350>, temperature=0.0, openai_api_key=SecretStr('**********'), openai_api_base='https://api.360.cn/v1', openai_proxy=''))
Assistant: "SELECT ship_region, customer_id, MAX(freight) as max_freight\nFROM orders\nWHERE order_date BETWEEN '2024-05-01' AND '2024-05-31'\nGROUP BY ship_region, customer_id"
Assistant: 
sql_db_query
description="Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use sql_db_schema to query the correct table fields." db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x7fd64efb1f90>
Assistant: "[('\u534e\u5317', 'CUST001', Decimal('100.50')), ('\u534e\u4e1c', 'CUST002', Decimal('200.75')), ('\u534e\u5357', 'CUST003', Decimal('150.00')), ('\u534e\u5357', 'CUST004', Decimal('120.25')), ('\u897f\u5357', 'CUST005', Decimal('180.90')), ('\u897f\u5357', 'CUST006', Decimal('250.60')), ('\u534e\u4e1c', 'CUST007', Decimal('90.45')), ('\u534e\u4e1c', 'CUST008', Decimal('300.80')), ('\u534e\u4e2d', 'CUST009', Decimal('220.30')), ('\u897f\u5317', 'CUST010', Decimal('170.95'))]"
Assistant: 以下是5月每个地区运费最贵的用户:

| 地区   | 客户ID  | 最大运费 (¥) |
|-------|--------|--------------|
| 华北   | CUST001 | 100.50       |
| 华东   | CUST008 | 300.80       |
| 华南   | CUST003 | 150.00       |
| 西南   | CUST006 | 250.60       |
| 华中   | CUST009 | 220.30       |
| 西北   | CUST010 | 170.95       |

可以看到,每个地区的运费最高的用户及其对应的运费如上所示。
User:  q
Goodbye!

总结

不知道你有没有发现一个神奇的现象:

对于问题:找到5月各地区的运费最贵的用户

assistant生成的sql其实是错的

SELECT ship_region, customer_id, MAX(freight) as max_freight
FROM orders
WHERE order_date BETWEEN '2024-05-01' AND '2024-05-31'
GROUP BY ship_region, customer_id

这是我在Leetcode上找的一个中等难度的sql题

assistant生成的这个sql,执行的结果,应该会出现同一个地区有多个用户的情况,但是最后给我的输出答案确实对的。。。

原因是:chatgpt4o 耍诈,它没有能力生成这么复杂sql,但是可以写一个中间结果的sql,然后自己把结果处理一下再返回给我

这只是测试数据集,数据量比较小,如果在生产环境上,那就有问题了。

如何解决?

可以使用few-shot的方式优化

关于few-shot的调优,我单独写了一个博客,请移步:LangGraph开发Agent智能体应用【NL2SQL】(few-shot优化)-优快云博客

参考

🦜🕸️LangGraph - LangGraph

LangChain开发LLM应用【入门指南】_langchain 开发社区-优快云博客

LangGraph开发Agent智能体应用【基础聊天机器人】-优快云博客

Introduction | 🦜️🔗 LangChain

代码已上传,按需下载,谢谢大家

<think>好的,用户现在问的是如何使用DeepSeek和自己的数据库做二次开发。首先,我需要理解用户的需求。DeepSeek可能是一个大语言模型,类似于LLaMA或GPT,用户想要结合自己的数据库进行二次开发,可能是想定制模型,使其能够访问特定的数据,或者在企业内部应用中使用。 首先,回顾之前的对话历史,用户之前询问了关于GPU选择的详细建议,现在转向了模型的应用开发,可能是在确定硬件配置后,想开始实际的项目开发。用户可能是开发者或者技术团队的成员,需要将模型与自有数据结合,可能用于企业内部的问答系统、数据分析或其他业务场景。 接下来,我需要考虑用户可能的深层需求。用户提到“二次开发”,可能意味着他们希望定制模型,而不仅仅是使用预训练模型。结合自己的数据库,可能需要模型能够访问和利用这些数据,比如进行RAG(检索增强生成)、微调模型以适应特定领域,或者构建一个端到端的应用系统。 需要分步骤回答这个问题。首先,可能需要介绍整体流程,比如数据准备、模型选择、微调、RAG集成、部署等。然后,每个步骤详细说明,给出具体的工具和方法,比如使用Hugging Face的Transformers库,或者DeepSeek的API。还要考虑用户可能遇到的挑战,比如数据隐私、计算资源限制,以及如何优化模型性能。 此外,用户可能需要代码示例或工具推荐,比如如何连接数据库,如何处理数据,如何配置模型参数。需要确保建议的实用性和可操作性,尤其是针对不同规模的企业,可能从个人开发者到企业级应用都有不同的解决方案。 还需要注意DeepSeek的具体情况,比如它是否提供API,是否有开源版本,支持的微调方法,以及如何处理数据隐私问题。如果用户使用的是本地部署的DeepSeek模型,可能需要考虑模型的大小和所需的硬件资源,是否与之前建议的GPU配置匹配。 可能还需要提到评估和优化,比如如何测试模型在二次开发后的性能,如何进行A/B测试,以及如何监控和维护系统。同时,提醒用户注意法律和合规问题,尤其是在处理敏感数据时。 最后,总结关键点,确保用户有一个清晰的路线图,从数据准备到部署,每个步骤都有明确的指导。可能还需要提供进一步的帮助选项,比如具体的技术细节或案例,根据用户的具体需求进行深入解答。</think>以下是结合DeepSeek大模型与自有数据库进行二次开发的完整技术方案,涵盖从数据准备到部署落地的全流程: --- ### **一、核心开发流程** ```mermaid graph TD A[数据准备] --> B[模型选择] B --> C[微调训练] C --> D[RAG增强] D --> E[API部署] E --> F[业务集成] ``` --- ### **二、关键步骤详解** #### **1. 数据准备与清洗** - **数据库接入**: ```python # 示例:从SQL数据库加载数据 import pandas as pd from sqlalchemy import create_engine engine = create_engine('postgresql://user:pass@localhost/dbname') query = "SELECT id, content FROM knowledge_base WHERE category='tech'" df = pd.read_sql(query, engine) ``` - **数据格式化**: - 对话数据:转换为`{"instruction":"...", "input":"...", "output":"..."}`格式 - 知识数据:构建`<title, content, embedding>`三元组 #### **2. 模型选型策略** | 模型类型 | 适用场景 | 推荐版本 | |-------------------|----------------------------|-------------------------| | DeepSeek-7B | 通用问答/垂直领域微调 | deepseek-llm-7b-chat | | DeepSeek-MoE-16B | 高推理效率的多任务处理 | deepseek-moe-16b-chat | | DeepSeek-67B | 复杂逻辑推理/专业领域分析 | deepseek-67b-base | #### **3. 微调训练方案** - **LoRA高效微调**: ```python from peft import LoraConfig lora_config = LoraConfig( r=8, lora_alpha=32, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none" ) ``` - **全参数微调**(需多卡): ```bash # 启动分布式训练 torchrun --nproc_per_node=4 finetune.py \ --model_name deepseek-7b \ --batch_size_per_gpu 4 ``` #### **4. RAG增强实现** - **向量数据库构建**: ```python from langchain.vectorstores import Chroma from langchain.embeddings import HuggingFaceEmbeddings embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-base-zh") vector_db = Chroma.from_documents(docs, embeddings, persist_dir="./chroma_db") ``` - **检索增强推理**: ```python def rag_inference(query): docs = vector_db.similarity_search(query, k=3) context = "\n".join([doc.page_content for doc in docs]) prompt = f"基于以下信息:{context}\n\n问题:{query}" return model.generate(prompt) ``` #### **5. 部署与优化** - **API服务封装**: ```python from fastapi import FastAPI app = FastAPI() @app.post("/chat") async def chat_endpoint(request: dict): response = model.generate( request["prompt"], max_length=1024, temperature=0.7 ) return {"response": response} ``` - **性能优化技巧**: - 使用vLLM加速推理:`--tensor_parallel_size 2 --quantization awq` - 开启FlashAttention-2优化 --- ### **三、典型应用场景** #### **1. 企业知识库问答** ```python # 知识检索增强示例 def enterprise_qa(question): # 第一步:从数据库检索 related_knowledge = search_erp_database(question) # 第二步:构造增强prompt prompt_template = """你是一名ERP系统专家,请根据以下信息回答问题: [系统文档] {knowledge} 问题:{question}""" # 第三步:调用微调后的模型 return deepseek_api(prompt_template.format( knowledge=related_knowledge, question=question )) ``` #### **2. 数据库智能查询** ```python # 自然语言转SQL def nl2sql(question): schema = get_db_schema() # 获取数据库schema prompt = f"""将以下问题转换为SQL查询语句: 数据库结构:{schema} 问题:{question} SQL:""" response = model.generate(prompt) return validate_sql(response) # SQL语法校验 ``` #### **3. 私有化数据分析** ```python # 自动化报告生成 def generate_report(metrics): raw_data = fetch_from_dwh(metrics) # 数据仓库查询 insights = model.generate(f"分析以下数据:{raw_data},总结3个关键结论") return format_as_html(insights) ``` --- ### **四、安全与合规建议** 1. **数据脱敏处理**: ```python from presidio_analyzer import AnalyzerEngine analyzer = AnalyzerEngine() results = analyzer.analyze(text=text, language="zh") ``` 2. **访问控制**: - 基于角色的API访问权限(RBAC) - 请求频率限制:`FastAPI-Limiter` 3. **审计日志**: ```python @app.middleware("http") async def log_requests(request, call_next): start_time = time.time() response = await call_next(request) log_db.insert({ "path": request.url.path, "status": response.status_code, "latency": time.time()-start_time }) return response ``` --- ### **五、硬件配置推荐** 根据GPU选择指南(参考前文),典型配置方案: - **开发环境**:RTX 4090 24GB(本地调试7B模型) - **生产环境**: - 中小规模:2×A100 80GB(支持70B模型推理) - 企业级:8×H100 80GB集群(千亿模型训练) --- 如果需要具体场景的代码实现(如与MySQL/Oracle的深度集成),或希望获得某行业(金融/医疗/制造)的定制方案,可进一步说明需求细节。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值