——别让“Hello World”停留在内存里,让它在国产数据库中生根发芽
大家好,我是那个总在演示会上被问“能不能把这棵树画出来?”、又在 KES 表里手动拼接 feature + ' <= ' + threshold 的老架构。今天我们要干一件看似简单、却极具象征意义的事:
用决策树对经典的鸢尾花(Iris)数据集做分类,并把整棵树从电科金仓 KingbaseES(KES)中读出来、训练、再完整地可视化输出。
很多人说:“鸢尾花?那是玩具数据。” 但真相是:玩具数据的价值,在于验证你的技术栈是否闭环。
如果你连 Iris 都跑不通端到端——从建表、入库、训练到可视化——那你凭什么相信自己能在真实业务中驾驭更复杂的模型?
今天我们就用 Java + 自研 CART 实现 + Graphviz,完成一次纯国产技术栈的 AI 全流程演练。全程不依赖 Python、不调 sklearn,只为证明一件事:
在电科金仓的土壤上,我们也能长出清晰、可解释、可展示的 AI 之树。
一、为什么选鸢尾花?因为它暴露一切细节
Iris 数据集只有 150 条记录,4 个连续特征(花萼/花瓣长宽),3 个类别(setosa, versicolor, virginica)。 但它足够小,能让你看清:
-
特征如何分裂;
-
树深多少合适;
-
可视化是否准确。
更重要的是:它没有缺失值、没有噪声、标签干净——是检验你工程链路是否通畅的“试金石”。
二、在 KES 中建表并加载数据
首先,在 KingbaseES 中创建 schema 和表:
CREATE SCHEMA IF NOT EXISTS ai_demo;
CREATE TABLE ai_demo.iris_data (
id SERIAL PRIMARY KEY,
sepal_length REAL NOT NULL,
sepal_width REAL NOT NULL,
petal_length REAL NOT NULL,
petal_width REAL NOT NULL,
species VARCHAR(20) NOT NULL
);
12345678910
然后用 Java 批量插入(模拟从 CSV 或外部系统导入):
public void loadIrisDataToKES(Connection conn) throws SQLException {
double[][] features = {
{5.1, 3.5, 1.4, 0.2},
{7.0, 3.2, 4.7, 1.4},
{6.3, 3.3, 6.0, 2.5}
};
String[] labels = {"setosa", "versicolor", "virginica", };
String sql = "INSERT INTO ai_demo.iris_data (sepal_length, sepal_width, petal_length, petal_width, species) VALUES (?, ?, ?, ?, ?)";
try (PreparedStatement ps = conn.prepareStatement(sql)) {
for (int i = 0; i 🔗 确保使用 [电科金仓 JDBC 驱动](https://www.kingbase.com.cn/download.html#drive) 支持 `REAL` 类型精确写入。
---
### []()[]()三、从 KES 读取数据并训练 CART 树
复用上期优化后的 CART 实现,支持多分类(基尼不纯度自然扩展):
public List loadIrisFromKES(Connection conn) throws SQLException { String sql = "SELECT sepal_length, sepal_width, petal_length, petal_width, species FROM ai_demo.iris_data"; List data = new ArrayList<>();
try (PreparedStatement ps = conn.prepareStatement(sql);
ResultSet rs = ps.executeQuery()) {
while (rs.next()) {
Map feats = new HashMap<>();
feats.put("sepal_length", new FeatureValue("sepal_length", rs.getDouble("sepal_length")));
feats.put("sepal_width", new FeatureValue("sepal_width", rs.getDouble("sepal_width")));
feats.put("petal_length", new FeatureValue("petal_length", rs.getDouble("petal_length")));
feats.put("petal_width", new FeatureValue("petal_width", rs.getDouble("petal_width")));
String species = rs.getString("species");
data.add(new Instance(feats, species));
}
}
return data;
}
1234567891011121314151617181920
训练(限制 maxDepth=3 避免过拟合):
List irisData = loadIrisFromKES(conn); Set features = Set.of("sepal_length", "sepal_width", "petal_length", "petal_width"); TreeNode root = buildCartTree(irisData, features, 0, minSamplesSplit=5, maxDepth=3); 123
典型分裂结果(人工验证):
- 根节点:`petal_length <= 2.45` → 左子树全为 setosa;
- 右子树:`petal_width <= 1.75` → 区分 versicolor/virginica。
---
### []()[]()四、模型可视化:生成 Graphviz DOT 文件
为了让业务方“看见”模型,我们输出标准 DOT 格式:
public void exportTreeToDot(TreeNode node, PrintWriter writer, AtomicInteger nodeId) { int currentId = nodeId.getAndIncrement();
if (node.isLeaf()) {
String label = node.prediction + "\\n(samples=" + node.sampleCount + ")";
writer.println(currentId + " [label=\"" + label + "\", shape=box, style=filled, fillcolor=\"#e6f7ff\"];");
} else {
String condition = node.featureName + " " + leftId + " [label=\"True\"];");
int rightId = nodeId.get();
exportTreeToDot(node.right, writer, nodeId);
writer.println(currentId + " -> " + rightId + " [label=\"False\"];");
}
}
try (PrintWriter writer = new PrintWriter("iris_tree.dot")) { writer.println("digraph IrisTree {"); exportTreeToDot(root, writer, new AtomicInteger(0)); writer.println("}"); } 1234567891011121314151617181920212223242526272829
生成的 `iris_tree.dot` 示例:
digraph IrisTree { 0 [label="petal_length 1 [label="True"]; 2 [label="petal_width 2 [label="False"]; ... } 12345678
用 Graphviz 渲染:
dot -Tpng iris_tree.dot -o iris_tree.png 1

> ✅ 这就是**可交付、可汇报、可嵌入文档的模型资产**。
---
### []()[]()五、将树结构存回 KES 供在线服务
为支持实时预测,我们将树节点持久化:
CREATE TABLE ai_models.iris_tree_nodes (
node_id SERIAL PRIMARY KEY,
parent_id INT,
is_leaf BOOLEAN,
feature_name VARCHAR(32),
threshold DOUBLE PRECISION,
prediction VARCHAR(20),
sample_count INT,
path TEXT
);
12345678910
Java 写入(略去递归逻辑):
saveNodeToKES(conn, root, null, "root"); 1
在线预测服务只需递归查询:
WITH RECURSIVE predict(path, feature, threshold, pred) AS ( SELECT 'root', feature_name, threshold, prediction FROM ai_models.iris_tree_nodes WHERE path = 'root' UNION ALL SELECT ..., CASE WHEN input.petal_length