Python修改文本数据出现编码问题

709 阅读4分钟

作者:老九—技术大黍

社交:知乎

公众号:老九学堂(新人有惊喜)

特别声明:原创不易,未经授权不得转载或抄袭,如需转载可联系笔者授权

前言

实现决策树的时候遇到了一个问题:使用python修改的文本数据后,Java运行报错

先上代码——实现决策树

1、树节点


public class treeNode{
    private String sname;//节点名
    public treeNode(String str) {
        sname=str;
    }
    public String getsname() {
        return sname;
    }
    ArrayList<String> label=new ArrayList<String>();//和子节点间的边标签
    ArrayList<treeNode> node=new ArrayList<treeNode>();//对应子节点
}

2、 实现决策树


public class ID3 {
    private ArrayList<String> label = new ArrayList<String>();//特征标签
    private ArrayList<ArrayList<String>> date = new ArrayList<ArrayList<String>>();//数据集
    private ArrayList<ArrayList<String>> test = new ArrayList<ArrayList<String>>();//测试数据集
    private ArrayList<String> sum = new ArrayList<String>();//分类种类数
    private String kind;

    public static ArrayList<Output> outputs = new ArrayList<>();//用来存储输出的数组

    public ID3(String path, String path0) throws FileNotFoundException {     
        getDate(path); //初始化训练数据并得到分类种数
        gettestDate(path0);//获取测试数据集
        init(date);
    }

    public void init(ArrayList<ArrayList<String>> date) {    
        sum.add(date.get(0).get(date.get(0).size() - 1));//得到种类数
        for (int i = 0; i < date.size(); i++) {
            if (sum.contains(date.get(i).get(date.get(0).size() - 1)) == false) {
                sum.add(date.get(i).get(date.get(0).size() - 1));
            }
        }
    }

    /* 获取测试数据集 */
    public void gettestDate(String path) throws FileNotFoundException {
        String str;
        int i = 0;
        try {
            //BufferedReader in=new BufferedReader(new FileReader(path));
            FileInputStream fis = new FileInputStream(path);
            InputStreamReader isr = new InputStreamReader(fis, "UTF-8");
            BufferedReader in = new BufferedReader(isr);
            while ((str = in.readLine()) != null) {
                String[] strs = str.split(",");
                ArrayList<String> line = new ArrayList<String>();
                boolean isFinished=true;//判断是否有add进output数组过
                for (int j = 0; j < strs.length; j++) {
                    line.add(strs[j]);
                    if(!isFinished){
                        Output output=new Output(strs[j/(label.size()-1)],"null");
                        outputs.add(output);
                        isFinished=true;
                    }
                    if(j%(label.size()-1)==0){
                        isFinished=false;
                    }
                }
                test.add(line);
                i++;
            }

            in.close();
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    //获取训练数据集
    public void getDate(String path) throws FileNotFoundException {
        String str;
        int i = 0;
        try {
            FileInputStream fis = new FileInputStream(path);
            InputStreamReader isr = new InputStreamReader(fis, "UTF-8");
            BufferedReader in = new BufferedReader(isr);
            while ((str = in.readLine()) != null) {
                if (i == 0) {
                    String[] strs = str.split(",");
                    for (int j = 0; j < strs.length; j++) {
                        label.add(strs[j]);
                    }
                    i++;
                    continue;
                }
                String[] strs = str.split(",");
                ArrayList<String> line = new ArrayList<String>();
                for (int j = 0; j < strs.length; j++) {
                    line.add(strs[j]);
                }
                date.add(line);
                i++;
            }
            in.close();
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public double Ent(ArrayList<ArrayList<String>> dat) {
        //计算总的信息熵
        int all = 0;
        double amount = 0.0;
        for (int i = 0; i < sum.size(); i++) {
            for (int j = 0; j < dat.size(); j++) {
                if (sum.get(i).equals(dat.get(j).get(dat.get(0).size() - 1))) {
                    all++;
                }
            }
            if ((double) all / dat.size() == 0.0) {
                continue;
            }
            amount += ((double) all / dat.size()) * (Math.log(((double) all / dat.size())) / Math.log(2.0));
            all = 0;
        }
        if (amount == 0.0) {
            return 0.0;
        }
        return -amount;//计算信息熵
    }

    /* 计算条件熵并返回信息增益值 */
    public double condtion(int a, ArrayList<ArrayList<String>> dat) {
        ArrayList<String> all = new ArrayList<String>();
        double c = 0.0;
        all.add(dat.get(0).get(a));
        //得到属性种类
        for (int i = 0; i < dat.size(); i++) {
            if (all.contains(dat.get(i).get(a)) == false) {
                all.add(dat.get(i).get(a));
            }
        }
        ArrayList<ArrayList<String>> plus = new ArrayList<ArrayList<String>>();
        //部分分组
        ArrayList<ArrayList<ArrayList<String>>> count = new ArrayList<ArrayList<ArrayList<String>>>();
        //分组总和
        for (int i = 0; i < all.size(); i++) {
            for (int j = 0; j < dat.size(); j++) {
                if (true == all.get(i).equals(dat.get(j).get(a))) {
                    plus.add(dat.get(j));
                }
            }
            count.add(plus);
            c += ((double) count.get(i).size() / dat.size()) * Ent(count.get(i));
            plus.removeAll(plus);
        }
        return (Ent(dat) - c);
        //返回条件熵
    }

    /* 计算信息增益最大属性 */
    public int Gain(ArrayList<ArrayList<String>> dat) {
        ArrayList<Double> num = new ArrayList<Double>();
        //保存各信息增益值
        for (int i = 0; i < dat.get(0).size() - 1; i++) {
            num.add(condtion(i, dat));
        }
        int index = 0;
        double max = num.get(0);
        for (int i = 1; i < num.size(); i++) {
            if (max < num.get(i)) {
                max = num.get(i);
                index = i;
            }
        }
        return index;
    }

    //构建决策树
    public treeNode creattree(ArrayList<ArrayList<String>> dat) {
        int index = Gain(dat);
        treeNode node = new treeNode(label.get(index));
        ArrayList<String> s = new ArrayList<String>();//属性种类
        s.add(dat.get(0).get(index));
        for (int i = 1; i < dat.size(); i++) {
            if (s.contains(dat.get(i).get(index)) == false) {
                s.add(dat.get(i).get(index));
            }
        }
        ArrayList<ArrayList<String>> plus = new ArrayList<ArrayList<String>>();
        //部分分组
        ArrayList<ArrayList<ArrayList<String>>> count = new ArrayList<ArrayList<ArrayList<String>>>();
        //分组总和
        //得到节点下的边标签并分组
        for (int i = 0; i < s.size(); i++) {
            node.label.add(s.get(i));//添加边标签
            for (int j = 0; j < dat.size(); j++) {
                if (true == s.get(i).equals(dat.get(j).get(index))) {
                    plus.add(dat.get(j));
                }
            }
            count.add(plus);

            //以下添加结点
            int k;
            String str = count.get(i).get(0).get(count.get(i).get(0).size() - 1);
            for (k = 1; k < count.get(i).size(); k++) {
                if (false == str.equals(count.get(i).get(k).get(count.get(i).get(k).size() - 1))) {
                    break;
                }
            }
            if (k == count.get(i).size()) {
                treeNode dd = new treeNode(str);
                node.node.add(dd);
            } else {
                node.node.add(creattree(count.get(i)));
            }
            plus.removeAll(plus);
        }
        return node;
    }

    //输出决策树
    public void print(ArrayList<ArrayList<String>> dat) {
        System.out.println("构建的决策树如下:");
        treeNode node = null;
        node = creattree(dat);//类
        put(node);//递归调用
    }

    //用于递归的函数
    public void put(treeNode node) {
        System.out.println("结点:" + node.getsname() + "\n");
        for (int i = 0; i < node.label.size(); i++) {
            System.out.println(node.getsname() + "的标签属性:" + node.label.get(i));
            if (node.node.get(i).node.isEmpty() == true) {
                System.out.println("叶子结点:" + node.node.get(i).getsname());
            } else {
                put(node.node.get(i));
            }
        }
    }

    /* 用于对待决策数据进行预测并将结果保存在指定路径 */
    public void testdate(ArrayList<ArrayList<String>> test, String path) throws IOException {
        treeNode node = null;
        int count = 0;
        node = creattree(this.date);//类
        try {
            BufferedWriter out = new BufferedWriter(new FileWriter(path));
            for (int i = 0; i < test.size(); i++) {
                testput(node, test.get(i));//递归调用
                for (int j = 0; j < test.get(i).size(); j++) {
                    out.write(test.get(i).get(j) + ",");
                }
                if (kind.equals(date.get(i).get(date.get(i).size() - 1)) == true) {
                    count++;
                }
                out.write(kind);
                outputs.get(i).kind=kind;
                out.newLine();
            }
            out.flush();
            out.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    //用于测试的递归调用
    public void testput(treeNode node, ArrayList<String> t) {
        int index = 0;
        for (int i = 0; i < this.label.size(); i++) {
            if (this.label.get(i).equals(node.getsname()) == true) {
                index = i;
                break;
            }
        }
        for (int i = 0; i < node.label.size(); i++) {
            if (t.get(index).equals(node.label.get(i)) == false) {
                continue;
            }	
            if (node.node.get(i).node.isEmpty() == true) {
                this.kind = node.node.get(i).getsname();//取出分类结果
            } else {
                testput(node.node.get(i), t);
            }
        }
    }

    public static void main(String[] args) throws IOException {
        String data = "src\\com\\xuetang9\\data.txt";//训练数据集
        String test = "src\\com\\xuetang9\\test.txt";//测试数据集
        String result = "src\\com\\xuetang9\\result.txt";//预测结果集
        ID3 id = new ID3(data, test);//初始化数据
        id.print(id.date);//构建并输出决策树
        id.testdate(id.test,result);//预测数据并输出结果
        System.out.println("请输入你想查询的地点");
        Scanner scanner= new Scanner(System.in);
        String input=scanner.next();
        for (Output output: outputs) {
            if(output.countryName.equals(input)){
                System.out.println("查询结果为:"+output.kind);
            }
        }
    }
}

3、用来存储输出结果的工具类

public class Output{
    public String countryName;
    public String kind;
    public Output(String countryName,String kind){
        this.countryName=countryName;
        this.kind=kind;
    }

    @Override
    public String toString() {
        return "Output{" +
                "countryName='" + countryName + '\'' +
                ", kind='" + kind + '\'' +
                '}';
    }
}

4、data文本文件

地点,阳光,人数,交通,风力,空气,雨水,适合旅游
英国,明媚,拥挤,拥挤,微风,稍浑,少量,适合旅游
美国,微暗,拥挤,拥挤,微风,稍浑,少量,适合旅游
日本,灰暗,拥挤,稍拥,微风,清新,少量,适合旅游
韩国,明媚,拥挤,不拥,微风,清新,少量,适合旅游
英国,微暗,拥挤,稍拥,微风,清新,少量,适合旅游
美国,明媚,拥挤,稍拥,微风,稍浑,中量,适合旅游
日本,灰暗,稍拥,稍拥,中风,稍浑,少量,适合旅游
韩国,灰暗,稍拥,稍拥,微风,稍浑,少量,不适旅游
英国,灰暗,稍拥,不拥,中风,稍浑,少量,不适旅游
美国,明媚,不拥,拥挤,微风,浑浊,中量,不适旅游
日本,微暗,不拥,拥挤,中风,浑浊,大量,不适旅游
韩国,微暗,拥挤,稍拥,中风,浑浊,大量,不适旅游
英国,明媚,稍拥,稍拥,中风,清新,少量,不适旅游
美国,微暗,稍拥,不拥,中风,清新,少量,不适旅游
日本,灰暗,稍拥,稍拥,微风,稍浑,大量,不适旅游
韩国,微暗,拥挤,稍拥,中风,浑浊,少量,不适旅游
英国,明媚,拥挤,不拥,中风,稍浑,少量,不适旅游
美国,明媚,不拥,拥挤,微风,稍浑,少量,适合旅游

5、test文本文件

英国,灰暗,稍拥,稍拥,微风,稍浑,大量
美国,微暗,拥挤,拥挤,微风,稍浑,少量
日本,明媚,不拥,拥挤,微风,浑浊,中量
韩国,明媚,拥挤,不拥,微风,清新,少量

在Java中完美运行:

image-20210414112724819.png

但是问题来了

使用python更改test数据为

英国,明媚,拥挤,拥挤,微风,稍浑,少量

Java中运行结果为

image-20210414112849878.png

调试发现

//循环中kind成员变量为null
if (kind.equals(date.get(i).get(date.get(i).size() - 1)) == true) {
	count++;
}

追溯kind初始化位置

for (int i = 0; i < node.label.size(); i++) {
    if (t.get(index).equals(node.label.get(i)) == false) {
        continue;
    }	
    if (node.node.get(i).node.isEmpty() == true) {
        this.kind = node.node.get(i).getsname();//取出分类结果
    } else {
        testput(node.node.get(i), t);
    }
}

发现修改后的数据没有在决策树中找到相同数据

image-20210414113704354.png

打印日志

image-20210414113704354.png

System.out.println("T:"+t.get(index)+",L:"+node.label.get(i));

调试发现读入的数据出现乱码导致该问题

image-20210414113800525.png

检查文本编码格式,使用python修改前test是utf-8,使用后变为了GBK(大坑啊~~~)。然后修改一下文本的编码格式,问题解决。

总结

文本编程格式常常是我们编程中最经常见的问题。

最后

记得给大黍❤️关注+点赞+收藏+评论+转发❤️

作者:老九学堂—技术大黍

著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。