RAG-不写SQL也能查询MySQL数据

0 阅读8分钟

上一篇文章我们介绍了Milvus向量数据库,它是用来存储向量数据的,我们将各种非结构化的文档转为向量,存储在向量数据库。但是RAG系统不只有非结构化的数据,也有结构化的数据,比如说存储在MySQL或者其他关系型数据库中的数据。如何在RAG系统里面对这类数据进行检索查询呢,这就是本文要分享的内容:文本转SQL Text2SQL

文本转SQL

文本转SQL通俗的来说,就是用户输入自然语言,通过LLM大语言模型将自然语言结合表结构生成SQL的过程,实现思路如下:

流程图

Text2SQL2.png

代码示例
import os
from dotenv import load_dotenv
load_dotenv()
​
import pymysql
import json
​
​
# ====================== MYSQL的配置 ======================
MYSQL_CONFIG = {
    "host": "localhost",
    "port": 3306,
    "user": "root",
    "password": "root",
    "database": "finance_enterprise_db",
    "charset": "utf8mb4"
}
​
# 配置你的大模型API
MIMO_API_KEY = os.getenv("M_PROXY_AI_API_KEY")
MIMO_BASE_URL = os.getenv("M_PROXY_AI_BASE_URL")
MIMO_MODEL = "mimo-v2.5-pro"# ====================== LLM 接口 ======================
def llm_generate_sql(prompt):
    """调用大模型生成SQL"""
    from openai import OpenAI
​
    # (兼容OpenAI协议)
    client = OpenAI(
        api_key=MIMO_API_KEY,
        base_url=MIMO_BASE_URL
    )
​
    response = client.chat.completions.create(
        model=MIMO_MODEL,
        messages=[{"role": "user", "content": prompt}],
        temperature=0  # 生成SQL必须=0
    )
​
    sql = response.choices[0].message.content.strip()
    # 清理格式
    sql = sql.replace("```sql", "").replace("```", "").strip()
    return sql
​
# ====================== 工具函数 ======================
def get_db_connection():
    """获取MySQL连接"""
    return pymysql.connect(**MYSQL_CONFIG, cursorclass=pymysql.cursors.DictCursor)
​
def execute_sql(sql):
    """执行SQL,返回结果 or 异常"""
    conn = get_db_connection()
    try:
        with conn.cursor() as cursor:
            cursor.execute(sql)
            if sql.strip().upper().startswith("SELECT"):
                return cursor.fetchall(), None
            else:
                conn.commit()
                return "执行成功", None
    except Exception as e:
        return None, str(e)
    finally:
        conn.close()
​
def get_all_table_info():
    """获取数据库所有表名 + 字段结构 + 注释 + 示例SQL"""
    # 【稳定写法】直接用 INFORMATION_SCHEMA 查表名,100% 是字典
    sql = f"""
        SELECT TABLE_NAME 
        FROM information_schema.TABLES 
        WHERE TABLE_SCHEMA = '{MYSQL_CONFIG["database"]}'
    """
    tables, _ = execute_sql(sql)
    table_info_list = []
​
    for row in tables:
        # 这里取到表名
        table_name = row["TABLE_NAME"]
​
        # 表注释
        comment_sql = f"""
            SELECT TABLE_COMMENT 
            FROM information_schema.TABLES 
            WHERE TABLE_SCHEMA = '{MYSQL_CONFIG["database"]}' 
            AND TABLE_NAME = '{table_name}'
        """
        table_comment, _ = execute_sql(comment_sql)
        table_comment = table_comment[0]["TABLE_COMMENT"] if table_comment else ""
​
        # 字段信息
        columns, _ = execute_sql(f"DESC {table_name}")
​
        # 示例SQL
        example_sql = get_example_sql_by_table(table_name)
​
        # 拼接表结构描述
        table_info = f"""
表名:{table_name}
表注释:{table_comment}
字段:{json.dumps(columns, ensure_ascii=False, indent=2)}
示例SQL:
{example_sql}
        """
        table_info_list.append(table_info)
​
    return "\n=====================\n".join(table_info_list)
​
def get_example_sql_by_table(table_name):
    """根据表名返回业务示例SQL"""
    sql_map = {
        "fin_account_subject": "SELECT * FROM fin_account_subject LIMIT 5;",
        "fin_company_contact": "SELECT contact_name,contact_credit FROM fin_company_contact WHERE contact_type=1;",
        "fin_balance_sheet": "SELECT * FROM fin_balance_sheet WHERE report_year=2026 AND report_month=2;",
        "fin_profit_statement": "SELECT report_year,report_month,main_income,net_profit FROM fin_profit_statement ORDER BY report_year,report_month;",
        "fin_cash_flow": "SELECT * FROM fin_cash_flow WHERE report_year=2026 AND report_month=2;",
        "fin_expense_record": "SELECT expense_type,SUM(expense_amount) AS total FROM fin_expense_record GROUP BY expense_type;"
    }
    return sql_map.get(table_name, "SELECT * FROM " + table_name + " LIMIT 3;")
​
# ====================== 核心流程:提问 → 生成SQL → 执行 → 重试 ======================
def text_to_sql_with_retry(question, max_retry=3):
    """
    流程:
    1. 获取表结构
    2. 拼接prompt给LLM
    3. 生成SQL → 执行
    4. 失败重试(最多3次)
    """
    print(f"【用户问题】:{question}\n")
​
    # 1. 获取所有表结构+描述+示例SQL
    table_info = get_all_table_info()
​
    # 2. 构造Prompt
    prompt = f"""
你是专业SQL生成专家,请根据下面的数据库表结构、表注释、示例SQL,
严格按照用户问题生成【可直接运行】的MySQL语句,只返回SQL,不要任何解释。
​
表结构信息:
{table_info}
​
用户问题:{question}
​
要求:
1. 只返回SQL,不要```、不要文字
2. 必须使用提供的表和字段
3. 日期、金额、分组、排序严格按业务逻辑
4. 字段名不要编造
"""
​
    # 3. 重试机制
    for retry in range(max_retry):
        print(f"=== 第 {retry+1} 次生成SQL ===")
        try:
            # 小米 MiMo 生成SQL
            sql = llm_generate_sql(prompt)
            print(f"【生成SQL】:\n{sql}\n")
​
            # 执行SQL
            result, error = execute_sql(sql)
​
            if not error:
                print("【执行成功】")
                return {
                    "question": question,
                    "sql": sql,
                    "result": result,
                    "error": None
                }
            else:
                print(f"【执行失败】:{error}")
                # 把错误追加到prompt,让LLM修正
                prompt += f"\n\n上一次生成的SQL执行报错:{error},请修正SQL,只返回正确SQL"
​
        except Exception as e:
            print(f"【重试异常】{str(e)}")
​
    return {"error": "重试次数耗尽,生成SQL失败"}
​
# ====================== 测试 ======================
if __name__ == "__main__":
    # 你可以随便换问题
    question = "查询2026年1月和2月的主营业务收入、净利润,按月份排序"
​
    # 执行
    result = text_to_sql_with_retry(question, max_retry=3)
​
    # 打印最终结果
    print("\n" + "=" * 50)
    print("最终结果:")
    print(json.dumps(result, ensure_ascii=False, indent=2))
优化版本

我们通过LLM理解数据库的表结构来生成SQL语句,拿生成的SQL查询数据并返回结果。这逻辑本身是没问题的,但实际生产环境我们的数据库表很多,不可能一下全部查询出来放到prompt中,这样做即不合理,生成SQL也不够精准。所以我们可以引入上文所说的Milvus向量数据库,我们把数据库的表结构、字段注释、实例SQL向量化放入Milvus向量数据库。将用户的问题向量化,用Milvus向量检索出我们需要的表信息和相关实例SQL

流程图

text2SQL3.png

代码示例

初始化

import os
import openai
from dotenv import load_dotenv
load_dotenv()
​
​
from pymilvus import MilvusClient, DataType
from milvus_model.dense import OpenAIEmbeddingFunction
from sqlalchemy import create_engine, text
​
load_dotenv()
openai.api_key = os.getenv("OPENAI_API_KEY")
​
# 修复:使用新包
embedding_fn = OpenAIEmbeddingFunction(
    model_name='text-embedding-3-large',
    api_key=openai.api_key
)
​
# 修复:不能用文件模式,必须连接 Milvus 服务
client = MilvusClient(uri="http://localhost:19530")
​
DB_URL = "mysql+pymysql://root:root@localhost:3306/finance_enterprise_db"
engine = create_engine(DB_URL)
​
# ====================== 创建3个集合 ======================
def create_collections():
    for name, dim in [("ddl_knowledge", 3072), ("dbdesc_knowledge", 3072), ("q2sql_knowledge", 3072)]:
        if client.has_collection(name):
            client.drop_collection(name)
        schema = client.create_schema()
        schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True, auto_id=True)
        schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=8192)
        schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=dim)
        index = client.prepare_index_params()
        index.add_index(field_name="vector", index_type="AUTOINDEX", metric_type="COSINE")
        client.create_collection(name, schema=schema, index_params=index)
​
# ====================== 获取所有表结构 ======================
def get_all_tables():
    with engine.connect() as conn:
        tables = conn.execute(text("SHOW TABLES")).fetchall()
        return [t[0] for t in tables]
​
def get_ddl(table):
    return f"CREATE TABLE {table} (...)"def get_columns(table):
    return [{"col": "id", "desc": "主键"}, {"col": "name", "desc": "名称"}]
​
def get_examples():
    return [
        ("查询2026年1-2月净利润", "SELECT * FROM fin_profit_statement WHERE report_year=2026 AND report_month IN (1,2)"),
        ("查询各类型支出总额", "SELECT expense_type,SUM(expense_amount) FROM fin_expense_record GROUP BY expense_type")
    ]
​
# ====================== 入库 ======================
def build_all():
    create_collections()
    tables = get_all_tables()
    print("表:", tables)
    print("Milvus 财务库构建完成")
​
if __name__ == "__main__":
    build_all()

文本转SQL

import os
import logging
import re
import openai
from dotenv import load_dotenv
from pymilvus import MilvusClient
from milvus_model.dense import OpenAIEmbeddingFunction
from sqlalchemy import create_engine, text
​
# ====================== 基础配置 ======================
logging.basicConfig(level=logging.INFO, format='[%(levelname)s] %(message)s')
load_dotenv()
​
# ====================== OpenAI ======================
openai.api_key = os.getenv("OPENAI_API_KEY")
MODEL_NAME = os.getenv("OPENAI_MODEL", "gpt-4o-mini")
​
# ====================== 嵌入函数 ======================
embedding_fn = OpenAIEmbeddingFunction(
    model_name='text-embedding-3-large',
    api_key=openai.api_key
)
​
# ====================== 连接 Milvus 服务 ======================
MILVUS_DB = "http://localhost:19530"
client = MilvusClient(MILVUS_DB)
​
# ====================== 财务数据库连接 ======================
DB_URL = "mysql+pymysql://root:root@localhost:3306/finance_enterprise_db"
engine = create_engine(DB_URL)
​
# ====================== 向量检索工具 ======================
def retrieve(collection: str, query_emb: list, top_k=3, fields=None):
    results = client.search(
        collection_name=collection,
        data=[query_emb],
        limit=top_k,
        output_fields=fields
    )
    return results[0] if results else []
​
# ====================== SQL 提取 ======================
def extract_sql(text: str) -> str:
    sql_blocks = re.findall(r'```sql\n(.*?)\n```', text, re.DOTALL)
    if sql_blocks:
        return sql_blocks[0].strip()
    select_match = re.search(r'SELECT.*?;', text, re.DOTALL)
    if select_match:
        return select_match.group(0).strip()
    return text.strip()
​
# ====================== SQL 执行 ======================
def execute_sql(sql: str):
    try:
        with engine.connect() as conn:
            result = conn.execute(text(sql))
            cols = result.keys()
            rows = result.fetchall()
            return True, cols, rows
    except Exception as e:
        return False, None, str(e)
​
# ====================== LLM 生成 SQL ======================
def generate_sql(prompt: str, error_msg=None):
    if error_msg:
        prompt += f"\n上一次执行报错:{error_msg},请修正SQL,只返回正确SQL语句"
​
    response = openai.chat.completions.create(
        model=MODEL_NAME,
        messages=[{"role": "user", "content": prompt}]
    )
    raw = response.choices[0].message.content.strip()
    sql = extract_sql(raw)
    logging.info(f"生成SQL: {sql}")
    return sql
​
# ====================== 核心财务库 Text2SQL ======================
def text2sql_finance(question: str, max_retries=3):
    # 1. 问题向量化
    q_emb = embedding_fn([question])[0]
​
    # 2. 三大向量检索
    ddl_hits = retrieve("ddl_knowledge", q_emb, top_k=3, fields=["text"])
    q2sql_hits = retrieve("q2sql_knowledge", q_emb, top_k=3, fields=["question", "sql_text"])
    desc_hits = retrieve("dbdesc_knowledge", q_emb, top_k=5, fields=["table_name", "column_name", "description"])
​
    # 3. 拼接上下文
    ddl_text = "\n".join([h.get("text", "") for h in ddl_hits])
    example_text = "\n".join([f"问题:{h['question']}\nSQL:{h['sql_text']}" for h in q2sql_hits])
    desc_text = "\n".join([f"{h['table_name']}.{h['column_name']}{h['description']}" for h in desc_hits])
​
    # 4. 构造Prompt
    prompt = f"""
你是财务SQL专家,请根据下面的表结构、字段说明、示例,生成可直接运行的MySQL语句,只返回SQL。
​
### 表结构:
{ddl_text}
​
### 字段含义:
{desc_text}
​
### 参考示例:
{example_text}
​
### 用户问题:
{question}
​
要求:
1. 只返回SQL,不要任何解释
2. 字段必须真实存在
3. 日期、分组、排序严格按财务逻辑
4. 不要编造字段
"""
​
    # 5. 重试执行
    last_err = None
    for i in range(max_retries):
        logging.info(f"第 {i+1} 次生成")
        sql = generate_sql(prompt, last_err)
        ok, cols, res = execute_sql(sql)
        if ok:
            print("\n执行成功")
            print("字段:", cols)
            for row in res:
                print(row)
            return
        last_err = res
        logging.error(f"失败:{last_err}")
​
    print("超过最大重试次数")
    print("最后错误:", last_err)
​
# ====================== 测试 ======================
if __name__ == "__main__":
    q = input("请输入财务查询问题:")
    text2sql_finance(q)
总结

好了,文本转SQL的实现思路就分享到这儿,在座的亦菲、彦祖们有什么问题欢迎到评论区留言哦!