上一篇文章我们介绍了Milvus向量数据库,它是用来存储向量数据的,我们将各种非结构化的文档转为向量,存储在向量数据库。但是RAG系统不只有非结构化的数据,也有结构化的数据,比如说存储在MySQL或者其他关系型数据库中的数据。如何在RAG系统里面对这类数据进行检索查询呢,这就是本文要分享的内容:文本转SQL Text2SQL
文本转SQL
文本转SQL通俗的来说,就是用户输入自然语言,通过LLM大语言模型将自然语言结合表结构生成SQL的过程,实现思路如下:
流程图
代码示例
import os
from dotenv import load_dotenv
load_dotenv()
import pymysql
import json
# ====================== MYSQL的配置 ======================
MYSQL_CONFIG = {
"host": "localhost",
"port": 3306,
"user": "root",
"password": "root",
"database": "finance_enterprise_db",
"charset": "utf8mb4"
}
# 配置你的大模型API
MIMO_API_KEY = os.getenv("M_PROXY_AI_API_KEY")
MIMO_BASE_URL = os.getenv("M_PROXY_AI_BASE_URL")
MIMO_MODEL = "mimo-v2.5-pro"
# ====================== LLM 接口 ======================
def llm_generate_sql(prompt):
"""调用大模型生成SQL"""
from openai import OpenAI
# (兼容OpenAI协议)
client = OpenAI(
api_key=MIMO_API_KEY,
base_url=MIMO_BASE_URL
)
response = client.chat.completions.create(
model=MIMO_MODEL,
messages=[{"role": "user", "content": prompt}],
temperature=0 # 生成SQL必须=0
)
sql = response.choices[0].message.content.strip()
# 清理格式
sql = sql.replace("```sql", "").replace("```", "").strip()
return sql
# ====================== 工具函数 ======================
def get_db_connection():
"""获取MySQL连接"""
return pymysql.connect(**MYSQL_CONFIG, cursorclass=pymysql.cursors.DictCursor)
def execute_sql(sql):
"""执行SQL,返回结果 or 异常"""
conn = get_db_connection()
try:
with conn.cursor() as cursor:
cursor.execute(sql)
if sql.strip().upper().startswith("SELECT"):
return cursor.fetchall(), None
else:
conn.commit()
return "执行成功", None
except Exception as e:
return None, str(e)
finally:
conn.close()
def get_all_table_info():
"""获取数据库所有表名 + 字段结构 + 注释 + 示例SQL"""
# 【稳定写法】直接用 INFORMATION_SCHEMA 查表名,100% 是字典
sql = f"""
SELECT TABLE_NAME
FROM information_schema.TABLES
WHERE TABLE_SCHEMA = '{MYSQL_CONFIG["database"]}'
"""
tables, _ = execute_sql(sql)
table_info_list = []
for row in tables:
# 这里取到表名
table_name = row["TABLE_NAME"]
# 表注释
comment_sql = f"""
SELECT TABLE_COMMENT
FROM information_schema.TABLES
WHERE TABLE_SCHEMA = '{MYSQL_CONFIG["database"]}'
AND TABLE_NAME = '{table_name}'
"""
table_comment, _ = execute_sql(comment_sql)
table_comment = table_comment[0]["TABLE_COMMENT"] if table_comment else ""
# 字段信息
columns, _ = execute_sql(f"DESC {table_name}")
# 示例SQL
example_sql = get_example_sql_by_table(table_name)
# 拼接表结构描述
table_info = f"""
表名:{table_name}
表注释:{table_comment}
字段:{json.dumps(columns, ensure_ascii=False, indent=2)}
示例SQL:
{example_sql}
"""
table_info_list.append(table_info)
return "\n=====================\n".join(table_info_list)
def get_example_sql_by_table(table_name):
"""根据表名返回业务示例SQL"""
sql_map = {
"fin_account_subject": "SELECT * FROM fin_account_subject LIMIT 5;",
"fin_company_contact": "SELECT contact_name,contact_credit FROM fin_company_contact WHERE contact_type=1;",
"fin_balance_sheet": "SELECT * FROM fin_balance_sheet WHERE report_year=2026 AND report_month=2;",
"fin_profit_statement": "SELECT report_year,report_month,main_income,net_profit FROM fin_profit_statement ORDER BY report_year,report_month;",
"fin_cash_flow": "SELECT * FROM fin_cash_flow WHERE report_year=2026 AND report_month=2;",
"fin_expense_record": "SELECT expense_type,SUM(expense_amount) AS total FROM fin_expense_record GROUP BY expense_type;"
}
return sql_map.get(table_name, "SELECT * FROM " + table_name + " LIMIT 3;")
# ====================== 核心流程:提问 → 生成SQL → 执行 → 重试 ======================
def text_to_sql_with_retry(question, max_retry=3):
"""
流程:
1. 获取表结构
2. 拼接prompt给LLM
3. 生成SQL → 执行
4. 失败重试(最多3次)
"""
print(f"【用户问题】:{question}\n")
# 1. 获取所有表结构+描述+示例SQL
table_info = get_all_table_info()
# 2. 构造Prompt
prompt = f"""
你是专业SQL生成专家,请根据下面的数据库表结构、表注释、示例SQL,
严格按照用户问题生成【可直接运行】的MySQL语句,只返回SQL,不要任何解释。
表结构信息:
{table_info}
用户问题:{question}
要求:
1. 只返回SQL,不要```、不要文字
2. 必须使用提供的表和字段
3. 日期、金额、分组、排序严格按业务逻辑
4. 字段名不要编造
"""
# 3. 重试机制
for retry in range(max_retry):
print(f"=== 第 {retry+1} 次生成SQL ===")
try:
# 小米 MiMo 生成SQL
sql = llm_generate_sql(prompt)
print(f"【生成SQL】:\n{sql}\n")
# 执行SQL
result, error = execute_sql(sql)
if not error:
print("【执行成功】")
return {
"question": question,
"sql": sql,
"result": result,
"error": None
}
else:
print(f"【执行失败】:{error}")
# 把错误追加到prompt,让LLM修正
prompt += f"\n\n上一次生成的SQL执行报错:{error},请修正SQL,只返回正确SQL"
except Exception as e:
print(f"【重试异常】{str(e)}")
return {"error": "重试次数耗尽,生成SQL失败"}
# ====================== 测试 ======================
if __name__ == "__main__":
# 你可以随便换问题
question = "查询2026年1月和2月的主营业务收入、净利润,按月份排序"
# 执行
result = text_to_sql_with_retry(question, max_retry=3)
# 打印最终结果
print("\n" + "=" * 50)
print("最终结果:")
print(json.dumps(result, ensure_ascii=False, indent=2))
优化版本
我们通过LLM理解数据库的表结构来生成SQL语句,拿生成的SQL查询数据并返回结果。这逻辑本身是没问题的,但实际生产环境我们的数据库表很多,不可能一下全部查询出来放到prompt中,这样做即不合理,生成SQL也不够精准。所以我们可以引入上文所说的Milvus向量数据库,我们把数据库的表结构、字段注释、实例SQL向量化放入Milvus向量数据库。将用户的问题向量化,用Milvus向量检索出我们需要的表信息和相关实例SQL
流程图
代码示例
初始化
import os
import openai
from dotenv import load_dotenv
load_dotenv()
from pymilvus import MilvusClient, DataType
from milvus_model.dense import OpenAIEmbeddingFunction
from sqlalchemy import create_engine, text
load_dotenv()
openai.api_key = os.getenv("OPENAI_API_KEY")
# 修复:使用新包
embedding_fn = OpenAIEmbeddingFunction(
model_name='text-embedding-3-large',
api_key=openai.api_key
)
# 修复:不能用文件模式,必须连接 Milvus 服务
client = MilvusClient(uri="http://localhost:19530")
DB_URL = "mysql+pymysql://root:root@localhost:3306/finance_enterprise_db"
engine = create_engine(DB_URL)
# ====================== 创建3个集合 ======================
def create_collections():
for name, dim in [("ddl_knowledge", 3072), ("dbdesc_knowledge", 3072), ("q2sql_knowledge", 3072)]:
if client.has_collection(name):
client.drop_collection(name)
schema = client.create_schema()
schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True, auto_id=True)
schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=8192)
schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=dim)
index = client.prepare_index_params()
index.add_index(field_name="vector", index_type="AUTOINDEX", metric_type="COSINE")
client.create_collection(name, schema=schema, index_params=index)
# ====================== 获取所有表结构 ======================
def get_all_tables():
with engine.connect() as conn:
tables = conn.execute(text("SHOW TABLES")).fetchall()
return [t[0] for t in tables]
def get_ddl(table):
return f"CREATE TABLE {table} (...)"
def get_columns(table):
return [{"col": "id", "desc": "主键"}, {"col": "name", "desc": "名称"}]
def get_examples():
return [
("查询2026年1-2月净利润", "SELECT * FROM fin_profit_statement WHERE report_year=2026 AND report_month IN (1,2)"),
("查询各类型支出总额", "SELECT expense_type,SUM(expense_amount) FROM fin_expense_record GROUP BY expense_type")
]
# ====================== 入库 ======================
def build_all():
create_collections()
tables = get_all_tables()
print("表:", tables)
print("Milvus 财务库构建完成")
if __name__ == "__main__":
build_all()
文本转SQL
import os
import logging
import re
import openai
from dotenv import load_dotenv
from pymilvus import MilvusClient
from milvus_model.dense import OpenAIEmbeddingFunction
from sqlalchemy import create_engine, text
# ====================== 基础配置 ======================
logging.basicConfig(level=logging.INFO, format='[%(levelname)s] %(message)s')
load_dotenv()
# ====================== OpenAI ======================
openai.api_key = os.getenv("OPENAI_API_KEY")
MODEL_NAME = os.getenv("OPENAI_MODEL", "gpt-4o-mini")
# ====================== 嵌入函数 ======================
embedding_fn = OpenAIEmbeddingFunction(
model_name='text-embedding-3-large',
api_key=openai.api_key
)
# ====================== 连接 Milvus 服务 ======================
MILVUS_DB = "http://localhost:19530"
client = MilvusClient(MILVUS_DB)
# ====================== 财务数据库连接 ======================
DB_URL = "mysql+pymysql://root:root@localhost:3306/finance_enterprise_db"
engine = create_engine(DB_URL)
# ====================== 向量检索工具 ======================
def retrieve(collection: str, query_emb: list, top_k=3, fields=None):
results = client.search(
collection_name=collection,
data=[query_emb],
limit=top_k,
output_fields=fields
)
return results[0] if results else []
# ====================== SQL 提取 ======================
def extract_sql(text: str) -> str:
sql_blocks = re.findall(r'```sql\n(.*?)\n```', text, re.DOTALL)
if sql_blocks:
return sql_blocks[0].strip()
select_match = re.search(r'SELECT.*?;', text, re.DOTALL)
if select_match:
return select_match.group(0).strip()
return text.strip()
# ====================== SQL 执行 ======================
def execute_sql(sql: str):
try:
with engine.connect() as conn:
result = conn.execute(text(sql))
cols = result.keys()
rows = result.fetchall()
return True, cols, rows
except Exception as e:
return False, None, str(e)
# ====================== LLM 生成 SQL ======================
def generate_sql(prompt: str, error_msg=None):
if error_msg:
prompt += f"\n上一次执行报错:{error_msg},请修正SQL,只返回正确SQL语句"
response = openai.chat.completions.create(
model=MODEL_NAME,
messages=[{"role": "user", "content": prompt}]
)
raw = response.choices[0].message.content.strip()
sql = extract_sql(raw)
logging.info(f"生成SQL: {sql}")
return sql
# ====================== 核心财务库 Text2SQL ======================
def text2sql_finance(question: str, max_retries=3):
# 1. 问题向量化
q_emb = embedding_fn([question])[0]
# 2. 三大向量检索
ddl_hits = retrieve("ddl_knowledge", q_emb, top_k=3, fields=["text"])
q2sql_hits = retrieve("q2sql_knowledge", q_emb, top_k=3, fields=["question", "sql_text"])
desc_hits = retrieve("dbdesc_knowledge", q_emb, top_k=5, fields=["table_name", "column_name", "description"])
# 3. 拼接上下文
ddl_text = "\n".join([h.get("text", "") for h in ddl_hits])
example_text = "\n".join([f"问题:{h['question']}\nSQL:{h['sql_text']}" for h in q2sql_hits])
desc_text = "\n".join([f"{h['table_name']}.{h['column_name']}:{h['description']}" for h in desc_hits])
# 4. 构造Prompt
prompt = f"""
你是财务SQL专家,请根据下面的表结构、字段说明、示例,生成可直接运行的MySQL语句,只返回SQL。
### 表结构:
{ddl_text}
### 字段含义:
{desc_text}
### 参考示例:
{example_text}
### 用户问题:
{question}
要求:
1. 只返回SQL,不要任何解释
2. 字段必须真实存在
3. 日期、分组、排序严格按财务逻辑
4. 不要编造字段
"""
# 5. 重试执行
last_err = None
for i in range(max_retries):
logging.info(f"第 {i+1} 次生成")
sql = generate_sql(prompt, last_err)
ok, cols, res = execute_sql(sql)
if ok:
print("\n执行成功")
print("字段:", cols)
for row in res:
print(row)
return
last_err = res
logging.error(f"失败:{last_err}")
print("超过最大重试次数")
print("最后错误:", last_err)
# ====================== 测试 ======================
if __name__ == "__main__":
q = input("请输入财务查询问题:")
text2sql_finance(q)
总结
好了,文本转SQL的实现思路就分享到这儿,在座的亦菲、彦祖们有什么问题欢迎到评论区留言哦!