chromadb + Ollam - 穷人的自然语言写SQL方案

1,769 阅读10分钟

项目概述

核心功能是构建一个基于 AI 的数据库查询助手,它能够理解自然语言的用户查询,并自动生成相应的 SQL 语句进行查询。这里使用chromadb + Ollama的穷人方案实现,安装的软件少,机器配置要求低,完全免费。

chromadb + Ollama可以参考我的这篇文章:chromadb + Ollama 快速实现RAG应用

具体来说,它做了以下几件事:

  1. 提取数据库结构定义 (DDL): 从目标 MySQL 数据库中提取所有表的 DDL 语句,作为数据库结构的知识库。
  2. 利用向量数据库 (ChromaDB) 存储 DDL 信息: 使用 Ollama 模型将 DDL 语句转换为嵌入向量,并存储到 ChromaDB 中,以便进行语义搜索。
  3. 接收用户自然语言查询: 接收用户以自然语言形式提出的数据库查询请求。
  4. 语义搜索相关 DDL: 使用 Ollama 模型将用户查询转换为嵌入向量,并在 ChromaDB 中搜索语义最相关的 DDL 语句。
  5. 利用大语言模型 (Ollama) 生成 SQL: 将用户查询和相关的 DDL 语句作为上下文信息,发送给 Ollama 模型,请求其生成能够满足用户查询需求的 SQL 语句。
  6. 执行 SQL 并展示结果: 执行生成的 SQL 语句,并将查询结果以表格形式展示给用户。

总而言之,这段代码结合了向量数据库和大型语言模型的优势,实现了基于自然语言的数据库查询功能。它能够降低用户使用数据库的门槛,提高数据查询的效率,为数据分析和决策提供便捷的工具。

sequenceDiagram
    participant 用户
    participant 应用程序
    participant Ollama Embeddings
    participant ChromaDB
    participant Ollama Coder

    用户->>应用程序: 输入自然语言查询
    activate 应用程序
    应用程序->>Ollama Embeddings: 请求查询嵌入向量
    activate Ollama Embeddings
    Ollama Embeddings-->>应用程序: 返回查询嵌入向量
    deactivate Ollama Embeddings
    应用程序->>ChromaDB: 查询相关 DDL
    activate ChromaDB
    ChromaDB-->>应用程序: 返回相关 DDL
    deactivate ChromaDB
    应用程序->>Ollama Coder: 请求生成 SQL (查询 + DDL)
    activate Ollama Coder
    Ollama Coder-->>应用程序: 返回 SQL 语句
    deactivate Ollama Coder
    应用程序->>MySQL 数据库: 执行 SQL 查询
    activate MySQL 数据库
    MySQL 数据库-->>应用程序: 返回查询结果
    deactivate MySQL 数据库
    应用程序->>用户: 展示查询结果 (表格)
    deactivate 应用程序

技术选型

  • Embedding Models:Ollama + bge-large-zh:v1.5
  • Vector Databases:chromadb
  • Generation Models: Ollama + deepseek-coder-v2

这样的选择主要是因为安装的软件较少,对机器配置要求低,完全免费。只需安装chromadb + Ollama即可。我机器是4G的3050也能勉强运行。

代码实现

安装依赖

pip install mysql-connector-python
pip install tabulate
pip install ollama chromadb
  • mysql-connector-python用于

    • 获取表的DDL
    • 执行SQL
  • tabulate用于把SQL结果显示

  • ollama用于运行本地模型(Embedding Models和Generation Models)

  • chromadb用于Vector Databases

数据库定义

这里是一个大家最讨厌的广告展示点击系统

本次测试的题目是 “请问如何统计各个大区的点击数,请给出SQL语句” ,根据下面的DDL可以看出,实际只和provinceclick表相关

  • click中有省份的信息
  • province中有省份和大区的关系

比如province

province_nameregion_name
上海华东地区
北京华北地区
…………

比如click

click_idprovinceclick_price其它列
129070334浙江1.18……
129070335北京1.53
……………………

具体定义如下,还有一些乱七八糟的表略过……

当然数据库的DDL的注释质量越高,我们得到的效果就越好

CREATE TABLE `ad_provider` (
  `id` int(20) NOT NULL COMMENT 'ID索引',
  `provider_name` varchar(255) NOT NULL COMMENT '广告提供商名称',
  `call_url` varchar(255) NOT NULL COMMENT '广告API地址',
  PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COMMENT='广告提供商'CREATE TABLE `app` (
  `id` int(20) NOT NULL COMMENT 'ID索引',
  `app_name` varchar(255) DEFAULT NULL COMMENT '应用名称',
  `app_key` varchar(255) DEFAULT NULL COMMENT '应用KEY',
  PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COMMENT='应用'CREATE TABLE `weight` (
  `id` int(20) NOT NULL COMMENT 'ID索引',
  `app_id` int(20) DEFAULT NULL COMMENT '应用编号',
  `ad_provider_id` int(255) DEFAULT NULL COMMENT '广告提供商编号',
  `ad_weight` int(4) DEFAULT NULL COMMENT '广告权重',
  `ad_type` int(4) DEFAULT NULL COMMENT '广告类别(横幅,开屏,视频)',
  PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COMMENT='分配权重'CREATE TABLE `click` (
  `click_id` int(11) NOT NULL,
  `ad_id` int(11) NOT NULL COMMENT '广告ID',
  `app_id` int(11) NOT NULL COMMENT '广告位ID',
  `check_uuid` varchar(200) COLLATE utf8mb4_bin NOT NULL DEFAULT '' COMMENT '用户唯一ID',
  `create_dateline` int(11) NOT NULL DEFAULT '0' COMMENT '时间戳',
  `create_date` date NOT NULL DEFAULT '2000-01-01' COMMENT '日期',
  `create_h` smallint(6) NOT NULL DEFAULT '0' COMMENT '小时',
  `flag` smallint(6) NOT NULL DEFAULT '0',
  `province` varchar(50) COLLATE utf8mb4_bin NOT NULL DEFAULT '0' COMMENT '省份',
  `city` varchar(50) COLLATE utf8mb4_bin NOT NULL DEFAULT '0' COMMENT '城市',
  `isp` varchar(50) COLLATE utf8mb4_bin NOT NULL DEFAULT '0' COMMENT '运营商',
  `browser` varchar(50) COLLATE utf8mb4_bin NOT NULL DEFAULT '0' COMMENT '终端浏览器',
  `remote_addr` varchar(50) COLLATE utf8mb4_bin NOT NULL DEFAULT '0' COMMENT 'IP地址',
  `http_user_agent` text COLLATE utf8mb4_bin NOT NULL COMMENT 'UA',
  `click_price` decimal(5,2) NOT NULL DEFAULT '0.00' COMMENT '点击单价',
  PRIMARY KEY (`click_id`) USING BTREE
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin
​
CREATE TABLE `province` (
  `province_name` varchar(50) NOT NULL COMMENT '省份',
  `region_name` varchar(50) DEFAULT NULL COMMENT '大区',
  PRIMARY KEY (`province_name`) USING BTREE
) ENGINE=InnoDB DEFAULT CHARSET=utf8
​
# 还有一些乱七八糟的表略过,比如广告表

源代码

import chromadb
import mysql.connector
import ollama
import re
from tabulate import tabulate
​
# 初始化 ChromaDB 客户端
client = chromadb.Client()
​
# 定义 ChromaDB 集合名称
collection_name = "train_tables"
​
​
def get_table_ddl(host, user, password, database, table_name):
    """
    获取MySQL数据库表的DDL语句。
​
    Args:
      host: 数据库主机地址。
      user: 数据库用户名。
      password: 数据库密码。
      database: 数据库名称。
      table_name: 表名。
​
    Returns:
      表的DDL语句,字符串类型。
    """
​
    try:
        mydb = mysql.connector.connect(
            host=host, user=user, password=password, database=database
        )
​
        mycursor = mydb.cursor()
​
        mycursor.execute(f"SHOW CREATE TABLE {table_name}")
​
        result = mycursor.fetchone()
​
        ddl = result[1]
​
        return ddl
​
    except mysql.connector.Error as err:
        print(f"Error: {err}")
        return None
​
    finally:
        if mydb.is_connected():
            mycursor.close()
            mydb.close()
​
​
def get_all_tables_ddl(host, user, password, database):
    """
    获取MySQL数据库中所有表的DDL语句。
​
    Args:
      host: 数据库主机地址。
      user: 数据库用户名。
      password: 数据库密码。
      database: 数据库名称。
​
    Returns:
      一个字典,键为表名,值为表的DDL语句。
    """
​
    try:
        mydb = mysql.connector.connect(
            host=host, user=user, password=password, database=database
        )
​
        mycursor = mydb.cursor()
​
        mycursor.execute("SHOW TABLES")
​
        tables = mycursor.fetchall()
​
        ddl_dict = {}
​
        for table in tables:
            table_name = table[0]
            ddl = get_table_ddl(
                host, user, password, database, table_name
            )  # 调用之前定义的函数
            ddl_dict[table_name] = ddl
​
        return ddl_dict
​
    except mysql.connector.Error as err:
        print(f"Error: {err}")
        return None
​
    finally:
        if mydb.is_connected():
            mycursor.close()
            mydb.close()
​
​
def extract_sql_from_markdown(markdown_text):
    """
    从Markdown文本中提取SQL语句。
​
    Args:
      markdown_text: Markdown文本字符串。
​
    Returns:
      SQL语句字符串,如果未找到则返回None。
    """
    pattern = r"```sql\n(.*?)\n```"  # 匹配以```sql开头和结尾的代码块
    match = re.search(pattern, markdown_text, re.DOTALL)
    if match:
        return match.group(1).strip()
    else:
        return None
​
​
def execute_sql_and_display_table(host, user, password, database, sql_statement):
    """执行SQL语句并以表格形式显示结果"""
    try:
        mydb = mysql.connector.connect(
            host=host, user=user, password=password, database=database
        )
​
        mycursor = mydb.cursor()
        mycursor.execute(sql_statement)
        results = mycursor.fetchall()
​
        # 获取列名
        column_names = [i[0] for i in mycursor.description]
​
        # 使用tabulate库格式化输出表格
        table = tabulate(results, headers=column_names, tablefmt="grid")
        print(table)
​
    except mysql.connector.Error as err:
        print(f"Error: {err}")
​
    finally:
        if mydb.is_connected():
            mycursor.close()
            mydb.close()
​
​
def embed_and_upsert_document(collection, document_id, document_content):
    """
    计算文档的嵌入向量并将其添加到 ChromaDB 集合中。
​
    Args:
      collection: ChromaDB 集合对象。
      document_id: 文档的唯一标识符。
      document_content: 文档内容。
    """
    # 使用 Ollama 模型计算嵌入向量
    response = ollama.embeddings(
        model="dztech/bge-large-zh:v1.5", prompt=document_content
    )
    embedding = response["embedding"]
​
    # 将文档及其嵌入向量添加到 ChromaDB 集合
    collection.upsert(
        ids=document_id, embeddings=[embedding], documents=[document_content]
    )
    print(f"Document {document_id} added to ChromaDB.")
​
​
def query_chromadb(collection, query_text, n_results=5):
    """
    查询 ChromaDB 集合并返回最相关的文档。
​
    Args:
      collection: ChromaDB 集合对象。
      query_text: 查询文本。
      n_results: 返回结果的数量。
​
    Returns:
      ChromaDB 查询结果,包含相关文档及其 ID。
    """
    # 使用 Ollama 模型计算查询文本的嵌入向量
    response = ollama.embeddings(prompt=query_text, model="dztech/bge-large-zh:v1.5")
​
    # 查询 ChromaDB 集合
    results = collection.query(
        query_embeddings=[response["embedding"]], n_results=n_results
    )
    return results
​
​
# 使用示例:
host = "localhost"
user = "root"
password = ""
database = "train"if __name__ == "__main__":
    # 获取或创建 ChromaDB 集合
    collection = client.get_or_create_collection(name=collection_name)
​
    # 获取所有表的 DDL 语句
    ddl_dict = get_all_tables_ddl(host, user, password, database)
​
    # 将 DDL 语句嵌入并添加到 ChromaDB 集合
    if ddl_dict:
        for table_name, ddl in ddl_dict.items():
            print(f"Table: {table_name}")
            print(ddl)
            print("-" * 20)
            embed_and_upsert_document(collection, table_name, ddl)
​
    print("All done!")
​
    # 从请求中获取提示
    prompt = "请问如何统计各个大区的点击数,请给出SQL语句"
​
    # 查询 ChromaDB
    results = query_chromadb(collection, prompt)
    print(results)
​
    # 获取查询结果中的第一个文档
    if results["documents"]:
        data = "\n\n\n\n\n".join(results["documents"][0])
        output = ollama.generate(
            model="deepseek-coder-v2",
            prompt=f"根据这些数据库结构定义:{data}。回答这个问题:{prompt}",
        )
        sql_statement = extract_sql_from_markdown(output["response"])
        if sql_statement:
            print(sql_statement)
            # 执行 SQL
            execute_sql_and_display_table(host, user, password, database, sql_statement)
​
    else:
        print("No results found.")

执行效果

各个大区的点击数

题目是 “请问如何统计各个大区的点击数,请给出SQL语句”

AI生成的SQL

SELECT 
    p.region_name,
    COUNT(c.click_id) AS click_count
FROM
    province p
JOIN
    click c ON p.province_name = c.province
WHERE
    p.region_name IS NOT NULL
GROUP BY
    p.region_name;

输出的表格

+---------------+---------------+
| region_name   |   click_count |
+===============+===============+
| 东北地区      |         39404 |
+---------------+---------------+
| 华东地区      |        289745 |
+---------------+---------------+
| 华中地区      |         67297 |
+---------------+---------------+
| 华北地区      |        123153 |
+---------------+---------------+
| 华南地区      |        290853 |
+---------------+---------------+
| 西北地区      |          8130 |
+---------------+---------------+
| 西南地区      |         26264 |
+---------------+---------------+
大区的点击数及平均点击单价

再换一个题目 “请问如何统计各个大区的点击数及平均点击单价,请给出SQL语句”

AI生成的SQL

SELECT 
    p.region_name AS region,
    COUNT(c.click_id) AS click_count,
    AVG(c.click_price) AS avg_click_price
FROM
    click c
JOIN
    province p ON c.province = p.province_name
GROUP BY
    p.region_name;

输出的表格

+----------+---------------+-------------------+
| region   |   click_count |   avg_click_price |
+==========+===============+===================+
|          |          2479 |           1.05679 |
+----------+---------------+-------------------+
| 东北地区 |         39404 |           1.05224 |
+----------+---------------+-------------------+
| 华东地区 |        289745 |           1.04953 |
+----------+---------------+-------------------+
| 华中地区 |         67297 |           1.04822 |
+----------+---------------+-------------------+
| 华北地区 |        123153 |           1.05077 |
+----------+---------------+-------------------+
| 华南地区 |        290853 |           1.05135 |
+----------+---------------+-------------------+
| 西北地区 |          8130 |           1.04146 |
+----------+---------------+-------------------+
| 西南地区 |         26264 |           1.04751 |
+----------+---------------+-------------------+
APP的点击单价

再换一个题目 “请问如何统计那个APP的点击单价最高,请给出SQL语句”

AI生成的SQL

SELECT app.app_name, AVG(click.click_price) AS avg_click_price
FROM click
JOIN app ON click.app_id = app.id
GROUP BY click.app_id
ORDER BY avg_click_price DESC
LIMIT 1;

输出的表格

+------------------------+-------------------+
| app_name               |   avg_click_price |
+========================+===================+
| 安卓APP-开屏位-ZX-CS01 |           1.04937 |
+------------------------+-------------------+

产品实现

根据以上的技术研究,自己开发了一套相关的免费产品,有兴趣的朋友可以下载体验:

使用自然语言即可操作数据库,无需编写复杂的 SQL 语句。说出您的需求,数据库智能体就能返回您想要的数据,并生成可视化图表,让数据分析更简单。

🚨请谨慎在生产环境下运行数据库智能体,出现删库等事故本人概不负责😀

自然语言查询

使用自然语言查询,比如:请问如何统计各个大区的点击数

数据库智能体详解

点击执行得到查询结果。

执行SQL查询

对查询结果不满意,有的表总是命中不了,直接指定全部表或者特定表

数据库智能体详解-指定表

查询结果可视化

数据库智能体会自动生成可视化图表,让数据更直观。可以使用专属Prompt对可视化进行详细要求(比如是饼图还是折线图)

SQL查询结果可视化

数据库优化

发现查询太慢了,居然耗时4060ms,要AI给出数据库优化方案

给出SQL优化方案

根据优化建议创建索引后可以看到,查询时间从4060ms降低到262ms

再次执行SQL查询

数据库结构分析

帮助分析数据库的关系和结构,比如:找到和客户相关的表

数据库结构分析

懒得写注释,让AI帮我写

数据库注释

实战检验

我自己试了一下,使用自然语言完成课程中的问题,成功率还是蛮高的😀

  • SQL 大全

    • 以一个小型项目为线索,采用一问一答的方式全面讲解 PostgreSQL 中各类 SQL 语句的编写技巧。
    • 每章节配备可在 Local Agents 终端和数据库智能体中运行的实战课件。
    • github.com/lgc653/cour… ➡ database

题目中故意埋了坑的问题也能够答对

PostgreSQL课程展示

PostgreSQL课程展示

详细介绍

效果评测

下面是我自己针对国内外各个大模型做的评测,大家可以看看,效果还是不错的