Spring AI + Super-SQL 实现NL2SQL

1,486 阅读7分钟

序言

最近公司项目当中需要整合SpringAI来实现NL2SQL弟弟功能,刚好 Gitee 上刚开源一个框架实现了这个功能,由于框架刚刚开源我在尝试的过程中碰到过很多问题,所以写这篇记录一下,帮助大家的同时也为以后做参考

Super-SQL

Super-Sql 是一个基于国内外先进生成式大模型的Java框架,专注于将数据库表结构通过检索增强生成(RAG, Retrieval-Augmented Generation)技术进行训练,从而实现从自然语言文本到SQL查询的智能转换(Text to SQL)。该框架旨在简化复杂的数据库查询过程,使开发者和用户能够通过简单的自然语言描述获取所需数据。

主要特性包括:

  • 生成式SQL利用强大的生成式大模型,自动将自然语言问题转化为精确的SQL查询语句。
  • RAG训练通过检索增强生成技术对数据库表结构进行深度学习训练,提高SQL生成的准确性和效率。
  • 类型安全与灵活易用结合Java的泛型机制确保编译期类型检查,同时提供简洁直观的API设计,易于集成到现有项目中。
  • 多数据库支持兼容多种主流数据库系统,满足不同应用场景的需求。
  • 性能优化经过精心设计与调优,在保证高效执行的同时保持良好的可读性。

Super-Sql 适用于希望在Java应用程序中快速、安全地进行复杂数据库操作,并且希望通过自然语言处理技术为传统企业应用快速实现AI赋能的场景。

工作原理

Super-Sql 的工作原理基于 RAG 技术,通过检索增强生成技术对数据库表结构进行深度学习训练,从而实现从自然语言文本到SQL查询的智能转换。具体来说,当用户提供一段自然语言描述时,框架会首先解析这段描述,并根据预训练的模型生成相应的SQL查询语句。这些查询语句可以进一步用于实际的数据库操作。

image.png

官网地址

gitee.com/guocjsh/sup…

快速开始

导入Super-Sql的工程

git clone https://gitee.com/guocjsh/supersql-open.git

配置Super-Sql的配置文件

配置init-train配置项,默认为false,表示不进行训练,如果为true,则自动根据数据库连接配置进行全表训练。

super-sql:
  init-train: false

大语言模型配置

这里选择阿里通义模型,在 super-sql 模块的 pom 文件中加入

<repositories>
    <repository>
        <id>spring-milestones</id>
        <name>Spring Milestones</name>
        <url>https://repo.spring.io/milestone</url>
        <snapshots>
            <enabled>false</enabled>
        </snapshots>
    </repository>
</repositories>

然后将 spring-ai.version 版本修改为 1.0.0-M3。这个地方一定要修改,使用 M5 版本会报错

<properties>
    <java.version>17</java.version>
    <revision>1.0.0-M1-SNAPSHOT</revision>
    <spring-boot.version>3.3.0</spring-boot.version>
    <hutool.version>5.8.35</hutool.version>
    <fastjson.version>2.0.31</fastjson.version>
    <lombok.version>1.18.32</lombok.version>
    <mybatis.plus-version>3.5.8</mybatis.plus-version>
    <mysql.version>8.0.32</mysql.version>
    <!-- Spring AI -->
    <spring-ai.version>1.0.0-M3</spring-ai.version>
</properties>

在 super-sql-console 模块中加入阿里巴巴依赖

<dependency>
    <groupId>com.alibaba.cloud.ai</groupId>
    <artifactId>spring-ai-alibaba-starter</artifactId>
    <version>1.0.0-M3.1</version>
</dependency>

阿里模型配置如下,建议模型不要更改,其他模型可以能出现请求不成功的情况

spring:
    ai:
      dashscope:
        api-key: #{key 替换为自己的阿里云大模型key}
        chat:
          options:
            model: qwen-plus

请求调用Text To SQL示例:


private final DashScopeChatModel dashScopeChatModel;

private final SpringSqlEngine sqlEngine;

private final SpringVectorStore store;

@GetMapping("getSuperSql")
public Object getSuperSql(@RequestParam String question) {
    String sql = sqlEngine.setChatModel(dashScopeChatModel).generateSql(question);
    return sql;
}

向量数据库配置

Chroma

安装Chroma环境

我本地 Python 版本为 3.13.2。执行下面命令安装Chroma

pip install chromadb=0.5.23
chroma run

这里建议下载 0.5.23 的版本,下载其他的版本连接时会报错。 出现以下界面表示启动成功

image.png

<!--spring ai chroma 的向量数据库-->
<dependency>    
    <groupId>org.springframework.ai</groupId>    
    <artifactId>spring-ai-chroma-store-spring-boot-starter</artifactId>
</dependency>

安装完成后,启动项目,这个地方大家在启动的时候可能会报错提示token过长,这里我初步是将超出提示词限制的文档进行切割,分步添加到向量数据库中,修改后就可以正常启动了。用到的类如下:

public class DataSplitter {

    public static List<String> splitData(String data, int maxLength) {
        List<String> chunks = new ArrayList<>();
        int start = 0;
        while (start < data.length()) {
            int end = Math.min(start + maxLength, data.length());
            chunks.add(data.substring(start, end));
            start = end;
        }
        return chunks;
    }
}

image.png 启动成功后,测试前面的 TXT2SQL 的功能,这里假设有以下表:

-- 创建 hospitals 表并添加表注释
CREATE TABLE hospitals (
    -- 医院 ID,主键,自增,注释:医院的唯一标识
    hospital_id INT AUTO_INCREMENT PRIMARY KEY COMMENT '医院的唯一标识',
    -- 医院名称,最长 255 个字符,非空,注释:医院的全称
    hospital_name VARCHAR(255) NOT NULL COMMENT '医院的全称',
    -- 医院地址,最长 255 个字符,注释:医院所在的详细地址
    address VARCHAR(255) COMMENT '医院所在的详细地址',
    -- 联系电话,最长 20 个字符,注释:医院的对外联系电话号码
    phone_number VARCHAR(20) COMMENT '医院的对外联系电话号码',
    -- 医院评级,例如一级甲等、二级乙等,最长 20 个字符,注释:医院的等级评定信息
    rating VARCHAR(20) COMMENT '医院的等级评定信息',
    -- 医院成立年份,注释:医院成立的具体年份
    established_year YEAR COMMENT '医院成立的具体年份',
    -- 医院描述信息,文本类型,注释:关于医院的详细描述内容
    description TEXT COMMENT '医院的详细描述内容'
) COMMENT = '医院信息表';

请求结果返回示例:

image.png 可以看到这里的表名和医院名称都和数据库表是对应的,但是医院地址的字段名称是不对的,我猜测是由于上面将文档切割后不完整导致的。但是为了解决提示词过长的问题,不得不对文档进行切割。所以为了应对这种情况,可以采用单独训练的方式来保证生成 SQL 的精确性。

训练指定内容

强化训练数据库的DDL语句

@GetMapping("trainDdl")
public String trainDDl() {
    String ddl = """
                CREATE TABLE hospitals (
                hospital_id INT AUTO_INCREMENT PRIMARY KEY COMMENT '医院的唯一标识',
                hospital_name VARCHAR(255) NOT NULL COMMENT '医院的全称',
                address VARCHAR(255) COMMENT '医院所在的详细地址',
                phone_number VARCHAR(20) COMMENT '医院的对外联系电话号码',
                rating VARCHAR(20) COMMENT '医院的等级评定信息',
                established_year YEAR COMMENT '医院成立的具体年份',
                description TEXT COMMENT '医院的详细描述内容'
                ) COMMENT = '医院信息表';
            """;
    sqlEngine.setChatModel(dashScopeChatModel).train(TrainBuilder.builder().content(ddl).policy(TrainPolicyType.DDL).build());
    return "successful training";
}

单独训练指定SQL

@GetMapping("trainSql")
public String trainSql() {
    String sql = "SELECT * FROM hospitals WHERE address LIKE '%黄浦区%'";
    String question = "在黄浦区的医院有哪些?";
    sqlEngine.setChatModel(dashScopeChatModel).train(TrainBuilder.builder().content(sql).question(question).policy(TrainPolicyType.SQL).build());
    return "successful training";
}

训练完成后,再次调用上文的 Text To SQL 接口,效果基本已经和数据库高度吻合。

image.png

业务流程

现在来说一下为什么经过训练后 AI 所给出的答案就已经非常准确了。既然是测试 AI 生产 SQL 的准确性,那就看一下生成 SQL 的具体业务逻辑。

@Override
public String generateSql(String question) {
    if (question == null || question.trim().isEmpty()) {
        throw new IllegalArgumentException("Question cannot be null or empty");
    }
    // 生成 SQL
    FilterExpressionBuilder expression = new FilterExpressionBuilder();
    List<Document> questionSqlList = this.searchVectorByTag(question, TrainPolicyType.SQL);
    List<Document> ddlList = this.searchVectorByTag(question, TrainPolicyType.DDL);
    List<Document> documentList = this.searchVectorByTag(question, TrainPolicyType.DOCUMENTATION);
    SqlpromptBuilder sqlprompt = SqlpromptBuilder.builder().question(question).questionSqlList(questionSqlList).ddlList(ddlList).documentList(documentList).build();
    Prompt prompt = SqlAssistantPrompt.getSqlPrompt(sqlprompt);
    log.info("Generating SQL Prompt for first:\n {}", prompt.getContents());
    // 调用 LLM 生成 SQL
    String llmResponse = ChatClientFactory.buildChatClient(this.chatModel).prompt(prompt).call().content();
    log.info("Generating SQL From LLM {}", llmResponse);

    // 验证 SQL 并返回
    if (llmResponse.contains("intermediate_sql")) {
        String intermediateSql = SqlExtractorUtils.extractSql(llmResponse);
        List<Map<String, Object>> executed = executeSql(intermediateSql);
        sqlprompt.getDocumentList().add(
                new Document(String.format("""
                    The following is a pandas DataFrame with the results of the intermediate SQL query %s:\n%s
                    """,intermediateSql, executed.toString()
                )));
        prompt = SqlAssistantPrompt.getSqlPrompt(sqlprompt);
        llmResponse = ChatClientFactory.buildChatClient(this.chatModel).prompt(prompt).call().content();
    }
    String sql = SqlExtractorUtils.extractSql(llmResponse);
    return validSql(sql) ? sql : null;
}

看这三行

List<Document> questionSqlList = this.searchVectorByTag(question, TrainPolicyType.SQL);
List<Document> ddlList = this.searchVectorByTag(question, TrainPolicyType.DDL);
List<Document> documentList = this.searchVectorByTag(question, TrainPolicyType.DOCUMENTATION);

这三行就是从向量数据库中查询已经训练过的数据,然后将训练过的数据构建成SqlpromptBuilder对象,最后根据SqlpromptBuilder对象创建Prompt对象传递给 LLM 大模型生成 SQL。

疑问

为了保证 SQL 的准确性而去检索向量数据库,当向量数据库中的数据庞大时,所生成的Prompt提示词对象大小一定会超出 AI 模型的限制,出现这种情况应该如何解决?

总结

如果只是单纯实现文本2SQL的功能其实不难,目前市面上的 AI 都可以做到这一步,但是在公司里开发这是远远不够的,在项目中集成 AI 是为了让 AI 和项目耦合,以公司内部数据为准来回答问题,这些就需要数据铺垫来保证生成 SQL 的准确性。