基于 llm 的自然语言转 sql 实现

285 阅读6分钟

1. 前言

在数据驱动决策的当下,如何让非技术人员轻松从数据库获取信息成了关键。自然语言转 SQL 技术应运而生,借助大语言模型(LLM)的强大语义理解与生成能力,可将自然语言问题转化为可执行的 SQL 查询。

LangChain 作为一款强大的开源框架,为开发者提供了丰富工具和接口,能高效地将 LLM 集成到应用中,简化开发流程。

而检索增强生成(RAG)则是应对数据安全和知识更新问题的有效策略。它通过在生成文本过程中引入外部知识检索,既避免了敏感数据的直接暴露,又能让生成的 SQL 更贴合实际需求和最新数据。

本文将详细探讨如何结合 LangChain 与 RAG 技术,实现基于 LLM 的自然语言转 SQL 功能,解决实际应用中的数据安全与知识精准性难题。

2. 实现

本文使用火山引擎上提供的豆包系列模型实现

  1. 导入相关包,并设置环境变量
import os

from langchain.vectorstores import FAISS
from langchain.document_loaders import TextLoader
from langchain.text_splitter import CharacterTextSplitter

from model.embedding.doubao_embedding import DoubaoEmbeddings
from util.constant import MODEL_BASE_URL

# embedding 模型使用
os.environ["ARK_API_KEY"] = "****"
# chat llm 模型使用
os.environ["OPENAI_API_KEY"] = "****"
# 火山引擎上的模型 base_url
os.environ["OPENAI_API_BASE"] = "https://ark.cn-beijing.volces.com/api/v3"

# 模型: doubao-1.5-pro-32k
MODEL_NAME = "****"
  1. 设置 llm cache,减少消耗的 token
# 设置模型的缓存, 减少使用的 token, 并且提高响应速度
from langchain.globals import set_llm_cache
from langchain_community.cache import SQLiteCache

set_llm_cache(SQLiteCache(database_path="./data/.model_cache.db"))
  1. 读取并处理数据

我们需要所有的表结构信息进行检索和作为上下文发送给大模型,因此需要把所有的表信息导出到文件(后续生产可以调用 api 获取表信息)
使用表的 DDL 来记录表的结构信息,因为他不仅包含表名、列字段,还有索引相关的信息,可以辅助更好的生成 SQL。
DDL 结构如下,多个表使用`|||`来分隔

create table xxx.xxx
(
    id               int auto_increment comment 'ID' primary key,
    xxx              xxx                           not null comment 'xxx',
    yyy              yyy                           not null comment 'abc',
    constraint uniq_idx_xxx unique (xxx),
    constraint uniq_idx_xxx2 unique (xxx2)
)
    comment 'xxxxxx' charset = utf8; |||
create table xxx.xxx2
(
    id               int auto_increment comment 'ID' primary key,
    xxx              xxx                           not null comment 'xxx',
    yyy              yyy                           not null comment 'abc'
)
    comment 'zzzzzz' charset = utf8;

create index idx_abc
    on segment (abc); |||
......

from pydantic import BaseModel
from langchain_core.embeddings import Embeddings

from util.constant import MODEL_BASE_URL

from volcenginesdkarkruntime import Ark

# embedding model
EMBEDDING_MODEL_NAME = "****"


# 由于 langchain 没有对应的 Embedding 模型实现, 这里自己实现了一个
class DoubaoEmbeddings(BaseModel, Embeddings):
    chunk_size: int = 1000

    def embed_documents(self, texts: list[str], chunk_size: int | None = None) -> list[list[float]]:
        client = Ark(base_url=MODEL_BASE_URL)
        text_in_chunks = [
            texts[i: i + self.chunk_size]
            for i in range(0, len(texts), self.chunk_size)
        ]
        embeddings = []
        for chunk in text_in_chunks:
            resp = client.embeddings.create(
                model=EMBEDDING_MODEL_NAME,
                input=chunk
            )
            embeddings.extend([res.embedding for res in resp.data])

        return embeddings

    def embed_query(self, text: str) -> list[float]:
        return self.embed_documents([text])[0]
import re


def extract_table_name_and_description(ddl: str) -> (str, str):
    # 抽取表名
    table_name_match = re.search(r'create table (\S+)', ddl)
    table_name = table_name_match.group(1) if table_name_match else None

    # 抽取表描述
    table_desc_match = re.search(r'comment \'([^\']+)\';', ddl)
    table_desc_match2 = re.search(r'comment \'([^\']+)\' charset', ddl)
    table_desc = table_desc_match.group(1) if table_desc_match else table_desc_match2.group(
        1) if table_desc_match2 else None
    return table_name, table_desc


# 处理数据
with open("./data/schema.txt", "r") as f:
    schemas = f.read()
schema_arr = [s.strip().rstrip("\n").lstrip("\n") for s in schemas.split("|||")]

# 表名到 ddl 的映射
tb_ddl = {}
tb_desc = []
for schema in schema_arr:
    tbn, tbd = extract_table_name_and_description(schema)
    tb_ddl[tbn] = schema
    tb_desc.append(f"表名: {tbn}, 描述: {tbd}")

# 写入文件保存
with open("./data/table_desc.txt", "w") as f:
    f.writelines("\n".join(tb_desc))
  1. 解析数据,向量化存储

把表名和表描述信息做向量化处理,用于后续从用户问题中检索相关的表。

# 加载数据 切分
loader = TextLoader("./data/table_desc.txt")
tb_desc = loader.load()
text_spliter = CharacterTextSplitter(
    chunk_size=50,
    chunk_overlap=0,
    separator="\n",
    length_function=len
)
split_td = text_spliter.split_documents(tb_desc)
# embedding 保存向量数据库
embedding_save_path = "./data/table_desc_embedding"
db = FAISS.from_documents(split_td, DoubaoEmbeddings()) # 后续可以注释
# embedding 后的数据保存在本地
db.save_local(embedding_save_path) # 后续可以注释
# 加载保存在本地的 embedding 数据
db = FAISS.load_local(embedding_save_path, DoubaoEmbeddings(), allow_dangerous_deserialization=True)

retriever = db.as_retriever(
    search_kwargs={'k': 5}
)

测试结果:

image.png

  1. 使用 langchain 定义 react agent

首先定义 agent 要用到的 tools,然后实例化 react_agent

from langchain_core.tools import Tool


def obtain_table_ddl(table_names: str) -> str:
    print("table_names:", table_names)
    tb_names = [s.strip() for s in table_names.split(",")]
    return "\n".join([tb_ddl[tn] if tn in tb_ddl else "" for tn in tb_names])

def obtain_similar_table_information(query: str) -> str:
    print("query:", query)
    res = retriever.invoke(query)
    return "\n".join([d.page_content for d in res])


# 定义交互的 tool
tools = [
    Tool(
        name="Obtain similar table information",
        func=obtain_similar_table_information,
        description="根据表名,从保存了表名和描述信息的向量数据库中查询相似信息",
    ),
    Tool(
        name="Obtain table DDL",
        func=obtain_table_ddl,
        description=(
            "根据表名,获取表的 DDL。\n"
            "入参:表名字符串,以','分隔,例如 tag,user;\n"
            "出参格式:DDL 语句字符串"
        ),
    ),
]
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langgraph.prebuilt import create_react_agent
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(
    model_name=MODEL_NAME,
    temperature=0,  # 确定性更高
)

text_to_sql_system_prompt = """你是一个 tidb 专家。给定一个输入问题,创建一个语法正确的 TiDB SQL 来运行,并仅返回生成的查询,不返回其他内容。在创建的查询前后不要包含```sql ```。除非另有说明,否则返回的行数不得超过 10 行。

让我们一步一步分析思考:
第一步:分析用户提出的问题,找到需要查询的表,作为[查询表]
第二步:找到[查询表]对应的 DDL 信息,作为[查询表 DDL]
第三步:分析[查询表 DDL]中跟问题关联的信息,作为[context]
第四步:根据问题结合[context],做出最后的回答

在生成 SQL 的时候注意以下原则:
1. 如果信息有缺失,直接告诉用户,不要生成 SQL。
2. 仅使用[context],来创建正确的 SQL 查询,并密切关注哪个列在哪个表中。
3. 在生成 SQL 的时候关注[查询表 DDL]中的索引信息,尽量使 SQL 命中索引,提高 SQL 性能。
4. 确保不要查询表中不存在的列,仅在需要时使用别名。
5. 考虑为表和列使用别名以提高查询的可读性,特别是在复杂连接或子查询的情况下。
6. 如果有必要,使用子查询或公用表表达式(CTE)将问题分解为更小、更易于管理的部分。"""

text_to_sql_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", text_to_sql_system_prompt),
        MessagesPlaceholder(variable_name="messages"),
    ]
)

graph = create_react_agent(llm, tools=tools, prompt=text_to_sql_prompt)
graph
  1. 问题测试
inputs = {"messages": [("user", "关联了****的****有多少个?")]}

for s in graph.stream(inputs, stream_mode="values"):
    message = s["messages"][-1]
    if isinstance(message, tuple):
        print(message)
    else:
        message.pretty_print()

image.png

image.png

3. 问题

  1. 没有记忆能力

没有把历史对话保存,因此只能针对当前这一次对话生成 SQL,后续可以使用 langchain 的 Memory 模块,增加记忆能力,提升用户体验。

  1. 有数据安全问题

现在仍然需要把所有的表名和描述信息发送到 Embedding 模型进行向量化处理,有一定的数据安全问题,后续可以考虑使用本地化部署 Embedding 模型,进行向量化处理后存储起来。