决策树之ID3算法的实现(Java版)

2,289 阅读5分钟

ID3算法是什么?

ID3算法是一种贪心算法,用来构造决策树。ID3算法起源于概念学习系统(CLS),以信息熵的下降速度为选取测试属性的标准,即在每个节点选取还尚未被用来划分的具有最高信息增益的属性作为划分标准,然后继续这个过程,直到生成的决策树能完美分类训练样例。

ID3算法的伪代码

if 样本S全部属于同一个类别C then
   创建一个叶结点,并标记类标号为C;
   returnelse
   计算属性集F中每一个属性的信息增益,假定增益值最大的属性为A;
   创建结点,取属性A为该结点的决策属性;
   for 结点属性A的每个可能的取值V  do
      为该结点添加一个新的分支,假设SV为属性A取值为V的样本子集;
      if 样本SV全部属于同一个类别C then
          为该分支添加一个叶结点,并标记类标号为C;
     else
          递归调用DT(SV, F-{A}),为该分支创建子树;
     end if
   end for
end if

ID3算法的Java实现

public class Id3Util {

    /**
     * 存储属性名称
     */
    private ArrayList<String> attribute = new ArrayList<>();
    /**
     * 存储每个属性的取值
     */
    private ArrayList<ArrayList<Integer>> attributeValue = new ArrayList<>();
    /**
     * 原始数据
     */
    private ArrayList<Integer[]> data = new ArrayList<>();
    /**
     * 决策变量在属性集中的索引
     */
    int decatt = 0;
    
    public static final String patternString = "@attribute(.*)[{](.*?)[}]";

    private Document document;
    private static Element root;

    public Id3Util(){
        document = DocumentHelper.createDocument();
        root = document.addElement("root");
        root.addElement("DecisionTree").addAttribute("value", "null");
    }

    /**
     * 初始化决策树
     */
    public void init(){
        // 读取arff文件
        readArff(new File("C:\\1.arff"));
        // 设置决策变量
        setDecatt("purchase");
        ArrayList<Integer> attributeIndexList =new ArrayList<>(attribute.size());
        for(int i = 0; i < attribute.size(); i++){
            if(i != decatt) {
                attributeIndexList.add(i);
            }
        }
        ArrayList<Integer> dataIndexList = new ArrayList<>(data.size());
        for(int i = 0; i < data.size(); i++){
            dataIndexList.add(i);
        }
        // 建立决策树
        this.buildDecisionTree("DecisionTree", null, dataIndexList, attributeIndexList, root);
        // 将document对象
        this.writeXML("C:\\decision_tree.xml");
    }

    /**
     * 读取arff文件
     * @param file arff文件
     */
    private void readArff(File file){
        try {
            FileReader fileReader = new FileReader(file);
            BufferedReader bufferedReader = new BufferedReader(fileReader);
            String line = null;
            Pattern pattern = Pattern.compile(patternString);
            while((line = bufferedReader.readLine()) != null){
                Matcher matcher = pattern.matcher(line);
                if(matcher.find()){
                    attribute.add(matcher.group(1).trim());
                    String[] values = matcher.group(2).split(",");
                    ArrayList<Integer> array = new ArrayList<>(values.length);
                    for(String value: values){
                        array.add(Integer.parseInt(value.trim()));
                    }
                    attributeValue.add(array);
                }else if(line.startsWith("@data")){
                    while((line = bufferedReader.readLine()) != null){
                        if(line == ""){
                            continue;
                        }
                        String[] array = line.split(",");
                        Integer[] row = new Integer[array.length];
                        for(int i = 0; i < array.length; i++){
                             row[i] = Integer.parseInt(array[i]);
                        }
                        data.add(row);
                    }
                }else {
                    continue;
                }
            }
            bufferedReader.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    /**
     * 设置决策变量
     * @param n 决策变量索引
     */
    private void setDecatt(int n) {
        if (n < 0 || n >= attribute.size()) {
            System.out.println("决策变量设置失败");
            System.exit(2);
        }
        decatt = n;
    }

    /**
     * 设置决策变量
     * @param name 决策变量名称
     */
    private void setDecatt(String name) {
        int n = attribute.indexOf(name);
        setDecatt(n);
    }

    /**
     * 计算给定数组的信息熵
     * @param array 数组
     * @return
     */
    private double getEntropy(int[] array){
        int sum = 0;
        for(int i = 0; i < array.length; i++){
            sum += array[i];
        }
        return getEntropy(array, sum);
    }

    /**
     * 计算给定数组的信息熵
     * @param array 数组
     * @param sum 给定数组的算术和
     * @return
     */
    private double getEntropy(int[] array, int sum) {
        double entropy = 0.0;
        for (int i = 0; i < array.length; i++) {
            if(array[i] == 0){
                continue;
            }
            entropy -= ((double) array[i] / sum) * Utils.log2((double) array[i] / sum);
        }
        return entropy;
    }

    /**
     * 检查给定数据集是否同属于一个类别
     * @param subset 数据集
     * @return 检查结果
     */
    private boolean isClassesUnanimous(ArrayList<Integer> subset){
        int count = 1;
        Integer value = data.get(subset.get(0))[decatt];
        for(int i = 1; i < subset.size(); i++){
            Integer next = data.get(subset.get(i))[decatt];
            if(value.equals(next)){
                ++count;
            }
        }
        // 计算下数据集中同属一个类别的概率,如果此概率大于我们设置的一个值
        // 可以认为此数据集同属于一个类别
        double ratio = (double) count / subset.size();
        return ratio >= 0.95 ? true : false;
    }

    /**
     * 计算原始数据的子集以第index个属性为节点时计算它的信息熵
     * @param subset 数据集子集索引集合
     * @param index 属性索引
     * @return
     */
    private double calNodeEntropy(ArrayList<Integer> subset, int index){
        // 数据集剩余个数
        int sum = subset.size();
        double entropy = 0.0;
        int[][] info = new int[attributeValue.get(index).size()][];
        for(int i = 0; i < info.length; i++){
            info[i] = new int[attributeValue.get(decatt).size()];
        }
        int[] count = new int[attributeValue.get(index).size()];
        for(int i = 0; i < sum; i++){
            int n = subset.get(i);
            Integer nodeValue = data.get(n)[index];
            int nodeIndex = attributeValue.get(index).indexOf(nodeValue);
            count[nodeIndex]++;
            Integer decattValue = data.get(n)[decatt];
            int decattIndex = attributeValue.get(decatt).indexOf(decattValue);
            info[nodeIndex][decattIndex]++;
        }
        for(int i = 0; i < info.length; i++){
            entropy += getEntropy(info[i]) * count[i] / sum;
        }
        return entropy;
    }

    /**
     * 建立决策树
     * @param name 节点名
     * @param value 值
     * @param subset 数据集的索引集合
     * @param selatt 属性集的索引集合
     */
    private void buildDecisionTree(String name, String value, ArrayList<Integer> subset, ArrayList<Integer> selatt, Element parent){
        Element element = null;
        List<Element> list = parent.selectNodes(name);
        Iterator<Element> iterator = list.iterator();
        // 确定element的位置
        while(iterator.hasNext()){
            element = iterator.next();
            if(element.attributeValue("value").equals(value)){
                break;
            }
        }
        // 如果数据集同属于一个类别,则创建叶子结点
        if(isClassesUnanimous(subset)) {
            element.setText(String.valueOf(data.get(subset.get(0))[decatt]));
            return;
        }
        int minIndex = -1;
        int minEntropySelatt = -1;
        double minEntropy = Double.MAX_VALUE;
        // 获取最低属性熵的索引并赋值给minIndex
        for(int i = 1; i < selatt.size(); i++){
            if(i == decatt){
                continue;
            }
            // 计算每个属性的信息熵
            double entropy = calNodeEntropy(subset, selatt.get(i));
            if(entropy < minEntropy){
                minEntropySelatt = selatt.get(i);
                minIndex = i;
                minEntropy = entropy;
            }
        }
        String nodeName = attribute.get(minEntropySelatt);
        // 从属性表中去除此属性
        ArrayList<Integer> remainSelatt = removeAttribute(selatt, minIndex);
        // 获取该属性的取值范围
        ArrayList<Integer> attributeValues = attributeValue.get(minEntropySelatt);
        for(Integer attValue : attributeValues){
            Element newElement = element.addElement(nodeName).addAttribute("value", String.valueOf(attValue));
            ArrayList<Integer> remainSubset = new ArrayList<>();
            for (int i = 0; i < subset.size(); i++){
                if(data.get(subset.get(i))[minEntropySelatt].equals(attValue)){
                    remainSubset.add(subset.get(i));
                }
            }
            // 样本子集为空,删除该结点
            if(remainSubset.size() == 0){
                element.remove(newElement);
                continue;
            }
            // 样本子集全部属于同一个类别,则创建叶子结点并标号
            if(isClassesUnanimous(remainSubset)){
                newElement.setText(String.valueOf(data.get(remainSubset.get(0))[decatt]));
                continue;
            }
            // 递归调用建立树
            buildDecisionTree(nodeName, String.valueOf(attValue), remainSubset, remainSelatt, element);
        }
    }

    /**
     * 删除属性表中指定索引的属性
     * @param selatt 属性表
     * @param index 指定索引
     * @return
     */
    private ArrayList<Integer> removeAttribute(ArrayList<Integer> selatt, int index) {
        ArrayList<Integer> arrayList = new ArrayList<>(selatt);
        arrayList.remove(index);
        return arrayList;
    }

    /**
     * 将xml写入文件
     * @param fileName 文件名
     */
    private void writeXML(String fileName){
            try {
                File file = new File(fileName);
                if(!file.exists()) {
                    file.createNewFile();
                }
                FileWriter fileWriter = new FileWriter(file);
                OutputFormat outputFormat = OutputFormat.createPrettyPrint();
                XMLWriter writer = new XMLWriter(fileWriter, outputFormat);
                writer.write(document);
                writer.close();
            } catch (IOException e) {
                System.out.println("写入xml文件失败");
            }
    }
}

参考:归纳决策树ID3(Java实现)