作者:老九—技术大黍
社交:知乎
公众号:老九学堂(新人有惊喜)
特别声明:原创不易,未经授权不得转载或抄袭,如需转载可联系笔者授权
前言
实现决策树的时候遇到了一个问题:使用python修改的文本数据后,Java运行报错
先上代码——实现决策树
1、树节点
public class treeNode{
private String sname;//节点名
public treeNode(String str) {
sname=str;
}
public String getsname() {
return sname;
}
ArrayList<String> label=new ArrayList<String>();//和子节点间的边标签
ArrayList<treeNode> node=new ArrayList<treeNode>();//对应子节点
}
2、 实现决策树
public class ID3 {
private ArrayList<String> label = new ArrayList<String>();//特征标签
private ArrayList<ArrayList<String>> date = new ArrayList<ArrayList<String>>();//数据集
private ArrayList<ArrayList<String>> test = new ArrayList<ArrayList<String>>();//测试数据集
private ArrayList<String> sum = new ArrayList<String>();//分类种类数
private String kind;
public static ArrayList<Output> outputs = new ArrayList<>();//用来存储输出的数组
public ID3(String path, String path0) throws FileNotFoundException {
getDate(path); //初始化训练数据并得到分类种数
gettestDate(path0);//获取测试数据集
init(date);
}
public void init(ArrayList<ArrayList<String>> date) {
sum.add(date.get(0).get(date.get(0).size() - 1));//得到种类数
for (int i = 0; i < date.size(); i++) {
if (sum.contains(date.get(i).get(date.get(0).size() - 1)) == false) {
sum.add(date.get(i).get(date.get(0).size() - 1));
}
}
}
/* 获取测试数据集 */
public void gettestDate(String path) throws FileNotFoundException {
String str;
int i = 0;
try {
//BufferedReader in=new BufferedReader(new FileReader(path));
FileInputStream fis = new FileInputStream(path);
InputStreamReader isr = new InputStreamReader(fis, "UTF-8");
BufferedReader in = new BufferedReader(isr);
while ((str = in.readLine()) != null) {
String[] strs = str.split(",");
ArrayList<String> line = new ArrayList<String>();
boolean isFinished=true;//判断是否有add进output数组过
for (int j = 0; j < strs.length; j++) {
line.add(strs[j]);
if(!isFinished){
Output output=new Output(strs[j/(label.size()-1)],"null");
outputs.add(output);
isFinished=true;
}
if(j%(label.size()-1)==0){
isFinished=false;
}
}
test.add(line);
i++;
}
in.close();
} catch (Exception e) {
e.printStackTrace();
}
}
//获取训练数据集
public void getDate(String path) throws FileNotFoundException {
String str;
int i = 0;
try {
FileInputStream fis = new FileInputStream(path);
InputStreamReader isr = new InputStreamReader(fis, "UTF-8");
BufferedReader in = new BufferedReader(isr);
while ((str = in.readLine()) != null) {
if (i == 0) {
String[] strs = str.split(",");
for (int j = 0; j < strs.length; j++) {
label.add(strs[j]);
}
i++;
continue;
}
String[] strs = str.split(",");
ArrayList<String> line = new ArrayList<String>();
for (int j = 0; j < strs.length; j++) {
line.add(strs[j]);
}
date.add(line);
i++;
}
in.close();
} catch (Exception e) {
e.printStackTrace();
}
}
public double Ent(ArrayList<ArrayList<String>> dat) {
//计算总的信息熵
int all = 0;
double amount = 0.0;
for (int i = 0; i < sum.size(); i++) {
for (int j = 0; j < dat.size(); j++) {
if (sum.get(i).equals(dat.get(j).get(dat.get(0).size() - 1))) {
all++;
}
}
if ((double) all / dat.size() == 0.0) {
continue;
}
amount += ((double) all / dat.size()) * (Math.log(((double) all / dat.size())) / Math.log(2.0));
all = 0;
}
if (amount == 0.0) {
return 0.0;
}
return -amount;//计算信息熵
}
/* 计算条件熵并返回信息增益值 */
public double condtion(int a, ArrayList<ArrayList<String>> dat) {
ArrayList<String> all = new ArrayList<String>();
double c = 0.0;
all.add(dat.get(0).get(a));
//得到属性种类
for (int i = 0; i < dat.size(); i++) {
if (all.contains(dat.get(i).get(a)) == false) {
all.add(dat.get(i).get(a));
}
}
ArrayList<ArrayList<String>> plus = new ArrayList<ArrayList<String>>();
//部分分组
ArrayList<ArrayList<ArrayList<String>>> count = new ArrayList<ArrayList<ArrayList<String>>>();
//分组总和
for (int i = 0; i < all.size(); i++) {
for (int j = 0; j < dat.size(); j++) {
if (true == all.get(i).equals(dat.get(j).get(a))) {
plus.add(dat.get(j));
}
}
count.add(plus);
c += ((double) count.get(i).size() / dat.size()) * Ent(count.get(i));
plus.removeAll(plus);
}
return (Ent(dat) - c);
//返回条件熵
}
/* 计算信息增益最大属性 */
public int Gain(ArrayList<ArrayList<String>> dat) {
ArrayList<Double> num = new ArrayList<Double>();
//保存各信息增益值
for (int i = 0; i < dat.get(0).size() - 1; i++) {
num.add(condtion(i, dat));
}
int index = 0;
double max = num.get(0);
for (int i = 1; i < num.size(); i++) {
if (max < num.get(i)) {
max = num.get(i);
index = i;
}
}
return index;
}
//构建决策树
public treeNode creattree(ArrayList<ArrayList<String>> dat) {
int index = Gain(dat);
treeNode node = new treeNode(label.get(index));
ArrayList<String> s = new ArrayList<String>();//属性种类
s.add(dat.get(0).get(index));
for (int i = 1; i < dat.size(); i++) {
if (s.contains(dat.get(i).get(index)) == false) {
s.add(dat.get(i).get(index));
}
}
ArrayList<ArrayList<String>> plus = new ArrayList<ArrayList<String>>();
//部分分组
ArrayList<ArrayList<ArrayList<String>>> count = new ArrayList<ArrayList<ArrayList<String>>>();
//分组总和
//得到节点下的边标签并分组
for (int i = 0; i < s.size(); i++) {
node.label.add(s.get(i));//添加边标签
for (int j = 0; j < dat.size(); j++) {
if (true == s.get(i).equals(dat.get(j).get(index))) {
plus.add(dat.get(j));
}
}
count.add(plus);
//以下添加结点
int k;
String str = count.get(i).get(0).get(count.get(i).get(0).size() - 1);
for (k = 1; k < count.get(i).size(); k++) {
if (false == str.equals(count.get(i).get(k).get(count.get(i).get(k).size() - 1))) {
break;
}
}
if (k == count.get(i).size()) {
treeNode dd = new treeNode(str);
node.node.add(dd);
} else {
node.node.add(creattree(count.get(i)));
}
plus.removeAll(plus);
}
return node;
}
//输出决策树
public void print(ArrayList<ArrayList<String>> dat) {
System.out.println("构建的决策树如下:");
treeNode node = null;
node = creattree(dat);//类
put(node);//递归调用
}
//用于递归的函数
public void put(treeNode node) {
System.out.println("结点:" + node.getsname() + "\n");
for (int i = 0; i < node.label.size(); i++) {
System.out.println(node.getsname() + "的标签属性:" + node.label.get(i));
if (node.node.get(i).node.isEmpty() == true) {
System.out.println("叶子结点:" + node.node.get(i).getsname());
} else {
put(node.node.get(i));
}
}
}
/* 用于对待决策数据进行预测并将结果保存在指定路径 */
public void testdate(ArrayList<ArrayList<String>> test, String path) throws IOException {
treeNode node = null;
int count = 0;
node = creattree(this.date);//类
try {
BufferedWriter out = new BufferedWriter(new FileWriter(path));
for (int i = 0; i < test.size(); i++) {
testput(node, test.get(i));//递归调用
for (int j = 0; j < test.get(i).size(); j++) {
out.write(test.get(i).get(j) + ",");
}
if (kind.equals(date.get(i).get(date.get(i).size() - 1)) == true) {
count++;
}
out.write(kind);
outputs.get(i).kind=kind;
out.newLine();
}
out.flush();
out.close();
} catch (IOException e) {
e.printStackTrace();
}
}
//用于测试的递归调用
public void testput(treeNode node, ArrayList<String> t) {
int index = 0;
for (int i = 0; i < this.label.size(); i++) {
if (this.label.get(i).equals(node.getsname()) == true) {
index = i;
break;
}
}
for (int i = 0; i < node.label.size(); i++) {
if (t.get(index).equals(node.label.get(i)) == false) {
continue;
}
if (node.node.get(i).node.isEmpty() == true) {
this.kind = node.node.get(i).getsname();//取出分类结果
} else {
testput(node.node.get(i), t);
}
}
}
public static void main(String[] args) throws IOException {
String data = "src\\com\\xuetang9\\data.txt";//训练数据集
String test = "src\\com\\xuetang9\\test.txt";//测试数据集
String result = "src\\com\\xuetang9\\result.txt";//预测结果集
ID3 id = new ID3(data, test);//初始化数据
id.print(id.date);//构建并输出决策树
id.testdate(id.test,result);//预测数据并输出结果
System.out.println("请输入你想查询的地点");
Scanner scanner= new Scanner(System.in);
String input=scanner.next();
for (Output output: outputs) {
if(output.countryName.equals(input)){
System.out.println("查询结果为:"+output.kind);
}
}
}
}
3、用来存储输出结果的工具类
public class Output{
public String countryName;
public String kind;
public Output(String countryName,String kind){
this.countryName=countryName;
this.kind=kind;
}
@Override
public String toString() {
return "Output{" +
"countryName='" + countryName + '\'' +
", kind='" + kind + '\'' +
'}';
}
}
4、data文本文件
地点,阳光,人数,交通,风力,空气,雨水,适合旅游
英国,明媚,拥挤,拥挤,微风,稍浑,少量,适合旅游
美国,微暗,拥挤,拥挤,微风,稍浑,少量,适合旅游
日本,灰暗,拥挤,稍拥,微风,清新,少量,适合旅游
韩国,明媚,拥挤,不拥,微风,清新,少量,适合旅游
英国,微暗,拥挤,稍拥,微风,清新,少量,适合旅游
美国,明媚,拥挤,稍拥,微风,稍浑,中量,适合旅游
日本,灰暗,稍拥,稍拥,中风,稍浑,少量,适合旅游
韩国,灰暗,稍拥,稍拥,微风,稍浑,少量,不适旅游
英国,灰暗,稍拥,不拥,中风,稍浑,少量,不适旅游
美国,明媚,不拥,拥挤,微风,浑浊,中量,不适旅游
日本,微暗,不拥,拥挤,中风,浑浊,大量,不适旅游
韩国,微暗,拥挤,稍拥,中风,浑浊,大量,不适旅游
英国,明媚,稍拥,稍拥,中风,清新,少量,不适旅游
美国,微暗,稍拥,不拥,中风,清新,少量,不适旅游
日本,灰暗,稍拥,稍拥,微风,稍浑,大量,不适旅游
韩国,微暗,拥挤,稍拥,中风,浑浊,少量,不适旅游
英国,明媚,拥挤,不拥,中风,稍浑,少量,不适旅游
美国,明媚,不拥,拥挤,微风,稍浑,少量,适合旅游
5、test文本文件
英国,灰暗,稍拥,稍拥,微风,稍浑,大量
美国,微暗,拥挤,拥挤,微风,稍浑,少量
日本,明媚,不拥,拥挤,微风,浑浊,中量
韩国,明媚,拥挤,不拥,微风,清新,少量
在Java中完美运行:
但是问题来了
使用python更改test数据为
英国,明媚,拥挤,拥挤,微风,稍浑,少量
Java中运行结果为
调试发现
//循环中kind成员变量为null
if (kind.equals(date.get(i).get(date.get(i).size() - 1)) == true) {
count++;
}
追溯kind初始化位置
for (int i = 0; i < node.label.size(); i++) {
if (t.get(index).equals(node.label.get(i)) == false) {
continue;
}
if (node.node.get(i).node.isEmpty() == true) {
this.kind = node.node.get(i).getsname();//取出分类结果
} else {
testput(node.node.get(i), t);
}
}
发现修改后的数据没有在决策树中找到相同数据
打印日志
System.out.println("T:"+t.get(index)+",L:"+node.label.get(i));
调试发现读入的数据出现乱码导致该问题
检查文本编码格式,使用python修改前test是utf-8,使用后变为了GBK(大坑啊~~~)。然后修改一下文本的编码格式,问题解决。
总结
文本编程格式常常是我们编程中最经常见的问题。
最后
记得给大黍❤️关注+点赞+收藏+评论+转发❤️
作者:老九学堂—技术大黍
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。