决策树实战:基于 KingbaseES 的鸢尾花分类 —— 模型可视化输出

36 阅读4分钟

——别让“Hello World”停留在内存里,让它在国产数据库中生根发芽

大家好,我是那个总在演示会上被问“能不能把这棵树画出来?”、又在 KES 表里手动拼接 feature + ' <= ' + threshold 的老架构。今天我们要干一件看似简单、却极具象征意义的事:

用决策树对经典的鸢尾花(Iris)数据集做分类,并把整棵树从电科金仓 KingbaseES(KES)中读出来、训练、再完整地可视化输出

很多人说:“鸢尾花?那是玩具数据。” 但真相是:玩具数据的价值,在于验证你的技术栈是否闭环

如果你连 Iris 都跑不通端到端——从建表、入库、训练到可视化——那你凭什么相信自己能在真实业务中驾驭更复杂的模型?

今天我们就用 Java + 自研 CART 实现 + Graphviz,完成一次纯国产技术栈的 AI 全流程演练。全程不依赖 Python、不调 sklearn,只为证明一件事:

在电科金仓的土壤上,我们也能长出清晰、可解释、可展示的 AI 之树


一、为什么选鸢尾花?因为它暴露一切细节

Iris 数据集只有 150 条记录,4 个连续特征(花萼/花瓣长宽),3 个类别(setosa, versicolor, virginica)。 但它足够小,能让你看清:

  • 特征如何分裂;

  • 树深多少合适;

  • 可视化是否准确。

更重要的是:它没有缺失值、没有噪声、标签干净——是检验你工程链路是否通畅的“试金石”。


二、在 KES 中建表并加载数据

首先,在 KingbaseES 中创建 schema 和表:

CREATE SCHEMA IF NOT EXISTS ai_demo;

CREATE TABLE ai_demo.iris_data (
    id            SERIAL PRIMARY KEY,
    sepal_length  REAL NOT NULL,
    sepal_width   REAL NOT NULL,
    petal_length  REAL NOT NULL,
    petal_width   REAL NOT NULL,
    species       VARCHAR(20) NOT NULL  
);
12345678910

然后用 Java 批量插入(模拟从 CSV 或外部系统导入):

public void loadIrisDataToKES(Connection conn) throws SQLException {
    
    double[][] features = {
        {5.1, 3.5, 1.4, 0.2}, 
        {7.0, 3.2, 4.7, 1.4}, 
        {6.3, 3.3, 6.0, 2.5}  
        
    };
    String[] labels = {"setosa", "versicolor", "virginica", };

    String sql = "INSERT INTO ai_demo.iris_data (sepal_length, sepal_width, petal_length, petal_width, species) VALUES (?, ?, ?, ?, ?)";
    try (PreparedStatement ps = conn.prepareStatement(sql)) {
        for (int i = 0; i  🔗 确保使用 [电科金仓 JDBC 驱动](https://www.kingbase.com.cn/download.html#drive) 支持 `REAL` 类型精确写入。
 

---
 

### []()[]()三、从 KES 读取数据并训练 CART 树
 

复用上期优化后的 CART 实现,支持多分类(基尼不纯度自然扩展):
 

public List loadIrisFromKES(Connection conn) throws SQLException { String sql = "SELECT sepal_length, sepal_width, petal_length, petal_width, species FROM ai_demo.iris_data"; List data = new ArrayList<>();

try (PreparedStatement ps = conn.prepareStatement(sql);
     ResultSet rs = ps.executeQuery()) {
    while (rs.next()) {
        Map feats = new HashMap<>();
        feats.put("sepal_length", new FeatureValue("sepal_length", rs.getDouble("sepal_length")));
        feats.put("sepal_width",  new FeatureValue("sepal_width",  rs.getDouble("sepal_width")));
        feats.put("petal_length", new FeatureValue("petal_length", rs.getDouble("petal_length")));
        feats.put("petal_width",  new FeatureValue("petal_width",  rs.getDouble("petal_width")));
        String species = rs.getString("species");
        data.add(new Instance(feats, species));
    }
}
return data;

}

1234567891011121314151617181920

 

训练(限制 maxDepth=3 避免过拟合):
 

List irisData = loadIrisFromKES(conn); Set features = Set.of("sepal_length", "sepal_width", "petal_length", "petal_width"); TreeNode root = buildCartTree(irisData, features, 0, minSamplesSplit=5, maxDepth=3); 123

 

典型分裂结果(人工验证):
 
- 根节点:`petal_length <= 2.45`  左子树全为 setosa;
- 右子树:`petal_width <= 1.75`  区分 versicolor/virginica。

 

---
 

### []()[]()四、模型可视化:生成 Graphviz DOT 文件
 

为了让业务方“看见”模型,我们输出标准 DOT 格式:
 

public void exportTreeToDot(TreeNode node, PrintWriter writer, AtomicInteger nodeId) { int currentId = nodeId.getAndIncrement();

if (node.isLeaf()) {
    
    String label = node.prediction + "\\n(samples=" + node.sampleCount + ")";
    writer.println(currentId + " [label=\"" + label + "\", shape=box, style=filled, fillcolor=\"#e6f7ff\"];");
} else {
    
    String condition = node.featureName + "  " + leftId + " [label=\"True\"];");
    
    int rightId = nodeId.get();
    exportTreeToDot(node.right, writer, nodeId);
    writer.println(currentId + " -> " + rightId + " [label=\"False\"];");
}

}

try (PrintWriter writer = new PrintWriter("iris_tree.dot")) { writer.println("digraph IrisTree {"); exportTreeToDot(root, writer, new AtomicInteger(0)); writer.println("}"); } 1234567891011121314151617181920212223242526272829

 

生成的 `iris_tree.dot` 示例:
 

digraph IrisTree { 0 [label="petal_length 1 [label="True"]; 2 [label="petal_width 2 [label="False"]; ... } 12345678

 

用 Graphviz 渲染:
 

dot -Tpng iris_tree.dot -o iris_tree.png 1

 

![](https://p3-xtjj-sign.byteimg.com/tos-cn-i-73owjymdk6/9d8f3af715674c46b16250912581a793~tplv-73owjymdk6-jj-mark-v1:0:0:0:0:5o6Y6YeR5oqA5pyv56S-5Yy6IEAg55So5oi3MjAyMzA2MzMzNTY2:q75.awebp?rk3s=f64ab15b&x-expires=1772596457&x-signature=LBB6RyykgsNuy3Si%2Faj4PxBXqyY%3D)
 

>  这就是**可交付、可汇报、可嵌入文档的模型资产**。
 

---
 

### []()[]()五、将树结构存回 KES 供在线服务
 

为支持实时预测,我们将树节点持久化:
 

CREATE TABLE ai_models.iris_tree_nodes ( node_id SERIAL PRIMARY KEY, parent_id INT, is_leaf BOOLEAN, feature_name VARCHAR(32), threshold DOUBLE PRECISION, prediction VARCHAR(20), sample_count INT, path TEXT
); 12345678910

 

Java 写入(略去递归逻辑):
 

saveNodeToKES(conn, root, null, "root"); 1

 

在线预测服务只需递归查询:
 

WITH RECURSIVE predict(path, feature, threshold, pred) AS ( SELECT 'root', feature_name, threshold, prediction FROM ai_models.iris_tree_nodes WHERE path = 'root' UNION ALL SELECT ..., CASE WHEN input.petal_length