项目概述
核心功能是构建一个基于 AI 的数据库查询助手,它能够理解自然语言的用户查询,并自动生成相应的 SQL 语句进行查询。这里使用chromadb + Ollama的穷人方案实现,安装的软件少,机器配置要求低,完全免费。
chromadb + Ollama可以参考我的这篇文章:chromadb + Ollama 快速实现RAG应用
具体来说,它做了以下几件事:
- 提取数据库结构定义 (DDL): 从目标 MySQL 数据库中提取所有表的 DDL 语句,作为数据库结构的知识库。
- 利用向量数据库 (ChromaDB) 存储 DDL 信息: 使用 Ollama 模型将 DDL 语句转换为嵌入向量,并存储到 ChromaDB 中,以便进行语义搜索。
- 接收用户自然语言查询: 接收用户以自然语言形式提出的数据库查询请求。
- 语义搜索相关 DDL: 使用 Ollama 模型将用户查询转换为嵌入向量,并在 ChromaDB 中搜索语义最相关的 DDL 语句。
- 利用大语言模型 (Ollama) 生成 SQL: 将用户查询和相关的 DDL 语句作为上下文信息,发送给 Ollama 模型,请求其生成能够满足用户查询需求的 SQL 语句。
- 执行 SQL 并展示结果: 执行生成的 SQL 语句,并将查询结果以表格形式展示给用户。
总而言之,这段代码结合了向量数据库和大型语言模型的优势,实现了基于自然语言的数据库查询功能。它能够降低用户使用数据库的门槛,提高数据查询的效率,为数据分析和决策提供便捷的工具。
sequenceDiagram
participant 用户
participant 应用程序
participant Ollama Embeddings
participant ChromaDB
participant Ollama Coder
用户->>应用程序: 输入自然语言查询
activate 应用程序
应用程序->>Ollama Embeddings: 请求查询嵌入向量
activate Ollama Embeddings
Ollama Embeddings-->>应用程序: 返回查询嵌入向量
deactivate Ollama Embeddings
应用程序->>ChromaDB: 查询相关 DDL
activate ChromaDB
ChromaDB-->>应用程序: 返回相关 DDL
deactivate ChromaDB
应用程序->>Ollama Coder: 请求生成 SQL (查询 + DDL)
activate Ollama Coder
Ollama Coder-->>应用程序: 返回 SQL 语句
deactivate Ollama Coder
应用程序->>MySQL 数据库: 执行 SQL 查询
activate MySQL 数据库
MySQL 数据库-->>应用程序: 返回查询结果
deactivate MySQL 数据库
应用程序->>用户: 展示查询结果 (表格)
deactivate 应用程序
技术选型
- Embedding Models:Ollama + bge-large-zh:v1.5
- Vector Databases:chromadb
- Generation Models: Ollama + deepseek-coder-v2
这样的选择主要是因为安装的软件较少,对机器配置要求低,完全免费。只需安装
chromadb + Ollama
即可。我机器是4G的3050也能勉强运行。
代码实现
安装依赖
pip install mysql-connector-python
pip install tabulate
pip install ollama chromadb
-
mysql-connector-python用于
- 获取表的DDL
- 执行SQL
-
tabulate用于把SQL结果显示
-
ollama用于运行本地模型(Embedding Models和Generation Models)
-
chromadb用于Vector Databases
数据库定义
这里是一个大家最讨厌的广告展示点击系统
本次测试的题目是 “请问如何统计各个大区的点击数,请给出SQL语句”
,根据下面的DDL可以看出,实际只和province
和click
表相关
- click中有省份的信息
- province中有省份和大区的关系
比如province表
province_name | region_name |
---|---|
上海 | 华东地区 |
北京 | 华北地区 |
…… | …… |
比如click表
click_id | province | click_price | 其它列 |
---|---|---|---|
129070334 | 浙江 | 1.18 | …… |
129070335 | 北京 | 1.53 | |
…… | …… | …… | …… |
具体定义如下,还有一些乱七八糟的表略过……
当然数据库的DDL的注释质量越高,我们得到的效果就越好
CREATE TABLE `ad_provider` (
`id` int(20) NOT NULL COMMENT 'ID索引',
`provider_name` varchar(255) NOT NULL COMMENT '广告提供商名称',
`call_url` varchar(255) NOT NULL COMMENT '广告API地址',
PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COMMENT='广告提供商'
CREATE TABLE `app` (
`id` int(20) NOT NULL COMMENT 'ID索引',
`app_name` varchar(255) DEFAULT NULL COMMENT '应用名称',
`app_key` varchar(255) DEFAULT NULL COMMENT '应用KEY',
PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COMMENT='应用'
CREATE TABLE `weight` (
`id` int(20) NOT NULL COMMENT 'ID索引',
`app_id` int(20) DEFAULT NULL COMMENT '应用编号',
`ad_provider_id` int(255) DEFAULT NULL COMMENT '广告提供商编号',
`ad_weight` int(4) DEFAULT NULL COMMENT '广告权重',
`ad_type` int(4) DEFAULT NULL COMMENT '广告类别(横幅,开屏,视频)',
PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COMMENT='分配权重'
CREATE TABLE `click` (
`click_id` int(11) NOT NULL,
`ad_id` int(11) NOT NULL COMMENT '广告ID',
`app_id` int(11) NOT NULL COMMENT '广告位ID',
`check_uuid` varchar(200) COLLATE utf8mb4_bin NOT NULL DEFAULT '' COMMENT '用户唯一ID',
`create_dateline` int(11) NOT NULL DEFAULT '0' COMMENT '时间戳',
`create_date` date NOT NULL DEFAULT '2000-01-01' COMMENT '日期',
`create_h` smallint(6) NOT NULL DEFAULT '0' COMMENT '小时',
`flag` smallint(6) NOT NULL DEFAULT '0',
`province` varchar(50) COLLATE utf8mb4_bin NOT NULL DEFAULT '0' COMMENT '省份',
`city` varchar(50) COLLATE utf8mb4_bin NOT NULL DEFAULT '0' COMMENT '城市',
`isp` varchar(50) COLLATE utf8mb4_bin NOT NULL DEFAULT '0' COMMENT '运营商',
`browser` varchar(50) COLLATE utf8mb4_bin NOT NULL DEFAULT '0' COMMENT '终端浏览器',
`remote_addr` varchar(50) COLLATE utf8mb4_bin NOT NULL DEFAULT '0' COMMENT 'IP地址',
`http_user_agent` text COLLATE utf8mb4_bin NOT NULL COMMENT 'UA',
`click_price` decimal(5,2) NOT NULL DEFAULT '0.00' COMMENT '点击单价',
PRIMARY KEY (`click_id`) USING BTREE
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin
CREATE TABLE `province` (
`province_name` varchar(50) NOT NULL COMMENT '省份',
`region_name` varchar(50) DEFAULT NULL COMMENT '大区',
PRIMARY KEY (`province_name`) USING BTREE
) ENGINE=InnoDB DEFAULT CHARSET=utf8
# 还有一些乱七八糟的表略过,比如广告表
源代码
import chromadb
import mysql.connector
import ollama
import re
from tabulate import tabulate
# 初始化 ChromaDB 客户端
client = chromadb.Client()
# 定义 ChromaDB 集合名称
collection_name = "train_tables"
def get_table_ddl(host, user, password, database, table_name):
"""
获取MySQL数据库表的DDL语句。
Args:
host: 数据库主机地址。
user: 数据库用户名。
password: 数据库密码。
database: 数据库名称。
table_name: 表名。
Returns:
表的DDL语句,字符串类型。
"""
try:
mydb = mysql.connector.connect(
host=host, user=user, password=password, database=database
)
mycursor = mydb.cursor()
mycursor.execute(f"SHOW CREATE TABLE {table_name}")
result = mycursor.fetchone()
ddl = result[1]
return ddl
except mysql.connector.Error as err:
print(f"Error: {err}")
return None
finally:
if mydb.is_connected():
mycursor.close()
mydb.close()
def get_all_tables_ddl(host, user, password, database):
"""
获取MySQL数据库中所有表的DDL语句。
Args:
host: 数据库主机地址。
user: 数据库用户名。
password: 数据库密码。
database: 数据库名称。
Returns:
一个字典,键为表名,值为表的DDL语句。
"""
try:
mydb = mysql.connector.connect(
host=host, user=user, password=password, database=database
)
mycursor = mydb.cursor()
mycursor.execute("SHOW TABLES")
tables = mycursor.fetchall()
ddl_dict = {}
for table in tables:
table_name = table[0]
ddl = get_table_ddl(
host, user, password, database, table_name
) # 调用之前定义的函数
ddl_dict[table_name] = ddl
return ddl_dict
except mysql.connector.Error as err:
print(f"Error: {err}")
return None
finally:
if mydb.is_connected():
mycursor.close()
mydb.close()
def extract_sql_from_markdown(markdown_text):
"""
从Markdown文本中提取SQL语句。
Args:
markdown_text: Markdown文本字符串。
Returns:
SQL语句字符串,如果未找到则返回None。
"""
pattern = r"```sql\n(.*?)\n```" # 匹配以```sql开头和结尾的代码块
match = re.search(pattern, markdown_text, re.DOTALL)
if match:
return match.group(1).strip()
else:
return None
def execute_sql_and_display_table(host, user, password, database, sql_statement):
"""执行SQL语句并以表格形式显示结果"""
try:
mydb = mysql.connector.connect(
host=host, user=user, password=password, database=database
)
mycursor = mydb.cursor()
mycursor.execute(sql_statement)
results = mycursor.fetchall()
# 获取列名
column_names = [i[0] for i in mycursor.description]
# 使用tabulate库格式化输出表格
table = tabulate(results, headers=column_names, tablefmt="grid")
print(table)
except mysql.connector.Error as err:
print(f"Error: {err}")
finally:
if mydb.is_connected():
mycursor.close()
mydb.close()
def embed_and_upsert_document(collection, document_id, document_content):
"""
计算文档的嵌入向量并将其添加到 ChromaDB 集合中。
Args:
collection: ChromaDB 集合对象。
document_id: 文档的唯一标识符。
document_content: 文档内容。
"""
# 使用 Ollama 模型计算嵌入向量
response = ollama.embeddings(
model="dztech/bge-large-zh:v1.5", prompt=document_content
)
embedding = response["embedding"]
# 将文档及其嵌入向量添加到 ChromaDB 集合
collection.upsert(
ids=document_id, embeddings=[embedding], documents=[document_content]
)
print(f"Document {document_id} added to ChromaDB.")
def query_chromadb(collection, query_text, n_results=5):
"""
查询 ChromaDB 集合并返回最相关的文档。
Args:
collection: ChromaDB 集合对象。
query_text: 查询文本。
n_results: 返回结果的数量。
Returns:
ChromaDB 查询结果,包含相关文档及其 ID。
"""
# 使用 Ollama 模型计算查询文本的嵌入向量
response = ollama.embeddings(prompt=query_text, model="dztech/bge-large-zh:v1.5")
# 查询 ChromaDB 集合
results = collection.query(
query_embeddings=[response["embedding"]], n_results=n_results
)
return results
# 使用示例:
host = "localhost"
user = "root"
password = ""
database = "train"
if __name__ == "__main__":
# 获取或创建 ChromaDB 集合
collection = client.get_or_create_collection(name=collection_name)
# 获取所有表的 DDL 语句
ddl_dict = get_all_tables_ddl(host, user, password, database)
# 将 DDL 语句嵌入并添加到 ChromaDB 集合
if ddl_dict:
for table_name, ddl in ddl_dict.items():
print(f"Table: {table_name}")
print(ddl)
print("-" * 20)
embed_and_upsert_document(collection, table_name, ddl)
print("All done!")
# 从请求中获取提示
prompt = "请问如何统计各个大区的点击数,请给出SQL语句"
# 查询 ChromaDB
results = query_chromadb(collection, prompt)
print(results)
# 获取查询结果中的第一个文档
if results["documents"]:
data = "\n\n\n\n\n".join(results["documents"][0])
output = ollama.generate(
model="deepseek-coder-v2",
prompt=f"根据这些数据库结构定义:{data}。回答这个问题:{prompt}",
)
sql_statement = extract_sql_from_markdown(output["response"])
if sql_statement:
print(sql_statement)
# 执行 SQL
execute_sql_and_display_table(host, user, password, database, sql_statement)
else:
print("No results found.")
执行效果
各个大区的点击数
题目是 “请问如何统计各个大区的点击数,请给出SQL语句”
AI生成的SQL
SELECT
p.region_name,
COUNT(c.click_id) AS click_count
FROM
province p
JOIN
click c ON p.province_name = c.province
WHERE
p.region_name IS NOT NULL
GROUP BY
p.region_name;
输出的表格
+---------------+---------------+
| region_name | click_count |
+===============+===============+
| 东北地区 | 39404 |
+---------------+---------------+
| 华东地区 | 289745 |
+---------------+---------------+
| 华中地区 | 67297 |
+---------------+---------------+
| 华北地区 | 123153 |
+---------------+---------------+
| 华南地区 | 290853 |
+---------------+---------------+
| 西北地区 | 8130 |
+---------------+---------------+
| 西南地区 | 26264 |
+---------------+---------------+
大区的点击数及平均点击单价
再换一个题目 “请问如何统计各个大区的点击数及平均点击单价,请给出SQL语句”
AI生成的SQL
SELECT
p.region_name AS region,
COUNT(c.click_id) AS click_count,
AVG(c.click_price) AS avg_click_price
FROM
click c
JOIN
province p ON c.province = p.province_name
GROUP BY
p.region_name;
输出的表格
+----------+---------------+-------------------+
| region | click_count | avg_click_price |
+==========+===============+===================+
| | 2479 | 1.05679 |
+----------+---------------+-------------------+
| 东北地区 | 39404 | 1.05224 |
+----------+---------------+-------------------+
| 华东地区 | 289745 | 1.04953 |
+----------+---------------+-------------------+
| 华中地区 | 67297 | 1.04822 |
+----------+---------------+-------------------+
| 华北地区 | 123153 | 1.05077 |
+----------+---------------+-------------------+
| 华南地区 | 290853 | 1.05135 |
+----------+---------------+-------------------+
| 西北地区 | 8130 | 1.04146 |
+----------+---------------+-------------------+
| 西南地区 | 26264 | 1.04751 |
+----------+---------------+-------------------+
APP的点击单价
再换一个题目 “请问如何统计那个APP的点击单价最高,请给出SQL语句”
AI生成的SQL
SELECT app.app_name, AVG(click.click_price) AS avg_click_price
FROM click
JOIN app ON click.app_id = app.id
GROUP BY click.app_id
ORDER BY avg_click_price DESC
LIMIT 1;
输出的表格
+------------------------+-------------------+
| app_name | avg_click_price |
+========================+===================+
| 安卓APP-开屏位-ZX-CS01 | 1.04937 |
+------------------------+-------------------+
产品实现
根据以上的技术研究,自己开发了一套相关的免费产品,有兴趣的朋友可以下载体验:
使用自然语言即可操作数据库,无需编写复杂的 SQL 语句。说出您的需求,数据库智能体就能返回您想要的数据,并生成可视化图表,让数据分析更简单。
🚨请谨慎在生产环境下运行数据库智能体,出现删库等事故本人概不负责😀
自然语言查询
使用自然语言查询,比如:请问如何统计各个大区的点击数
点击执行得到查询结果。
对查询结果不满意,有的表总是命中不了,直接指定全部表或者特定表
查询结果可视化
数据库智能体会自动生成可视化图表,让数据更直观。可以使用专属Prompt对可视化进行详细要求(比如是饼图还是折线图)
数据库优化
发现查询太慢了,居然耗时4060ms,要AI给出数据库优化方案
根据优化建议创建索引后可以看到,查询时间从4060ms降低到262ms
数据库结构分析
帮助分析数据库的关系和结构,比如:找到和客户相关的表
懒得写注释,让AI帮我写
实战检验
我自己试了一下,使用自然语言完成课程中的问题,成功率还是蛮高的😀
-
SQL 大全
- 以一个小型项目为线索,采用一问一答的方式全面讲解 PostgreSQL 中各类 SQL 语句的编写技巧。
- 每章节配备可在 Local Agents 终端和数据库智能体中运行的实战课件。
- github.com/lgc653/cour… ➡ database
题目中故意埋了坑的问题也能够答对
详细介绍
效果评测
下面是我自己针对国内外各个大模型做的评测,大家可以看看,效果还是不错的