决策树课程实验| 豆包MarsCode AI刷题

82 阅读7分钟

一.实验目的

编程实现决策树算法ID3;理解算法原理。

二.实验原理****

(1)ID3算法

ID3算法的核心思想就是以信息增益度量属性选择,选择分裂后信息增益最大的属性进行分裂。下面先定义几个要用到的概念。设D为用类别对训练元组进行的划分,则D的(entropy)表示为:

 其中pi表示第i个类别在整个训练元组中出现的概率,可以用属于此类别元素的数量除以训练元组元素总数量作为估计。熵的实际意义表示是D中元组的类标号所需要的平均信息量。现在我们假设将训练元组D按属性A进行划分,则A对D划分的期望信息为:

  而信息增益即为两者的差值:

gain(A)=info(D)-infoA(D)******

ID3算法就是在每次需要分裂时,计算每个属性的增益率,然后选择增益率最大的属性进行分裂。

对于特征属性为连续值,可以如此使用ID3算法:先将D中元素按照特征属性排序,则每两个相邻元素的中间点可以看做潜在分裂点,从第一个潜在分裂点开始,分裂D并计算两个集合的期望信息,具有最小期望信息的点称为这个属性的最佳分裂点,其信息期望作为此属性的信息期望。

三.实验内容:****

利用traindata.txt的数据(75*5,第5列为标签)进行训练,构造决策树;利用构造好的决策树对testdata.txt的数据进行分类,并输出分类准确率。

五.实验步骤:

(1)实现代码与简要说明,通过MATLAB进行实现

1,数据加载与预处理:加载训练数据和测试数据,将数据分为特征矩阵train_X和标签向量train_y,以及测试数据的特征矩阵test_X和标签向量test_y。

`function main
% 读取数据

    train_file_path = 'C:\Users\Lenovo\Desktop\实验课材料\实验二数据\traindata.txt';
    % 读取数据

    train_file_path = 'C:\Users\Lenovo\Desktop\实验课材料\实验二数据\traindata.txt';

    test_file_path = 'C:\Users\Lenovo\Desktop\实验课材料\实验二数据\testdata.txt';

   

    train_data = readmatrix(train_file_path);

    test_data = readmatrix(test_file_path);`

   

    train_X = train_data(:, 1:end-1);

    train_y = train_data(:, end);

   

    test_X = test_data(:, 1:end-1);

    test_y = test_data(:, end);

2,训练决策树与后剪枝:使用ID3算法训练决策树,并对决策树进行后剪枝以防止过拟合,提高分类准确率。

    % 训练决策树

    tree = id3(train_X, train_y);

   

    % 后剪枝

    tree = prune_tree(tree, train_X, train_y);

   

    % 输出决策树

    disp('构建的决策树:');

    disp(tree);

3,预测与计算准确率:使用训练好的决策树对测试数据进行预测,计算并输出分类准确率。

    % 预测

    predictions = predict(tree, test_X);

    accuracy = calculate_accuracy(test_y, predictions);

   

    fprintf('分类准确率: %.2f%%\n', accuracy * 100);

4,决策树可视化:可视化决策树,显示每个分支对应的特征和分裂值。

    % 可视化决策树

    visualize_tree(tree);

end

 

funct`ion tree = id3(X, y, min_gain)

    if nargin < 3

        min_gain = 0.`01; % 默认信息增益阈值

    end

    tree = build_tree(X, y, min_gain);

end

 

function tree = build_tree(X, y, min_gain)

    if length(unique(y)) == 1

        tree = struct('type', 'leaf', 'class', y(1));

        return;

    end

   

    if isempty(X)

        tree = struct('type', 'leaf', 'class', mode(y));

        return;

    end

   

    [best_feature, best_threshold, best_gain] = choose_best_feature(X, y);

   

    if best_gain < min_gain

        tree = struct('type', 'leaf', 'class', mode(y));

        return;

    end

   

    tree = struct('type', 'node', 'feature', best_feature, 'threshold', best_threshold, 'branches', []);
`
5,属性处理:对于每个特征的不同取值,分别构建子树,并递归处理子树。

    left_indices = X(:, best_feature) <= best_threshold;

    right_indices = X(:, best_feature) > best_threshold;

   

    left_subtree = build_tree(X(left_indices, :), y(left_indices), min_gain);

    right_subtree = build_tree(X(right_indices, :), y(right_indices), min_gain);

   

    tree.branches = [struct('value', '<=', 'subtree', left_subtree); struct('value', '>', 'subtree', right_subtree)];

end

 

function [best_feature, best_threshold, best_gain] = choose_best_feature(X, y)

    num_features = size(X, 2);

    base_entropy = entropy(y);

    best_gain = 0;

    best_feature = -1;

    best_threshold = NaN;

   

    for i = 1:num_features

        thresholds = unique(X(:, i));

       

        for j = 1:length(thresholds)

            threshold = thresholds(j);

            left_indices = X(:, i) <= threshold;

            right_indices = X(:, i) > threshold;

           

            left_entropy = entropy(y(left_indices));

            right_entropy = entropy(y(right_indices));

           

            new_entropy = (sum(left_indices) / length(y)) * left_entropy + (sum(right_indices) / length(y)) * right_entropy;

           

            info_gain = base_entropy - new_entropy;

           

            if info_gain > best_gain

                best_gain = info_gain;

                best_feature = i;

                best_threshold = threshold;

            end

        end

    end

end

 

function e = entropy(y)

    classes = unique(y);

    e = 0;

   

    for i = 1:length(classes)

        p = sum(y == classes(i)) / length(y);

        e = e - p * log2(p);

    end

end

 

function predictions = predict(tree, X)

    predictions = zeros(size(X, 1), 1);

   

    for i = 1:size(X, 1)

        predictions(i) = predict_instance(tree, X(i, :));

    end

end

 

function class = predict_instance(tree, instance)

    if strcmp(tree.type, 'leaf')

        class = tree.class;

    else

        feature_value = instance(tree.feature);

        if feature_value <= tree.threshold

            branch = tree.branches(1).subtree;

        else

            branch = tree.branches(2).subtree;

        end

        class = predict_instance(branch, instance);

    end

end

 

function accuracy = calculate_accuracy(y_true, y_pred)

    accuracy = sum(y_true == y_pred) / length(y_true);

end

 

function visualize_tree(tree)

    figure;

    hold on;

    axis off;

    draw_tree(tree, 0.5, 1, 0.1, 0);

    hold off;

end

 

function [x, y] = draw_tree(tree, x, y, x_offset, depth)

    if strcmp(tree.type, 'leaf')

        rectangle('Position', [x-0.025, y-0.02, 0.05, 0.04], 'Curvature', [0.2, 0.2], 'EdgeColor', 'k', 'LineWidth', 1.5, 'FaceColor', [0.9, 0.9, 0.9]);

        text(x, y, num2str(tree.class), 'HorizontalAlignment', 'center', 'FontSize', 12, 'FontWeight', 'bold');

    else

        rectangle('Position', [x-0.025, y-0.02, 0.05, 0.04], 'Curvature', [0.2, 0.2], 'EdgeColor', 'k', 'LineWidth', 1.5, 'FaceColor', [0.8, 0.8, 1]);

        text(x, y, ['X', num2str(tree.feature), ' <= ', num2str(tree.threshold)], 'HorizontalAlignment', 'center', 'FontSize', 12, 'FontWeight', 'bold');

        num_branches = length(tree.branches);

       

        for i = 1:num_branches

            new_x = x + (i - (num_branches + 1) / 2) * x_offset;

            new_y = y - 0.15;

            [child_x, child_y] = draw_tree(tree.branches(i).subtree, new_x, new_y, x_offset / 2, depth + 1);

            plot([x, child_x], [y - 0.02, child_y + 0.02], 'k-', 'LineWidth', 1.5);

            text((x + child_x) / 2, (y + child_y) / 2, tree.branches(i).value, 'HorizontalAlignment', 'center', 'FontSize', 10, 'FontWeight', 'bold', 'BackgroundColor', 'w');

        end

    end

end

 

function pruned_tree = prune_tree(tree, X, y)

    % 交叉验证

    n = length(y);

    k = 5; % 5折交叉验证

    indices = crossvalind('Kfold', y, k);

   

    best_tree = tree;

    best_accuracy = 0;

   

    for i = 1:k

        test_idx = (indices == i);

        train_idx = ~test_idx;

       

        validation_tree = prune_subtree(tree, X(train_idx, :), y(train_idx));

        predictions = predict(validation_tree, X(test_idx, :));

        accuracy = calculate_accuracy(y(test_idx), predictions);

       

        if accuracy > best_accuracy

            best_accuracy = accuracy;

            best_tree = validation_tree;

        end

    end

   

    pruned_tree = best_tree;

end

 

function pruned_tree = prune_subtree(tree, X, y)

    if strcmp(tree.type, 'leaf')

        pruned_tree = tree;

    else

        for i = 1:length(tree.branches)

            value = tree.branches(i).value;

            if strcmp(value, '<=')

                subset_indices = X(:, tree.feature) <= tree.threshold;

            else

                subset_indices = X(:, tree.feature) > tree.threshold;

            end

            tree.branches(i).subtree = prune_subtree(tree.branches(i).subtree, X(subset_indices, :), y(subset_indices));

        end

       

        % 计算剪枝前后的准确率

        original_predictions = predict(tree, X);

        original_accuracy = calculate_accuracy(y, original_predictions);

       

        % 转换成叶节点

        class = mode(y);

        leaf_tree = struct('type', 'leaf', 'class', class);

       

        leaf_predictions = predict(leaf_tree, X);

        leaf_accuracy = calculate_accuracy(y, leaf_predictions);

       

        if leaf_accuracy >= original_accuracy

            pruned_tree = leaf_tree;

        else

            pruned_tree = tree;

        end

    end

end

 

(2)分支的特征、属性说明

1,分支特征选择:决策树的分支基于信息增益选择特征,信息增益最高的特征用于分裂节点。

[best_feature, best_threshold, best_gain] = choose_best_feature(X, y);

2,属性处理:对于每个特征的不同取值,分别构建子树,并递归处理子树。

left_indices = X(:, best_feature) <= best_threshold;

right_indices = X(:, best_feature) > best_threshold;

 

left_subtree = build_tree(X(left_indices, :), y(left_indices), min_gain);

right_subtree = build_tree(X(right_indices, :), y(right_indices), min_gain);

 

(3)终止条件说明

1,纯节点终止:

if length(unique(y)) == 1

    tree = struct('type', 'leaf', 'class', y(1));

    return;

end

当节点中的样本属于同一类时,停止分裂,设置为叶节点。

2,无特征终止:

if isempty(X)

    tree = struct('type', 'leaf', 'class', mode(y));

    return;

end

当节点中无特征可用于分裂时,停止分裂,设置为叶节点,类别为样本中出现次数最多的类。

3,信息增益小于阈值:

if best_gain < min_gain

    tree = struct('type', 'leaf', 'class', mode(y));

    return;

end

当信息增益小于设定阈值min_gain时,停止分裂,设置为叶节点,类别为样本中出现次数最多的类``

 

六.实验结果:

运行代码后,会输出构建的决策树和分类准确率。结果截图显示准确率达到96%,决策树的可视化图像也会在MATLAB图形窗口中显示。每个节点表示一个特征或类别,每条边表示特征的一个取值。

 

构建的决策树:显示决策树的结构,指明每个分支对应的特征和分裂值。

决策树图像:根节点表示第一个分裂特征,叶节点表示分类结果。通过可视化可以清楚地看到决策树的分裂过程和最终分类结果。