Java 数据科学(二)
四、数据可视化
人类的大脑通常善于在视觉表现中看到模式、趋势和异常值。许多数据科学问题中存在的大量数据可以使用可视化技术进行分析。可视化适合广泛的受众,从分析师到高层管理人员,再到客户。在这一章中,我们将介绍各种可视化技术,并演示 Java 是如何支持这些技术的。
在本章中,我们将说明如何创建不同类型的图形、绘图和图表。大多数例子使用 JavaFX,少数使用名为GRAphing Library(GRAL)的免费库。有几个开源的 Java 绘图库可用。在 github.com/eseifert/gr… 的可以找到这些图书馆的简要对比。我们选择 JavaFX 是因为它被打包成 Java SE 的一部分。
GRAL 用于说明使用 JavaFX 不容易创建的图。GRAL 是一个免费的 Java 库,用于创建各种图表和图形。这个图形库在绘图类型、轴格式和导出选项方面提供了灵活性。trac.erichseifert.de/gral/的 GRAL 资源包括示例代码和有用的操作部分。
可视化是数据分析中的一个重要步骤,因为它允许我们以实用和有意义的方式构想大型数据集。我们可以查看小数据集的值,也许可以从我们看到的模式中得出结论,但这是一个势不可挡且不可靠的过程。使用可视化工具有助于我们识别潜在的问题或意想不到的数据结果,以及构建对好数据的有意义的解释。
数据可视化有用性的一个例子是异常值的出现。可视化数据使我们能够快速看到大大超出我们预期的数据结果,并且我们可以选择如何修改数据来构建干净可用的数据集。这个过程允许我们快速发现错误,并在它们成为问题之前处理它们。此外,可视化使我们能够轻松地对信息进行分类,并帮助分析师以最适合其特定数据集的方式组织他们的查询。
理解图形和图表
有许多类型的视觉表达可用于帮助可视化。我们将简要讨论最常见和最有用的表达式,然后演示几种实现这些表达式类型的 Java 技术。图形或其他可视化工具的选择将取决于数据集和应用程序的需求和约束。
一个条形图是一种非常常见的显示数据关系的技术。在这种类型的图表中,数据用沿 X 和 Y 轴放置的垂直或水平条来表示。数据经过缩放,因此每个条形代表的值可以相互比较。下面是一个简单的条形图示例,我们将在中使用国家作为类别部分创建:
当你想要展示一个与更大的集合相关的值时,一个饼图是最有用的。可以把这想象成一种方式,来想象一块饼相对于整个饼有多大。以下是一个简单的饼图示例,显示了选定欧洲国家的人口分布情况:
时间序列图是一种特殊类型的图表,用于显示与时间相关的值。当数据分析需要了解数据在一段时间内是如何变化的时,这是最合适的。在这些图中,纵轴对应于数值,横轴对应于特定的时间点。特别是,这种类型的图表对于识别不同时间的趋势,或者暗示给定时间段内数据值和特定事件之间的相关性非常有用。
例如,股票价格和房屋价格会变化,但是它们的变化率不同。污染水平和犯罪率也会随着时间而变化。有几种技术可以将这种类型的数据可视化。通常,特定的值没有它们随时间变化的趋势重要。
一个指标图也叫折线图。折线图使用 X 和 Y 轴在网格上绘制点。它们可以用来表示时间序列数据。这些点由线连接,这些线用于同时比较多个数据的值。这种比较通常通过沿着 X 轴绘制独立变量(如时间)以及沿着 Y 轴绘制独立变量(如频率或百分比)来实现。
以下是一个简单的指数图表示例,显示了选定欧洲国家的人口分布情况:
当我们希望以紧凑和有用的方式排列大量数据时,我们可以选择茎和叶图。这种类型的可视化表达式允许您以可读的方式演示一个值与多个值的相关性。茎是指一个数据值,叶是对应的数据点。一个常见的例子是火车时刻表。下表列出了列车的发车时间:
| 06:15 | 06 :20 | 06:25 | 06:30 |
| 06:40 | 06:45 | 06:55 | 07:15 |
| 07:20 | 07:25 | 07:30 | 07:40 |
| 07:45 | 07:55 | 08:00 | 08:12 |
| 08:24 | 08:36 | 08:48 | 09:00 |
| 09:12 | 09:24 | 09:36 | 09:48 |
| 10:00 | 10:12 | 10:24 | 10:36 |
| 10:48 | | | |
然而,这个表可能很难阅读。相反,在下面的部分茎和叶图中,茎代表火车可能出发的小时,而叶代表每小时内的分钟:
| 小时 | 分钟 | | 06 | :15 :20 :25 :30 :40 :45 :55 | | 07 | :15 :20 :25 :30 :40 :45 :55 | | 08 | :00 :12 :24 :36 :48 | | 09 | :00 :12 :24 :36 :48 | | 10 | :00 :12 :24 :36 :48 |
这更容易阅读和处理。
统计分析中一种非常流行的可视化形式是直方图。直方图允许您使用条形显示数据中的频率,类似于条形图。主要区别在于直方图用于识别数据集中的频率和趋势,而条形图用于比较数据集中的特定数据值。以下是我们将在创建直方图部分创建的直方图示例:
一个散点图仅仅是点的集合,分析技术,如相关或回归,可以用来识别这些类型的图表中的趋势。在下面的散点图中,正如在创建散点图中开发的,沿着 X 轴的人口相对于沿着 Y 轴的十年被绘制:
视觉分析目标
每种类型的视觉表达都适用于不同类型的数据和数据分析目的。数据分析的一个常见目的是数据分类。这包括确定特定数据值属于数据集中的哪个子集。这个过程可能发生在数据分析过程的早期,因为将数据分成可管理的和相关的片段简化了分析过程。通常,分类不是最终目标,而是进行进一步分析之前的一个重要的中间步骤。
回归分析是一种复杂而重要的数据分析形式。它包括研究自变量和因变量以及多个自变量之间的关系。这种类型的统计分析允许分析师确定可接受值或期望值的范围,并确定单个值如何适合更大的数据集。回归分析是机器学习的一个重要部分,我们将在第五章、统计数据分析技术中详细讨论。
聚类允许我们识别特定集合或类别中的数据点组。分类将数据分类到相似类型的数据集,而聚类则关注数据集中的数据。例如,我们可能有一个包含世界上所有猫科动物的大型数据集。然后,我们可以将这些猫科动物分为两组,豹亚科(包含大多数大型猫科动物)和猫亚科(所有其他猫科动物)。聚类包括在这些分类中的一个分类内对相似猫的子集进行分组。例如,所有的老虎都可能是豹亚科中的一个集群。
有时,我们的数据分析需要我们从数据集中提取特定类型的信息。选择要提取的数据的过程被称为属性选择或特征选择。这一过程有助于分析师简化数据模型,并使我们能够解决数据集中冗余或不相关信息的问题。
通过对基本绘图和图表类型的介绍,我们将讨论 Java 对创建这些绘图和图表的支持。
创建索引图表
指数图是一种折线图,显示某事物随时间变化的百分比。通常,这样的图表基于单个数据属性。在下面的例子中,我们将使用 60 年的比利时人口。该数据是在ourworldindata.org/grapher/pop…:
| 十年 | 人口 | | One thousand nine hundred and fifty | Eight million six hundred and thirty-nine thousand three hundred and sixty-nine | | One thousand nine hundred and sixty | Nine million one hundred and eighteen thousand seven hundred | | One thousand nine hundred and seventy | Nine million six hundred and thirty-seven thousand eight hundred | | One thousand nine hundred and eighty | Nine million eight hundred and forty-six thousand eight hundred | | One thousand nine hundred and ninety | Nine million nine hundred and sixty-nine thousand three hundred and ten | | Two thousand | Ten million two hundred and sixty-three thousand six hundred and eighteen |
我们从创建扩展了Application的MainApp类开始。我们创建一系列实例变量。XYChart.Series类代表某个图的一系列数据点。在我们的例子中,这将是几十年和人口,我们将很快初始化。下一个声明是针对CategoryAxis和NumberAxis实例的。这些分别代表 X 和 Y 轴。Y 轴的声明包括总体的范围和增量值。这使得图表更具可读性。最后一个声明是国家的字符串变量:
public class MainApp extends Application {
final XYChart.Series<String, Number> series =
new XYChart.Series<>();
final CategoryAxis xAxis = new CategoryAxis();
final NumberAxis yAxis =
new NumberAxis(8000000, 11000000, 1000000);
final static String belgium = "Belgium";
...
}
在 JavaFX 中,main方法通常使用基类launch方法启动应用程序。最终,调用了start方法,我们覆盖了它。在这个例子中,我们调用创建用户界面的simpleLineChart方法:
public static void main(String[] args) {
launch(args);
}
public void start(Stage stage) {
simpleIndexChart (stage);
}
simpleLineChart跟在后面,并被传递了一个Stage类的实例。这表示应用程序窗口的客户区。我们首先为应用程序和折线图设置一个标题。 Y 轴的标签设置完毕。使用 X 和 Y 轴实例初始化LineChart类的实例。此类表示折线图:
public void simpleIndexChart (Stage stage) {
stage.setTitle("Index Chart");
lineChart.setTitle("Belgium Population");
yAxis.setLabel("Population");
final LineChart<String, Number> lineChart
= new LineChart<>(xAxis, yAxis);
...
}
给该系列一个名称,然后使用addDataItem辅助方法将每十年的人口添加到该系列中:
series.setName("Population");
addDataItem(series, "1950", 8639369);
addDataItem(series, "1960", 9118700);
addDataItem(series, "1970", 9637800);
addDataItem(series, "1980", 9846800);
addDataItem(series, "1990", 9969310);
addDataItem(series, "2000", 10263618);
接下来是addDataItem方法,它使用传递给它的String和Number值创建一个XYChart.Data类实例。然后,它将实例添加到系列中:
public void addDataItem(XYChart.Series<String, Number> series,
String x, Number y) {
series.getData().add(new XYChart.Data<>(x, y));
}
simpleLineChart方法的最后一部分创建了一个代表stage内容的Scene类实例。JavaFX 使用舞台和场景的概念来处理应用程序 GUI 的内部。
使用折线图创建scene,应用程序的大小通过600像素设置为800。然后将系列添加到折线图中,并将scene添加到stage。show方法显示应用程序:
Scene scene = new Scene(lineChart, 800, 600);
lineChart.getData().add(series);
stage.setScene(scene);
stage.show();
当应用程序执行时,将显示以下窗口:
创建条形图
条形图使用两个带矩形条的轴,可以垂直或水平放置。条形的长度与它所代表的数值成正比。条形图可用于显示时间序列数据。
在下面的一系列示例中,我们将使用一组欧洲国家三十年的人口,如下表所示。该数据是在ourworldindata.org/grapher/pop…:
| 国家 | 1950 年 | 1960 年 | 1970 年 | | 比利时 | Eight million six hundred and thirty-nine thousand three hundred and sixty-nine | Nine million one hundred and eighteen thousand seven hundred | Nine million six hundred and thirty-seven thousand eight hundred | | 法国 | Forty-two million five hundred and eighteen thousand | Forty-six million five hundred and eighty-four thousand | Fifty-one million nine hundred and eighteen thousand | | 德国 | Sixty-eight million three hundred and seventy-four thousand five hundred and seventy-two | Seventy-two million four hundred and eighty thousand eight hundred and sixty-nine | Seventy-seven million seven hundred and eighty-three thousand one hundred and sixty-four | | 荷兰 | Ten million one hundred and thirteen thousand five hundred and twenty-seven | Eleven million four hundred and eighty-six thousand | Thirteen million thirty-two thousand three hundred and thirty-five | | 瑞典 | Seven million fourteen thousand and five | Seven million four hundred and eighty thousand three hundred and ninety-five | Eight million forty-two thousand eight hundred and three | | 联合王国 | Fifty million one hundred and twenty-seven thousand | Fifty-two million three hundred and seventy-two thousand | Fifty-five million six hundred and thirty-two thousand |
三个条形图中的第一个将使用 JavaFX 构建。我们从一系列国家声明开始,作为扩展Application类的一部分:
public class MainApp extends Application {
final static String belgium = "Belgium";
final static String france = "France";
final static String germany = "Germany";
final static String netherlands = "Netherlands";
final static String sweden = "Sweden";
final static String unitedKingdom = "United Kingdom";
...
}
接下来,我们声明了一系列表示图形各部分的实例变量。第一个是CategoryAxis和NumberAxis实例:
final CategoryAxis xAxis = new CategoryAxis();
final NumberAxis yAxis = new NumberAxis();
人口和国家数据存储在一系列XYChart.Series实例中。这里,我们声明了六个不同的系列,它们使用了一个字符串和数字对。第一个示例没有使用所有六个系列,但后面的示例会使用。我们首先将一个国家字符串及其相应的人口分配给三个系列。这些序列将代表未来几十年1950、1960和1970的人口:
final XYChart.Series<String, Number> series1 =
new XYChart.Series<>();
final XYChart.Series<String, Number> series2
new XYChart.Series<>();
final XYChart.Series<String, Number> series3 =
new XYChart.Series<>();
final XYChart.Series<String, Number> series4 =
new XYChart.Series<>();
final XYChart.Series<String, Number> series5 =
new XYChart.Series<>();
final XYChart.Series<String, Number> series6 =
new XYChart.Series<>();
我们将从两个简单的条形图开始。第一个将在类别中显示国家,在该类别中, X 轴显示年份变化,在 Y 轴显示人口。第二个将几十年显示为包含县的类别。最后一个例子是堆积条形图。
使用国家作为类别
条形图的元素在simpleBarChartByCountry方法中设置。设置图表的标题,并使用两个轴创建一个BarChart类实例。该图表及其 X 轴和 Y 轴也有在此初始化的标签:
public void simpleBarChartByCountry(Stage stage) {
stage.setTitle("Bar Chart");
final BarChart<String, Number> barChart
= new BarChart<>(xAxis, yAxis);
barChart.setTitle("Country Summary");
xAxis.setLabel("Country");
yAxis.setLabel("Population");
...
}
接下来,用一个名称初始化前三个系列,然后是该系列的国家和人口数据。上一节中介绍的助手方法addDataItem用于向每个系列添加数据:
series1.setName("1950");
addDataItem(series1,belgium, 8639369);
addDataItem(series1,france, 42518000);
addDataItem(series1,germany, 68374572);
addDataItem(series1,netherlands, 10113527);
addDataItem(series1,sweden, 7014005);
addDataItem(series1,unitedKingdom, 50127000);
series2.setName("1960");
addDataItem(series2,belgium, 9118700);
addDataItem(series2,france, 46584000);
addDataItem(series2,germany, 72480869);
addDataItem(series2,netherlands, 11486000);
addDataItem(series2,sweden, 7480395);
addDataItem(series2,unitedKingdom, 52372000);
series3.setName("1970");
addDataItem(series3,belgium, 9637800);
addDataItem(series3,france, 51918000);
addDataItem(series3,germany, 77783164);
addDataItem(series3,netherlands, 13032335);
addDataItem(series3,sweden, 8042803);
addDataItem(series3,unitedKingdom, 55632000);
该方法的最后一部分创建了一个scene实例。三个系列被添加到scene上,并且使用setScene方法将scene连接到stage上。一个stage是一个本质上代表窗口客户区的类:
Scene scene = new Scene(barChart, 800, 600);
barChart.getData().addAll(series1, series2, series3);
stage.setScene(scene);
stage.show();
两个方法中的最后一个是start方法,当窗口显示时自动调用。它被传递给Stage实例。在这里,我们称之为simpleBarChartByCountry法:
public void start(Stage stage) {
simpleBarChartByCountry(stage);
}
main方法由对Application类的launch方法的调用组成:
public static void main(String[] args) {
launch(args);
}
执行应用程序时,会显示以下图形:
以十年为范畴
在下面的例子中,我们将演示如何显示相同的信息,但是我们将按年份组织 X 轴类别。我们将使用simpleBarChartByYear方法,如下所示。轴和标题的设置方式与之前相同,但标题和标签的值不同:
public void simpleBarChartByYear(Stage stage) {
stage.setTitle("Bar Chart");
final BarChart<String, Number> barChart
= new BarChart<>(xAxis, yAxis);
barChart.setTitle("Year Summary");
xAxis.setLabel("Year");
yAxis.setLabel("Population");
...
}
以下字符串变量被声明为三十年:
String year1950 = "1950";
String year1960 = "1960";
String year1970 = "1970";
数据系列的创建方式与以前相同,只是国家名称用于系列名称,年份用于类别。此外,还使用六个系列,每个国家一个系列:
series1.setName(belgium);
addDataItem(series1, year1950, 8639369);
addDataItem(series1, year1960, 9118700);
addDataItem(series1, year1970, 9637800);
series2.setName(france);
addDataItem(series2, year1950, 42518000);
addDataItem(series2, year1960, 46584000);
addDataItem(series2, year1970, 51918000);
series3.setName(germany);
addDataItem(series3, year1950, 68374572);
addDataItem(series3, year1960, 72480869);
addDataItem(series3, year1970, 77783164);
series4.setName(netherlands);
addDataItem(series4, year1950, 10113527);
addDataItem(series4, year1960, 11486000);
addDataItem(series4, year1970, 13032335);
series5.setName(sweden);
addDataItem(series5, year1950, 7014005);
addDataItem(series5, year1960, 7480395);
addDataItem(series5, year1970, 8042803);
series6.setName(unitedKingdom);
addDataItem(series6, year1950, 50127000);
addDataItem(series6, year1960, 52372000);
addDataItem(series6, year1970, 55632000);
scene被创建并附加到stage:
Scene scene = new Scene(barChart, 800, 600);
barChart.getData().addAll(series1, series2,
series3, series4, series5, series6);
stage.setScene(scene);
stage.show();
main方法没有改变,但是start方法调用了simpleBarChartByYear方法:
public void start(Stage stage) {
simpleBarChartByYear(stage);
}
执行应用程序时,会显示以下图形:
创建堆叠图
面积图通过为较大的值分配更多的空间来描述信息。通过将面积图堆叠在一起,我们创建了一个堆叠图,有时称为流图。但是,堆积图不能很好地处理负值,也不能用于求和没有意义的数据,例如温度。如果堆叠了太多图表,那么解释起来会变得很困难。
接下来,我们将展示如何创建堆叠条形图。stackedGraphExample方法包含创建条形图的代码。我们从熟悉的代码开始设置标题和标签。但是,对于 X 轴,setCategories方法FXCollections。<String>observableArrayList实例用于设置类别。这个构造函数的参数是由Arrays类的asList方法创建的字符串数组和国家名称:
public void stackedGraphExample(Stage stage) {
stage.setTitle("Stacked Bar Chart");
final StackedBarChart<String, Number> stackedBarChart
= new StackedBarChart<>(xAxis, yAxis);
stackedBarChart.setTitle("Country Population");
xAxis.setLabel("Country");
xAxis.setCategories(
FXCollections.<String>observableArrayList(
Arrays.asList(belgium, germany, france,
netherlands, sweden, unitedKingdom)));
yAxis.setLabel("Population");
...
}
使用年份作为系列名称和国家对系列进行初始化,并使用 helper 方法addDataItem添加它们的人口。然后创建了scene:
series1.setName("1950");
addDataItem(series1, belgium, 8639369);
addDataItem(series1, france, 42518000);
addDataItem(series1, germany, 68374572);
addDataItem(series1, netherlands, 10113527);
addDataItem(series1, sweden, 7014005);
addDataItem(series1, unitedKingdom, 50127000);
series2.setName("1960");
addDataItem(series2, belgium, 9118700);
addDataItem(series2, france, 46584000);
addDataItem(series2, germany, 72480869);
addDataItem(series2, netherlands, 11486000);
addDataItem(series2, sweden, 7480395);
addDataItem(series2, unitedKingdom, 52372000);
series3.setName("1970");
addDataItem(series3, belgium, 9637800);
addDataItem(series3, france, 51918000);
addDataItem(series3, germany, 77783164);
addDataItem(series3, netherlands, 13032335);
addDataItem(series3, sweden, 8042803);
addDataItem(series3, unitedKingdom, 55632000);
Scene scene = new Scene(stackedBarChart, 800, 600);
stackedBarChart.getData().addAll(series1, series2, series3);
stage.setScene(scene);
stage.show();
main方法没有改变,但是start方法调用了stackedGraphExample方法:
public void start(Stage stage) {
stackedGraphExample(stage);
}
执行应用程序时,会显示以下图形:
创建饼图
以下饼图示例基于 2000 年选定欧洲国家的人口,如下所示:
| 国家 | 人口 | 百分比 | | 比利时 | Ten million two hundred and sixty-three thousand six hundred and eighteen | three | | 法国 | Sixty-one million one hundred and thirty-seven thousand | Twenty-six | | 德国 | Eighty-two million one hundred and eighty-seven thousand nine hundred and nine | Thirty-five | | 荷兰 | Fifteen million nine hundred and seven thousand eight hundred and fifty-three | seven | | 瑞典 | Eight million eight hundred and seventy-two thousand | four | | 联合王国 | Fifty-nine million five hundred and twenty-two thousand four hundred and sixty-eight | Twenty-five |
JavaFX 实现使用与前面示例中相同的Application基类和main方法。我们不会使用单独的方法来创建 GUI,而是将这段代码放在start方法中,如下所示:
public class PieChartSample extends Application {
public void start(Stage stage) {
Scene scene = new Scene(new Group());
stage.setTitle("Europian Country Population");
stage.setWidth(500);
stage.setHeight(500);
...
}
public static void main(String[] args) {
launch(args);
}
}
饼图由PieChart类表示。我们可以使用饼图数据的ObservableList在构造函数中创建并初始化饼图。该数据由一系列PieChart.Data实例组成,每个实例包含一个文本标签和一个百分比值。
下一个序列基于前面给出的欧洲人口数据创建了一个ObservableList实例。FXCollections类的observableArrayList方法返回一个带有饼图数据列表的ObservableList实例:
ObservableList<PieChart.Data> pieChartData =
FXCollections.observableArrayList(
new PieChart.Data("Belgium", 3),
new PieChart.Data("France", 26),
new PieChart.Data("Germany", 35),
new PieChart.Data("Netherlands", 7),
new PieChart.Data("Sweden", 4),
new PieChart.Data("United Kingdom", 25));
然后,我们创建饼图并设置其标题。然后饼状图被添加到scene,scene与stage相关联,然后显示窗口:
final PieChart pieChart = new PieChart(pieChartData);
pieChart.setTitle("Country Population");
((Group) scene.getRoot()).getChildren().add(pieChart);
stage.setScene(scene);
stage.show();
执行应用程序时,会显示以下图形:
创建散点图
散点图也使用 JavaFX 中的XYChart.Series类。在这个例子中,我们将使用一组欧洲数据,其中包括 1500 年到 2000 年这几十年中以前的欧洲国家及其人口数据。这些信息存储在一个名为EuropeanScatterData.csv的文件中。该文件的第一部分如下所示:
1500 1400000
1600 1600000
1650 1500000
1700 2000000
1750 2250000
1800 3250000
1820 3434000
1830 3750000
1840 4080000
...
我们从 JavaFX MainApp类的声明开始,如下所示。main方法启动应用程序,start方法创建用户界面:
public class MainApp extends Application {
@Override
public void start(Stage stage) throws Exception {
...
}
public static void main(String[] args) {
launch(args);
}
}
在start方法中,我们设置标题,创建轴,并创建代表散点图的ScatterChart的实例。NumberAxis类的构造函数使用的值比其默认构造函数使用的默认值更匹配数据范围:
stage.setTitle("Scatter Chart Sample");
final NumberAxis yAxis = new NumberAxis(1400, 2100, 100);
final NumberAxis xAxis = new NumberAxis(500000, 90000000,
1000000);
final ScatterChart<Number, Number> scatterChart = new
ScatterChart<>(xAxis, yAxis);
接下来,轴的标签与散点图的标题一起设置:
xAxis.setLabel("Population");
yAxis.setLabel("Decade");
scatterChart.setTitle("Population Scatter Graph");
创建了一个XYChart.Series类的实例,并命名为:
XYChart.Series series = new XYChart.Series();
使用一个CSVReader类实例和文件EuropeanScatterData.csv填充该系列。这个过程在第三章、数据清理中讨论:
try (CSVReader dataReader = new CSVReader(new FileReader("EuropeanScatterData.csv"), ',')) {
String[] nextLine;
while ((nextLine = dataReader.readNext()) != null) {
int decade = Integer.parseInt(nextLine[0]);
int population = Integer.parseInt(nextLine[1]);
series.getData().add(new XYChart.Data(
population, decade));
out.println("Decade: " + decade +
" Population: " + population);
}
}
scatterChart.getData().addAll(series);
JavaFX scene和stage被创建,然后显示绘图:
Scene scene = new Scene(scatterChart, 500, 400);
stage.setScene(scene);
stage.show();
执行应用程序时,会显示以下图形:
创建直方图
直方图虽然在外观上类似于条形图,但用于显示数据集中数据项相对于其他项的频率。下面每个使用 GRAL 的例子都将使用DataTable类来最初保存要显示的数据。在本例中,我们将从名为AgeofMarriage.csv的样本文件中读取数据。这个以逗号分隔的文件保存了人们第一次结婚的年龄列表。
我们将创建一个名为HistogramExample的新类,它扩展了JFrame类,并在其构造函数中包含以下代码。我们首先创建一个DataReader对象来指定数据是 CSV 格式的。然后我们使用一个 try-catch 块来处理 IO 异常,并调用DataReader类的read方法将数据直接放入DataTable对象中。read方法的第一个参数是一个FileInputStream对象,第二个参数指定文件中预期的数据类型:
DataReader readType=
DataReaderFactory.getInstance().get("text/csv");
String fileName = "C://AgeofMarriage.csv";
try {
DataTable histData = (DataTable) readType.read(
New FileInputStream(fileName), Integer.class);
...
}
接下来,我们创建一个Number数组来指定我们期望获得数据的年龄。在这种情况下,我们预计结婚年龄将在19到30之间。我们使用这个数组来创建我们的Histogram对象。我们包括了之前的DataTable,并且指定了方向。然后我们创建我们的DataSource,指定我们的开始年龄,并指定沿着我们的 X 轴的间距:
Number ageRange[] = {19,20,21,22,23,24,25,26,27,28,29,30};
Histogram sampleHisto = new Histogram1D(
histData, Orientation.VERTICAL, ageRange);
DataSource sampleHistData = new EnumeratedData(sampleHisto, 19,
1.0);
我们使用BarPlot类从前面读入的数据中创建直方图:
BarPlot testPlot = new BarPlot(sampleHistData);
接下来的几个步骤用于格式化直方图的各个方面。我们使用setInsets方法来指定在窗口内图表的每一边放置多少空间。我们可以为图表提供一个标题,并指定条形宽度:
testPlot.setInsets(new Insets2D.Double(20.0, 50.0, 50.0, 20.0));
testPlot.getTitle().setText("Average Age of Marriage");
testPlot.setBarWidth(0.7);
我们还需要格式化我们的 X 和 Y 轴。我们已经选择将我们的范围设置为 X 轴,以紧密匹配我们的预期年龄范围,但在图表的一侧提供一些空间。因为我们知道样本数据的数量,所以我们将我们的 Y 轴设置为从0到10的范围。在业务应用程序中,这些范围将通过检查实际数据集来计算。我们还可以指定是否希望显示刻度线,以及希望轴相交的位置:
testPlot.getAxis(BarPlot.AXIS_X).setRange(18, 30.0);
testPlot.getAxisRenderer(BarPlot.AXIS_X).setTickAlignment(0.0);
testPlot.getAxisRenderer(BarPlot.AXIS_X).setTickSpacing(1);
testPlot.getAxisRenderer(BarPlot.AXIS_X).setMinorTicksVisible(false );
testPlot.getAxis(BarPlot.AXIS_Y).setRange(0.0, 10.0);
testPlot.getAxisRenderer(BarPlot.AXIS_Y).setTickAlignment(0.0);
testPlot.getAxisRenderer(BarPlot.AXIS_Y).setMinorTicksVisible(false );
testPlot.getAxisRenderer(BarPlot.AXIS_Y).setIntersection(0);
我们在图形上显示的颜色和值也有很大的灵活性。在本例中,我们选择显示每个年龄的频率值,并将图表颜色设置为black:
PointRenderer renderHist =
testPlot.getPointRenderers(sampleHistData).get(0);
renderHist.setColor(GraphicsUtils.deriveWithAlpha(Color.black,
128));
renderHist.setValueVisible(true);
最后,我们为窗口的显示方式设置了几个属性:
InteractivePanel pan = new InteractivePanel(testPlot);
pan.setPannable(false);
pan.setZoomable(false);
add(pan);
setSize(1500, 700);
this.setVisible(true);
执行应用程序时,会显示以下图形:
创建圆环图
圆环图类似于饼图,但它们缺少中间部分(因此得名圆环图)。一些分析师更喜欢圆环图而不是饼图,因为它们不强调图表中每个部分的大小,并且更容易与其他圆环图进行比较。它们还提供了占用更少空间的额外优势,允许在显示中有更多的格式化选项。
在这个例子中,我们将假设我们的数据已经被填充到一个名为ageCount的二维数组中。数组的第一行包含可能的年龄值,范围也是从19到30(包括 T1 和)。第二行包含等于每个年龄的数据值的数量。例如,在我们的数据集中,有六个数据值等于19,因此ageCount[0][1]包含数字 6。
我们创建一个DataTable并使用add方法将数组中的值相加。请注意,我们正在测试特定年龄的值是否为零。在我们的测试案例中,将有零个数据值等于23。如果该点没有数据值,我们将选择在圆环图中添加一个空白区域。这是通过使用负数作为add方法中的第一个参数来实现的。这将设置一个大小为3的空白空间:
DataTable donutData = new DataTable(Integer.class, Integer.class);
for(int Y = 0; Y < ageCount[0].length; y++){
if(ageCount[1][y] == 0){
donutData.add(-3, ageCount[0][y]);
}else{
donutData.add(ageCount[1][y], ageCount[0][y]);
}
}
接下来,我们使用PiePlot类创建我们的圆环图。我们设置绘图的基本属性,包括指定图例的值。在这种情况下,我们希望我们的图例反映我们的年龄可能性,所以我们使用setLabelColumn方法来更改默认标签。我们也像在前面的例子中一样设置我们的 insets:
PiePlot testPlot = new PiePlot(donutData);
((ValueLegend) testPlot.getLegend()).setLabelColumn(1);
testPlot.getTitle().setText("Donut Plot Example");
testPlot.setRadius(0.9);
testPlot.setLegendVisible(true);
testPlot.setInsets(new Insets2D.Double(20.0, 20.0, 20.0, 20.0));
接下来,我们创建一个PieSliceRenderer对象来设置更高级的属性。因为甜甜圈图本质上基本上是一个饼图,我们将通过调用setInnerRadius方法来呈现甜甜圈图。我们还指定饼图扇区之间的间隙、使用的颜色以及标签的样式:
PieSliceRenderer renderPie = (PieSliceRenderer)
testPlot.getPointRenderer(donutData);
renderPie.setInnerRadius(0.4);
renderPie.setGap(0.2);
LinearGradient colors = new LinearGradient(
Color.blue, Color.green);
renderPie.setColor(colors);
renderPie.setValueVisible(true);
renderPie.setValueColor(Color.WHITE);
renderPie.setValueFont(Font.decode(null).deriveFont(Font.BOLD));
最后,我们创建面板并设置其大小:
add(new InteractivePanel(testPlot), BorderLayout.CENTER);
setSize(1500, 700);
setVisible(true);
执行应用程序时,会显示以下图形:
创建气泡图
气泡图类似于散点图,只是它们以三维形式表示数据。前两个维度在 X 和 Y 轴上表示,第三个维度由绘制点的大小表示。这有助于确定数据值之间的关系。
我们将再次使用DataTable类来最初保存要显示的数据。在本例中,我们将从名为MarriageByYears.csv的样本文件中读取数据。这也是一个 CSV 文件,其中一列代表结婚的年份,第二列代表结婚的年龄,第三列代表婚姻满意度的整数,范围从1(最不满意)到10(最满意)。我们创建一个DataSeries来表示我们想要的数据图类型,然后创建一个XYPlot对象:
DataReader readType =
DataReaderFactory.getInstance().get("text/csv");
String fileName = "C://MarriageByYears.csv";
try {
DataTable bubbleData = (DataTable) readType.read(
new FileInputStream(fileName), Integer.class,
Integer.class, Integer.class);
DataSeries bubbleSeries = new DataSeries("Bubble", bubbleData);
XYPlot testPlot = new XYPlot(bubbleSeries);
接下来,我们设置图表的基本属性信息。在本例中,我们将设置颜色并关闭垂直和水平网格。在本例中,我们还将使 X 轴和 Y 轴不可见。请注意,我们仍然为轴设置了一个范围,即使它们没有显示:**
testPlot.setInsets(new Insets2D.Double(30.0)); testPlot.setBackground(new Color(0.75f, 0.75f, 0.75f));
XYPlotArea2D areaProp = (XYPlotArea2D) testPlot.getPlotArea();
areaProp.setBorderColor(null);
areaProp.setMajorGridX(false);
areaProp.setMajorGridY(false);
areaProp.setClippingArea(null);
testPlot.getAxisRenderer(XYPlot.AXIS_X).setShapeVisible(false);
testPlot.getAxisRenderer(XYPlot.AXIS_X).setTicksVisible(false);
testPlot.getAxisRenderer(XYPlot.AXIS_Y).setShapeVisible(false);
testPlot.getAxisRenderer(XYPlot.AXIS_Y).setTicksVisible(false);
testPlot.getAxis(XYPlot.AXIS_X).setRange(1940, 2020);
testPlot.getAxis(XYPlot.AXIS_Y).setRange(17, 30);
我们还可以设置与图表上绘制的气泡相关的属性。在这里,我们设置颜色和形状,并指定哪一列数据将用于缩放形状。在这种情况下,将使用第三列,即婚姻满意度等级。我们使用setColumn方法设置它:
Color color = GraphicsUtils.deriveWithAlpha(Color.black, 96);
SizeablePointRenderer renderBubble = new SizeablePointRenderer();
renderBubble.setShape(new Ellipse2D.Double(-3.5, -3.5, 4.0, 4.0));
renderBubble.setColor(color);
renderBubble.setColumn(2);
testPlot.setPointRenderers(bubbleSeries, renderBubble);
最后,我们创建面板并设置其大小:
add(new InteractivePanel(testPlot), BorderLayout.CENTER);
setSize(new Dimension(1500, 700));
setVisible(true);
执行应用程序时,会显示下图。请注意,点的大小和颜色会根据特定数据点的频率而变化:
总结
在这一章中,我们介绍用于可视化数据的基本图形、曲线图和图表。可视化的过程使分析人员能够以图形方式检查被检查的数据。这更加直观,并且通常有助于快速识别数据中难以从原始数据中提取的异常。
检查了几种视觉表示,包括折线图、各种条形图、饼图、散点图、直方图、环形图和气泡图。这些数据的图形描述中的每一个都提供了被分析数据的不同视角。最合适的技术取决于所用数据的性质。虽然我们没有涵盖所有可能的图形技术,但是这个示例很好地概述了可用的技术。
我们还关心如何使用 Java 来绘制这些图形。许多例子都使用了 JavaFX。这是一个与 Java SE 捆绑在一起的现成工具。但是,还有其他几个可用的库。我们用 GRAL 来说明如何生成一些图形。
在概述了可视化技术之后,我们准备继续讨论其他主题,在这些主题中,可视化将用于更好地传达数据科学技术的本质。在下一章,我们将介绍基本的统计过程,包括线性回归,我们将使用本章介绍的技术。
五、统计数据分析技术
本章的目的不是让读者成为统计技术的专家。相反,它是为了让读者熟悉正在使用的基本统计技术,并演示 Java 如何支持统计分析。虽然有各种各样的数据分析技术,在这一章中,我们将把重点放在更常见的任务上。
这些技术包括从相对简单的平均值计算到复杂的回归分析模型。统计分析可能是一个非常复杂的过程,需要进行大量的研究。我们将从介绍基本的统计分析技术开始,包括计算数据集的均值、中值、众数和标准差。有许多方法可以用来计算这些值,我们将使用标准 Java 和第三方 API 来演示这些方法。我们还将简要讨论样本大小和假设检验。
回归分析是一种重要的数据分析技术。该技术会创建一条尝试匹配数据集的线。代表这条线的方程可以用来预测未来的行为。回归分析有几种类型。在本章中,我们将重点介绍简单线性回归和多元回归。通过简单的线性回归,年龄等单一因素被用来预测一些行为,如外出就餐的可能性。通过多元回归,年龄、收入水平和婚姻状况等多个因素可以用来预测一个人外出就餐的频率。
预测分析,或称分析,是关于预测未来事件的。本书中使用的许多技术都与预测有关。具体来说,本章的回归分析部分预测未来的行为。
在我们看到 Java 如何支持回归分析之前,我们需要讨论基本的统计技术。我们从均值、众数和中位数开始。
在本章中,我们将讨论以下主题:
- 使用平均值、众数和中位数
- 标准偏差和样本量的确定
- 假设检验
- 回归分析
使用平均值、众数和中位数
平均值、中值和众数是描述特征或汇总数据集中信息的基本方法。当第一次遇到一个新的大型数据集时,了解它的基本信息会有助于指导进一步的分析。这些值通常用于以后的分析,以生成更复杂的测量结果和结论。当我们使用数据集的平均值来计算标准偏差时,就会发生这种情况,我们将在本章的标准偏差一节中进行演示。
计算平均值
术语 **mean,**也称为平均值,计算方法是将列表中的值相加,然后将总和除以值的个数。这种技术对于确定一组数字的总体趋势很有用。它还可以用来填充缺失的数据元素。我们将研究几种使用标准 Java 库和第三方 API 计算给定数据集平均值的方法。
使用简单的 Java 技术寻找平均值
在我们的第一个例子中,我们将演示使用标准 Java 功能计算平均值的基本方法。我们将使用一个名为testData的double值数组:
double[] testData = {12.5, 18.7, 11.2, 19.0, 22.1, 14.3, 16.9, 12.5,
17.8, 16.9};
我们创建一个double变量来保存所有值的总和,创建一个double变量来保存mean。循环用于遍历数据并将值相加。接下来,总和除以我们数组的length(元素总数)来计算mean:
double total = 0;
for (double element : testData) {
total += element;
}
double mean = total / testData.length;
out.println("The mean is " + mean);
我们的输出如下:
平均值为 16.19
使用 Java 8 技术寻找平均值
Java 8 通过引入可选类提供了额外的功能。在这个例子中,我们将结合使用OptionalDouble类和Arrays类的stream方法。我们将使用与上一个例子中相同的 doubles 数组来创建一个OptionalDouble对象。如果数组中的任何数字,或者数组中数字的和不是实数,那么OptionalDouble对象的值也不是实数:
OptionalDouble mean = Arrays.stream(testData).average();
我们使用isPresent方法来确定我们是否为我们的平均值计算了一个有效的数字。如果我们没有得到一个好的结果,isPresent方法将返回false,我们可以处理任何异常:
if (mean.isPresent()) {
out.println("The mean is " + mean.getAsDouble());
} else {
out.println("The stream was empty");
}
我们的输出如下:
The mean is 16.19
另一种更简洁的使用OptionalDouble类的技术涉及 lambda 表达式和ifPresent方法。如果mean是一个有效的OptionalDouble对象,这个方法执行它的参数:
OptionalDouble mean = Arrays.stream(testData).average();
mean.ifPresent(x-> out.println("The mean is " + x));
我们的输出如下:
The mean is 16.19
最后,如果mean不是有效的OptionalDouble对象,我们可以使用orElse方法打印平均值或替代值:
OptionalDouble mean = Arrays.stream(testData).average();
out.println("The mean is " + mean.orElse(0));
我们的输出是相同的:
The mean is 16.19
在接下来的两个例子中,我们将使用第三方库,并继续使用 doubles 数组,testData。
用谷歌番石榴查找意思
在这个例子中,我们将使用谷歌番石榴库,在第三章中介绍、数据清理。Stats类提供了处理数字数据的功能,包括寻找平均值和标准差,我们将在后面演示。为了计算mean,我们首先使用testData数组创建一个Stats对象,然后执行mean方法:
Stats testStat = Stats.of(testData);
double mean = testStat.mean();
out.println("The mean is " + mean);
请注意本例中输出的默认格式之间的差异。
使用 Apache Commons 查找平均值
在我们最后的例子中,我们使用 Apache Commons 库,也在第 3 章、中介绍了数据清理。我们首先创建一个Mean对象,然后使用我们的testData执行evaluate方法。该方法返回一个double **,**表示数组中值的平均值:
Mean mean = new Mean();
double average = mean.evaluate(testData);
out.println("The mean is " + average);
我们的输出如下:
The mean is 16.19
Apache Commons 还提供了一个有用的DescriptiveStatistics类。稍后我们将使用它来演示中位数和标准差,但首先我们将从计算平均值开始。使用SynchronizedDescriptiveStatistics类是有利的,因为它是同步的,因此是线程安全的。
我们从创建我们的DescriptiveStatistics对象statTest开始。然后,我们循环遍历我们的双数组,并将每一项添加到statTest。然后我们可以调用getMean方法来计算mean:
DescriptiveStatistics statTest =
new SynchronizedDescriptiveStatistics();
for(double num : testData){
statTest.addValue(num);
}
out.println("The mean is " + statTest.getMean());
我们的输出如下:
The mean is 16.19
接下来,我们将讨论相关话题:中位数。
计算中位数
如果数据集包含大量异常值或有偏差,则平均值可能会产生误导。当这种情况发生时,众数和中位数会很有用。术语中值是一系列值中间的值。对于奇数个值,这很容易计算。对于偶数个值,中值计算为中间两个值的平均值。
使用简单的 Java 技术求中位数
在我们的第一个例子中,我们将使用一个基本的 Java 方法来计算中位数。对于这些例子,我们稍微修改了我们的testData数组:
double[] testData = {12.5, 18.3, 11.2, 19.0, 22.1, 14.3, 16.2, 12.5,
17.8, 16.5};
首先,我们使用Arrays类对我们的数据进行排序,因为当数据按数字顺序排列时,寻找中值是很简单的:
Arrays.sort(testData);
然后我们处理三种可能性:
- 我们的列表是空的
- 我们的列表有偶数个值
- 我们的列表有奇数个值
下面的代码可能会被缩短,但是我们已经明确地帮助阐明了这个过程。如果我们的列表有偶数个值,我们用列表的长度除以2。第一个变量mid1将保存两个中间值中的第一个。第二个变量mid2将保存第二个中间值。这两个数字的平均值就是我们的中值。寻找具有奇数个值的列表的中值索引的过程更简单,只需要我们将长度除以2并加上1:
if(testData.length==0){ // Empty list
out.println("No median. Length is 0");
}else if(testData.length%2==0){ // Even number of elements
double mid1 = testData[(testData.length/2)-1];
double mid2 = testData[testData.length/2];
double med = (mid1 + mid2)/2;
out.println("The median is " + med);
}else{ // Odd number of elements
double mid = testData[(testData.length/2)+1];
out.println("The median is " + mid);
}
使用前面包含偶数个值的数组,我们的输出是:
The median is 16.35
为了测试奇数个元素的代码,我们将把 double 12.5添加到数组的末尾。我们的新输出如下:
The median is 16.5
使用 Apache Commons 寻找中间值
我们还可以使用在计算平均值一节中演示的 Apache Commons DescriptiveStatistics类来计算中位数。我们将继续使用具有以下值的testData数组:
double[] testData = {12.5, 18.3, 11.2, 19.0, 22.1, 14.3, 16.2, 12.5,
17.8, 16.5, 12.5};
我们的代码非常类似于我们用来计算平均值的代码。我们只需创建我们的DescriptiveStatistics对象并调用getPercentile方法,该方法返回存储在其参数中指定的百分点值的估计值。为了找到中间值,我们使用50的值:
DescriptiveStatistics statTest =
new SynchronizedDescriptiveStatistics();
for(double num : testData){
statTest.addValue(num);
}
out.println("The median is " + statTest.getPercentile(50));
我们的输出如下:
The median is 16.2
计算模式
术语模式用于表示数据集中出现频率最高的值。这可以被认为是最受欢迎的结果,或直方图中最高的条。在进行统计分析时,它可能是一条有用的信息,但计算起来可能比第一次出现时更复杂。首先,我们将使用下面的testData数组演示一个简单的 Java 技术:
double[] testData = {12.5, 18.3, 11.2, 19.0, 22.1, 14.3, 16.2, 12.5,
17.8, 16.5, 12.5};
我们首先初始化变量来保存模式、模式在列表中出现的次数以及一个tempCnt变量。mode 和modeCount变量分别用于保存模式值和该值在列表中出现的次数。变量tempCnt用于统计一个元素在列表中出现的次数:
int modeCount = 0;
double mode = 0;
int tempCnt = 0;
然后,我们使用嵌套的 for 循环将数组中的每个值与数组中的其他值进行比较。当我们找到匹配的值时,我们增加我们的tempCnt。在比较每个值之后,我们测试看tempCnt是否大于modeCount,如果是,我们改变我们的modeCount和模式以反映新的值:
for (double testValue : testData){
tempCnt = 0;
for (double value : testData){
if (testValue == value){
tempCnt++;
}
}
if (tempCnt > modeCount){
modeCount = tempCnt;
mode = testValue;
}
}
out.println("Mode" + mode + " appears " + modeCount + " times.");
使用这个例子,我们的输出如下:
The mode is 12.5 and appears 3 times.
虽然我们前面的例子看起来简单明了,但它带来了潜在的问题。如下所示修改testData数组,其中最后一个条目更改为11.2:
double[] testData = {12.5, 18.3, 11.2, 19.0, 22.1, 14.3, 16.2, 12.5,
17.8, 16.5, 11.2};
当我们这次执行代码时,我们的输出如下:
The mode is 12.5 and appears 2 times.
问题是我们的testData数组现在包含两个各出现两次的值,12.5和11.2。这就是所谓的多模态数据集。我们可以通过基本的 Java 代码和第三方库来解决这个问题,稍后我们将展示这一点。
然而,首先我们将展示两种使用简单 Java 的方法。第一种方法将使用两个ArrayList实例,第二种方法将使用一个ArrayList和一个HashMap实例。
使用数组列表寻找多种模式
在第一种方法中,我们修改了上一个例子中使用的代码,以使用一个ArrayList类。我们将创建两个ArrayLists,一个保存数据集中的唯一数字,另一个保存每个数字的计数。我们还需要一个tempMode变量,我们接下来会用到它:
ArrayList<Integer> modeCount = new ArrayList<Integer>();
ArrayList<Double> mode = new ArrayList<Double>();
int tempMode = 0;
接下来,我们将遍历数组并测试模式列表中的每个值。如果在列表中没有找到该值,我们将它添加到mode中,并将modeCount中的相同位置设置为1。如果找到该值,我们将在modeCount中的相同位置增加1:
for (double testValue : testData){
int loc = mode.indexOf(testValue);
if(loc == -1){
mode.add(testValue);
modeCount.add(1);
}else{
modeCount.set(loc, modeCount.get(loc)+1);
}
}
接下来,我们遍历我们的modeCount列表来找到最大值。这表示数据集中最常见值的模式或频率。这允许我们选择多种模式:
for(int cnt = 0; cnt < modeCount.size(); cnt++){
if (tempMode < modeCount.get(cnt)){
tempMode = modeCount.get(cnt);
}
}
最后,我们再次遍历我们的modeCount数组,并打印出模式中与包含最大值的modeCount中的元素相对应的任何元素,或者模式:
for(int cnt = 0; cnt < modeCount.size(); cnt++){
if (tempMode == modeCount.get(cnt)){
out.println(mode.get(cnt) + " is a mode and appears " +
modeCount.get(cnt) + " times.");
}
}
当我们的代码被执行时,我们的输出反映了我们的多模态数据集:
12.5 is a mode and appears 2 times.
11.2 is a mode and appears 2 times.
使用散列表寻找多种模式
第二种方法使用HashMap。首先,我们创建ArrayList来保存可能的模式,就像前面的例子一样。我们还创建了我们的HashMap和一个变量来保存模式:
ArrayList<Double> modes = new ArrayList<Double>();
HashMap<Double, Integer> modeMap = new HashMap<Double, Integer>();
int maxMode = 0;
接下来,我们遍历我们的testData数组,并计算数组中每个值出现的次数。然后,我们将每个值的计数和值本身添加到HashMap中。如果值的计数大于我们的maxMode变量,我们将maxMode设置为新的最大值:
for (double value : testData) {
int modeCnt = 0;
if (modeMap.containsKey(value)) {
modeCnt = modeMap.get(value) + 1;
} else {
modeCnt = 1;
}
modeMap.put(value, modeCnt);
if (modeCnt > maxMode) {
maxMode = modeCnt;
}
}
最后,我们遍历我们的HashMap并检索我们的模式,或者计数等于我们的maxMode的所有值:
for (Map.Entry<Double, Integer> multiModes : modeMap.entrySet()) {
if (multiModes.getValue() == maxMode) {
modes.add(multiModes.getKey());
}
}
for(double mode : modes){
out.println(mode + " is a mode and appears " + maxMode + " times.");
}
当我们执行我们的代码时,我们得到与上一个例子相同的输出:
12.5 is a mode and appears 2 times.
11.2 is a mode and appears 2 times.
使用 Apache Commons 查找多种模式
另一种选择是使用 Apache Commons StatUtils类。这个类包含了几种统计分析的方法,包括多种平均值的方法,但是我们在这里只研究模式。该方法被命名为mode,并接受一个 doubles 数组作为其参数。它返回包含数据集所有模式的 doubles 数组:
double[] modes = StatUtils.mode(testData);
for(double mode : modes){
out.println(mode + " is a mode.");
}
一个缺点是我们不能计算我们的模式在这个方法中出现的次数。我们只知道模式是什么,而不知道它出现了多少次。当我们执行我们的代码时,我们得到一个与前面的例子相似的输出:
12.5 is a mode.
11.2 is a mode.
标准偏差
标准偏差是对平均值分布情况的测量。高偏差意味着分布很广,而低偏差意味着值更紧密地围绕平均值分组。如果没有一个焦点或者有许多异常值,这种测量可能会产生误导。
我们首先展示一个使用基本 Java 技术的简单例子。我们使用前面示例中的 testData 数组,在此复制:
double[] testData = {12.5, 18.3, 11.2, 19.0, 22.1, 14.3, 16.2, 12.5,
17.8, 16.5, 11.2};
在计算标准差之前,我们需要找到平均值。我们可以使用在计算平均值部分列出的任何技术,但是为了简单起见,我们将把我们的值相加,然后除以testData的长度:
int sum = 0;
for(double value : testData){
sum += value;
}
double mean = sum/testData.length;
接下来,我们创建一个变量sdSum,来帮助我们计算标准偏差。当我们遍历数组时,我们从每个数据值中减去平均值,对该值求平方,并将其添加到sdSum。最后,我们将sdSum除以数组的长度,然后对结果求平方:
int sdSum = 0;
for (double value : testData){
sdSum += Math.pow((value - mean), 2);
}
out.println("The standard deviation is " +
Math.sqrt( sdSum / ( testData.length ) ));
我们的输出是我们的标准差:
The standard deviation is 3.3166247903554
我们的下一个技术使用 Google Guava 的Stats类来计算标准差。我们首先用我们的testData创建一个Stats对象。我们然后调用populationStandardDeviation方法:
Stats testStats = Stats.of(testData);
double sd = testStats.populationStandardDeviation();
out.println("The standard deviation is " + sd);
输出如下所示:
The standard deviation is 3.3943803826056653
此示例计算整个总体的标准差。有时最好计算总体样本子集的标准差,以纠正可能的偏差。为了实现这一点,我们使用了与之前基本相同的代码,但是用sampleStandardDeviation替换了populationStandardDeviation方法:
Stats testStats = Stats.of(testData);
double sd = testStats.sampleStandardDeviation();
out.println("The standard deviation is " + sd);
在这种情况下,我们的输出是:
The sample standard deviation is 3.560056179332006
我们的下一个例子使用 Apache Commons DescriptiveStatistics类,我们在前面的例子中使用它来计算平均值和中值。记住,这种技术的优点是线程安全和同步。在我们创建了一个SynchronizedDescriptiveStatistics对象之后,我们添加数组中的每个值。我们然后称之为getStandardDeviation方法。
DescriptiveStatistics statTest =
new SynchronizedDescriptiveStatistics();
for(double num : testData){
statTest.addValue(num);
}
out.println("The standard deviation is " +
statTest.getStandardDeviation());
请注意,该输出与我们上一个示例的输出相匹配。默认情况下,getStandardDeviation方法返回为样本调整的标准偏差:
The standard deviation is 3.5600561793320065
然而,我们可以继续使用 Apache Commons 来计算任一形式的标准差。StandardDeviation类允许您计算总体标准偏差或子集标准偏差。为了演示不同之处,请用下面的代码替换前面的代码示例:
StandardDeviation sdSubset = new StandardDeviation(false);
out.println("The population standard deviation is " +
sdSubset.evaluate(testData));
StandardDeviation sdPopulation = new StandardDeviation(true);
out.println("The sample standard deviation is " +
sdPopulation.evaluate(testData));
在第一行,我们创建了一个新的StandardDeviation对象,并将我们的构造函数的参数设置为false,这将产生一个总体的标准偏差。第二部分使用值true,它产生样本的标准偏差。在我们的例子中,我们使用了相同的测试数据集。这意味着我们首先将它视为数据总体的一个子集。在第二个例子中,我们假设数据集是全部数据。实际上,您可能不会对这些方法中的每一种使用相同的数据集。输出如下所示:
The population standard deviation is 3.3943803826056653
The sample standard deviation is 3.560056179332006
首选方案将取决于您的样品和特定的分析需求。
样本量的确定
样本量的确定包括确定进行精确统计分析所需的数据量。处理大型数据集时,并不总是需要使用整个数据集。我们使用样本大小确定来确保我们选择的样本足够小,以方便操作和分析,但又足够大,以准确代表我们的总体数据。
使用数据的一个子集来训练模型,而使用另一个子集来测试模型,这种情况并不少见。这有助于验证数据的准确性和可靠性。样本量确定不当的一些常见后果包括假阳性结果、假阴性结果、在不存在统计显著性的情况下识别统计显著性,或者在实际存在统计显著性的情况下暗示缺乏显著性。网上有很多工具可以用来确定合适的样本量,每种工具的复杂程度都不一样。一个简单的例子是在 www.surveymonkey.com/mp/sample-s…
假设检验
假设检验用于检验关于数据集的某些假设或前提是否不会偶然发生。如果是这种情况,那么测试的结果被认为是有统计学意义的。
进行假设检验不是一项简单的任务。有许多不同的陷阱需要避免,如安慰剂效应或观察者效应。在前一种情况下,参与者将获得他们认为是预期的结果。在观察者效应中,也被称为霍桑效应,结果是有偏差的,因为参与者知道他们正在被观察。由于人类行为分析的复杂性,某些类型的统计分析特别容易出现偏差或讹误。
进行假设检验的具体方法超出了本书的范围,需要在统计过程和最佳实践方面有扎实的背景知识。Apache Commons 提供了一个包org.apache.commons.math3.stat.inference,其中包含执行假设检验的工具。这包括执行学生 T 检验、卡方检验和计算 p 值的工具。
回归分析
回归分析对于确定数据的趋势很有用。它表示因变量和自变量之间的关系。自变量决定因变量的值。每个自变量对因变量的值都有或强或弱的影响。线性回归使用散点图中的线条来显示趋势。非线性回归使用某种曲线来描述这种关系。
比如血压和年龄、体重指数()等各种因素都有关系。血压可视为因变量,其他因素可视为自变量。给定包含一组个体的这些因素的数据集,我们可以执行回归分析来查看趋势。
Java 支持几种类型的回归分析。我们将研究简单线性回归和多元线性回归。这两种方法都采用数据集,并推导出最适合数据的线性方程。简单线性回归使用一个因变量和一个自变量。多元线性回归使用多个因变量。
有几个支持简单线性回归的 API,包括:
- Apache Commons-http://Commons . Apache . org/proper/Commons-math/javadocs/API-3 . 6 . 1/index . html
- Weka-http://Weka . SourceForge . net/doc . dev/Weka/core/matrix/linear regression . html
- JFree-http://www . JFree . org/jfreechart/API/javadoc/org/JFree/data/statistics/regression . html
- **迈克尔·托马斯·弗拉纳根的 Java 科学图书馆-【www.ee.ucl.ac.uk/~mflanaga/j… **
非线性 Java 支持可在以下网址找到:
- **奥丁斯班 / 爪哇最小二乘法-【github.com/odinsbane/l… **
- 非线性最小方 ( 并行 Java 库文档)-https://www . cs . rit . edu/~ ark/pj/doc/edu/rit/numeric/非线性最小方
有几个统计数据可以评估分析的有效性。我们将把重点放在基本统计上。
残差是实际数据值和预测值之间的差值。残差平方和 ( RSS )是残差平方和。本质上,它测量数据和回归模型之间的差异。较小的 RSS 表示模型与数据非常匹配。RSS 也被称为预测的残差平方和** ( SSR )或误差平方和 ( SSE )。**
均方误差 ( MSE )是残差平方和除以自由度。自由度的数量是独立观察的数量( N )减去总体参数估计的数量。对于简单的线性回归,这个 N - 2 因为有两个参数。对于多元线性回归,它取决于使用的独立变量的数量。
较小的 MSE 也表明模型非常适合数据集。在讨论线性回归模型时,您会看到这两种统计数据。
相关系数衡量回归模型中两个变量之间的关联。相关系数从 -1 到 +1 不等。值 +1 意味着两个变量完全相关。当一个增加时,另一个也会增加。相关系数为 -1 意味着两个变量负相关。一个增加,另一个减少。值为 0 表示变量之间没有相关性。该系数通常被指定为 r。它通常是平方,因此忽略了关系的符号。通常使用皮尔逊积矩相关系数。
使用简单线性回归
简单线性回归使用最小二乘法,即计算一条线,使点和线之间距离的平方和最小。有时计算直线时不使用 Y 截距项。回归线是一个估计值。我们可以用这条线的方程来预测其他数据点。当我们想根据过去的表现预测未来的事件时,这是很有用的。
在下面的例子中,我们将 Apache Commons SimpleRegression 类与第 4 章、数据可视化中使用的比利时人口数据集一起使用。为了方便起见,这里复制了数据:
| 十年 | 人口 |
| 1950 | 8639369 |
| 1960 | 9118700 |
| 1970 | 9637800 |
| 1980 | 9846800 |
| 1990 | 9969310 |
| 2000 | 10263618 |
虽然我们将演示的应用程序是一个 JavaFX 应用程序,但我们将重点关注应用程序的线性回归方面。我们使用 JavaFX 程序生成一个图表来显示回归结果。
下面是start方法的主体。输入数据存储在一个二维数组中,如下所示:
double[][] input = {{1950, 8639369}, {1960, 9118700},
{1970, 9637800}, {1980, 9846800}, {1990, 9969310},
{2000, 10263618}};
创建了一个SimpleRegression类的实例,并使用addData方法添加了数据:
SimpleRegression regression = new SimpleRegression();
regression.addData(input);
我们将使用该模型来预测几年的行为,如下面的数组中所声明的:
double[] predictionYears = {1950, 1960, 1970, 1980, 1990, 2000,
2010, 2020, 2030, 2040};
我们还将使用下面的NumberFormat实例格式化我们的输出。一个用于带有 false 参数的setGroupingUsed方法取消逗号的年份。
NumberFormat yearFormat = NumberFormat.getNumberInstance();
yearFormat.setMaximumFractionDigits(0);
yearFormat.setGroupingUsed(false);
NumberFormat populationFormat = NumberFormat.getNumberInstance();
populationFormat.setMaximumFractionDigits(0);
SimpleRegression类拥有一个predict方法,该方法被传递一个值,在本例中是一年,并返回估计的人口。我们在循环中使用该方法,并为每年调用该方法:
for (int i = 0; i < predictionYears.length; i++) {
out.println(nf.format(predictionYears[i]) + "-"
+ nf.format(regression.predict(predictionYears[i])));
}
当程序执行时,我们得到以下输出:
**1950-8,801,975**
**1960-9,112,892**
**1970-9,423,808**
**1980-9,734,724**
**1990-10,045,641**
**2000-10,356,557**
**2010-10,667,474**
**2020-10,978,390**
**2030-11,289,307**
**2040-11,600,223**
为了以图形方式查看结果,我们生成了下面的索引图。该线与实际人口值相当吻合,并显示了未来的预测人口。
**
简单线性回归**
SimpleRegession类支持许多提供回归附加信息的方法。这些方法总结如下:
| 方法 | 意为 |
| getR | 返回皮尔逊的乘积矩相关系数 |
| getRSquare | 返回决定系数(R 平方) |
| getMeanSquareError | 返回 MSE |
| getSlope | 直线的斜率 |
| getIntercept | 截击 |
我们使用助手方法displayAttribute来显示各种属性值,如下所示:
displayAttribute(String attribute, double value) {
NumberFormat numberFormat = NumberFormat.getNumberInstance();
numberFormat.setMaximumFractionDigits(2);
out.println(attribute + ": " + numberFormat.format(value));
}
我们为我们的模型调用了前面的方法,如下所示:
displayAttribute("Slope",regression.getSlope());
displayAttribute("Intercept", regression.getIntercept());
displayAttribute("MeanSquareError",
regression.getMeanSquareError());
displayAttribute("R", + regression.getR());
displayAttribute("RSquare", regression.getRSquare());
输出如下:
**Slope: 31,091.64**
**Intercept: -51,826,728.48**
**MeanSquareError: 24,823,028,973.4**
**R: 0.97**
**RSquare: 0.94**
如您所见,模型与数据吻合得很好。
使用多元回归
我们的目的不是提供多元线性回归的详细解释,因为这超出了本节的范围。更彻底的治疗可以在 www.biddle.com/documents/b… 找到。相反,我们将解释该方法的基础,并展示我们如何使用 Java 来执行多元回归。
多元回归处理存在多个独立变量的数据。这种情况经常发生。考虑到汽车的燃油效率可能取决于所使用的汽油的辛烷值、发动机的大小、平均巡航速度和环境温度。所有这些因素都会影响燃油效率,有些因素的影响程度比其他因素更大。
自变量通常表示为 Y,其中多个因变量使用不同的 X 表示。使用三个因变量进行回归的简化方程如下,其中每个变量都有一个系数。第一项是截距。这些系数并不代表真实值,而仅用于说明目的。
Y = 11+0.75 X1+0.25 X2 2 X3
截距和系数是使用基于样本数据的多元回归模型生成的。一旦我们有了这些值,我们就可以创建一个方程来预测其他值。
我们将使用 Apache Commons OLSMultipleLinearRegression类来使用香烟数据执行多元回归。数据改编自http://www . amstat . org/publications/jse/v2 n1/datasets . McIntyre . html。该数据由不同品牌香烟的 25 个条目组成,包含以下信息:
- 商标名称
- 焦油含量(毫克)
- 尼古丁含量(毫克)
- 重量(克)
- 一氧化碳含量(毫克)
数据存储在名为data.csv的文件中,如以下部分内容列表所示,其中列值与之前列表的顺序相匹配:
Alpine,14.1,.86,.9853,13.6
Benson&Hedges,16.0,1.06,1.0938,16.6
BullDurham,29.8,2.03,1.1650,23.5
CamelLights,8.0,.67,.9280,10.2
...
以下是显示数据关系的散点图:
**
多元回归散点图**
我们将使用 JavaFX 程序来创建散点图并执行分析。我们从如下所示的MainApp类开始。在本例中,我们将重点关注多元回归代码,不包括用于创建散点图的 JavaFX 代码。完整的程序可以从 www.packtpub.com/support 下载。
数据保存在一维数组中,一个NumberFormat实例将用于格式化这些值。数组大小反映了每个条目的 25 条目和 4 值。在本例中,我们不会使用品牌名称。
public class MainApp extends Application {
private final double[] data = new double[100];
private final NumberFormat numberFormat =
NumberFormat.getNumberInstance();
...
public static void main(String[] args) {
launch(args);
}
}
使用如下所示的CSVReader实例将数据读入数组:
int i = 0;
try (CSVReader dataReader = new CSVReader(
new FileReader("data.csv"), ',')) {
String[] nextLine;
while ((nextLine = dataReader.readNext()) != null) {
String brandName = nextLine[0];
double tarContent = Double.parseDouble(nextLine[1]);
double nicotineContent = Double.parseDouble(nextLine[2]);
double weight = Double.parseDouble(nextLine[3]);
double carbonMonoxideContent =
Double.parseDouble(nextLine[4]);
data[i++] = carbonMonoxideContent;
data[i++] = tarContent;
data[i++] = nicotineContent;
data[i++] = weight;
...
}
}
Apache Commons 拥有两个执行多元回归的类:
OLSMultipleLinearRegression- 普通最小二乘(OLS) 回归GLSMultipleLinearRegression- 广义最小二乘(GLS) 回归
当使用后一种技术时,模型元素之间的相关性会对结果产生负面影响。我们将使用OLSMultipleLinearRegression类,并从它的实例化开始:
OLSMultipleLinearRegression ols =
new OLSMultipleLinearRegression();
我们将使用newSampleData方法来初始化模型。这种方法需要数据集中的观测值个数和自变量个数。它可能抛出一个需要处理的IllegalArgumentException异常。
int numberOfObservations = 25;
int numberOfIndependentVariables = 3;
try {
ols.newSampleData(data, numberOfObservations,
numberOfIndependentVariables);
} catch (IllegalArgumentException e) {
// Handle exceptions
}
接下来,我们将小数点后的位数设置为 2,并调用estimateRegressionParameters方法。这将为我们的等式返回一组值,然后显示这些值:
numberFormat.setMaximumFractionDigits(2);
double[] parameters = ols.estimateRegressionParameters();
for (int i = 0; i < parameters.length; i++) {
out.println("Parameter " + i +": " +
numberFormat.format(parameters[i]));
}
当执行时,我们将得到以下输出,这为我们的回归方程提供了所需的参数:
**Parameter 0: 3.2**
**Parameter 1: 0.96**
**Parameter 2: -2.63**
**Parameter 3: -0.13**
为了根据一组独立变量预测一个新的依赖值,声明了getY方法,如下所示。parameters参数包含生成的方程系数。arguments参数包含因变量的值。这些用于计算返回的新从属值:
public double getY(double[] parameters, double[] arguments) {
double result = 0;
for(int i=0; i<parameters.length; i++) {
result += parameters[i] * arguments[i];
}
return result;
}
我们可以通过创建一系列独立的值来测试这种方法。这里我们使用了与数据文件中的SalemUltra条目相同的值:
double arguments1[] = {1, 4.5, 0.42, 0.9106};
out.println("X: " + 4.9 + " y: " +
numberFormat.format(getY(parameters,arguments1)));
这将为我们提供以下值:
**X: 4.9 y: 6.31**
6.31的返回值与4.9的实际值不同。然而,使用VirginiaSlims的值:
double arguments2[] = {1, 15.2, 1.02, 0.9496};
out.println("X: " + 13.9 + " y: " +
numberFormat.format(getY(parameters,arguments2)));
我们得到以下结果:
**X: 13.9 y: 15.03**
这接近于13.9的实际值。接下来,我们使用一组不同于数据集中的值:
double arguments3[] = {1, 12.2, 1.65, 0.86};
out.println("X: " + 9.9 + " y: " +
numberFormat.format(getY(parameters,arguments3)));
结果如下:
**X: 9.9 y: 10.49**
这些值不同,但仍然很接近。下图显示了与原始数据相关的预测数据:
**
多重回归预测**
OLSMultipleLinearRegression类还拥有几种方法来评估模型与数据的吻合程度。然而,由于多元回归的复杂性,我们在这里没有讨论它们。
**# 总结
在本章中,我们简要介绍了在数据科学应用中可能遇到的基本统计分析技术。我们从计算一组数字数据的平均值、中值和众数的简单技术开始。标准 Java 和第三方 Java APIs 都用来展示如何计算这些属性。虽然这些技术相对简单,但在计算时需要考虑一些问题。
接下来,我们研究了线性回归。这种技术本质上更具预测性,并试图根据样本数据集计算未来或过去的其他值。我们研究了简单线性回归和多元回归,并使用 Apache Commons 类来执行回归,使用 JavaFX 来绘制图形。
简单线性回归使用单个自变量来预测因变量。多元回归使用一个以上的自变量。这两种技术都有用于评估它们与数据匹配程度的统计属性。
我们演示了如何使用 Apache Commons OLSMultipleLinearRegression类来使用香烟数据执行多元回归。我们能够使用多种属性来创建一个预测一氧化碳排放量的方程。
有了这些统计技术,我们现在可以在下一章检查基本的机器学习技术。这将包括多层感知器和各种其他神经网络的详细讨论。**
六、机器学习
机器学习是一个广泛的话题,有许多不同的支持算法。它通常关注开发一些技术,这些技术允许应用程序学习,而不需要显式地编程来解决问题。通常,建立模型是为了解决一类问题,然后使用来自问题域的样本数据进行训练。在这一章中,我们将讨论一些数据科学中更常见的问题和模型。
这些技术中的许多使用训练数据来教导模型。数据由问题空间的各种代表性元素组成。一旦该模型被训练,就使用测试数据对其进行测试和评估。然后,使用该模型和输入数据进行预测。
例如,商店顾客的购买可以用来训练模型。随后,可以对具有相似特征的客户进行预测。由于预测客户行为的能力,有可能提供特殊的交易或服务来吸引客户返回或促进他们的访问。
有几种对机器学习技术进行分类的方法。一种方法是根据学习风格对他们进行分类:
- 监督学习:通过监督学习,用将输入特征值与正确输出值相匹配的数据来训练模型
- 无监督学习:在无监督学习中,数据不包含结果,但是模型被期望自己确定关系。
- 半监督:该技术使用少量包含正确答案的标记数据和大量未标记数据。这种结合可以带来更好的结果。
- 强化学习:这类似于监督学习,但是对好的结果提供奖励。
- 深度学习:这种方法使用包含多个处理级别的图来建模高级抽象。
在这一章中,我们将只能触及其中的一些技术。具体来说,我们将举例说明使用监督学习的三种技术:
- 决策树:使用问题的特征作为内部节点,结果作为叶子来构建一棵树
- 支持向量机:通常用于分类,通过创建一个分离数据集的超平面,然后进行预测
- 贝叶斯网络:用于描述环境中事件之间概率关系的模型
对于无监督学习,我们将展示如何使用关联规则学习来发现数据集元素之间的关系。然而,我们不会在这一章中讨论无监督学习。
我们将讨论强化学习的要素,并讨论这种技术的一些具体变化。我们还将提供进一步探索的资源链接。
深度学习的讨论推迟到第八章、深度学习。这项技术建立在神经网络的基础上,这将在第 7 章、神经网络中讨论。
在本章中,我们将讨论以下具体主题:
- 决策树
- 支持向量机
- 贝叶斯网络
- 关联规则学习
- 强化学习
监督学习技术
有大量的监督机器学习算法可用。我们将研究其中的三种:决策树、支持向量机和贝叶斯网络。它们都使用包含属性和正确答案的带注释的数据集。通常,使用训练和测试数据集。
我们从讨论决策树开始。
决策树
机器学习决策树是一种用于进行预测的模型。它有效地将某些观察映射到关于目标的结论。术语树来自反映不同状态或价值的分支。树叶代表结果,树枝代表导致结果的特征。在数据挖掘中,决策树是用于分类的数据描述。例如,我们可以使用决策树来根据收入水平和邮政编码等特定属性来确定个人是否可能购买某件商品。
我们希望创建一个决策树,根据其他变量来预测结果。当目标变量取连续值,如实数时,该树被称为回归树。
树由内部节点和叶子组成。每个内部节点代表模型的一个特征,例如受教育的年数或者一本书是平装本还是精装本。从内部节点引出的边表示这些特征的值。每片叶子被称为一个类,并且有一个相关的概率分布。
例如,我们将使用一个数据集,该数据集根据书籍的装订类型、颜色使用和流派来处理书籍的成功与否。基于该数据集的一个可能的决策树如下:
决策图表
决策树很有用,也很容易理解。即使对于大型数据集,为模型准备数据也很简单。
决策树类型
通过将输入数据集除以特征,可以对树进行训练。这通常以递归方式完成,被称为递归划分或决策树的自顶向下归纳 ( TDIDT )。当节点的值与目标的值都是同一类型或者递归不再增加值时,递归是有界的。
分类回归树 ( 大车)分析是指两种不同类型的决策树类型:
- 分类树分析:叶子对应一个目标特征
- 回归树分析:叶子拥有一个代表特征的实数
在分析过程中,可能会创建多个树。有几种技术可以用来创建树。这些技术被称为集成方法:
- Bagging 决策树:数据被重新采样并经常用于获得基于共识的预测
- 随机森林分类器:用于提高分类率
- 提升树:这可用于回归或分类问题
- 旋转森林:使用一种叫做主成分分析 ( PCA )的技术
对于给定的一组数据,有可能不止一棵树对数据进行建模。例如,树的根可以指定银行是否有 ATM 机,随后的内部节点可以指定出纳员的数量。然而,可以创建这样的树,其中出纳员的数量是根,ATM 的存在是内部节点。树的结构差异可以决定树的效率有多高。
有许多方法可以确定树中节点的顺序。一种技术是选择提供最多信息增益的属性;也就是说,选择一个能更好地帮助快速缩小可能决策范围的属性。
决策树库
有几个 Java 库支持决策树:
- 韦卡:【www.cs.waikato.ac.nz/ml/weka/】T2
- Apache Spark:https://Spark . Apache . org/docs/1 . 2 . 0/ml lib-decision-tree . html
- JBoss:【jboost.sourceforge.net】T2
- 机器学习语言工具包 ( 木槌):【mallet.cs.umass.edu】T4
我们将使用怀卡托知识分析环境 ( Weka )来演示如何用 Java 创建决策树。Weka 是一个具有 GUI 界面的工具,允许对数据进行分析。也可以从命令行或通过我们将使用的 Java API 调用它。
在构建树时,选择一个变量来分割树。有几种方法可以用来选择变量。我们使用哪一个取决于通过选择一个变量获得了多少信息。具体来说,我们将使用 Weka 的J48类支持的 C4.5 算法。
Weka 使用一个.arff文件来保存数据集。这个文件是可读的,由两部分组成。第一个是标题部分;它描述了文件中的数据。本节使用&符号来指定数据的关系和属性。第二段是数据段;它由一组逗号分隔的数据组成。
对图书数据集使用决策树
对于这个例子,我们将使用一个名为books.arff的文件。接下来显示了它,它使用了四个称为属性的特性。这些功能指定了一本书是如何装订的,它是否使用多种颜色,它的流派,以及表明该书是否被购买的结果。标题部分如下所示:
@RELATION book_purchases
@ATTRIBUTE Binding {Hardcover, Paperback, Leather}
@ATTRIBUTE Multicolor {yes, no}
@ATTRIBUTE Genre {fiction, comedy, romance, historical}
@ATTRIBUTE Result {Success, Failure}
数据部分如下,由 13 个书条目组成:
@DATA
Hardcover,yes,fiction,Success
Hardcover,no,comedy,Failure
Hardcover,yes,comedy,Success
Leather,no,comedy,Success
Leather,yes,historical,Success
Paperback,yes,fiction,Failure
Paperback,yes,romance,Failure
Leather,yes,comedy,Failure
Paperback,no,fiction,Failure
Paperback,yes,historical,Failure
Hardcover,yes,historical,Success
Paperback,yes,comedy,Success
Hardcover,yes,comedy,Success
我们将使用下面定义的BookDecisionTree类来处理这个文件。它使用一个构造函数和三个方法:
BookDecisionTree:读入教练数据并创建一个用于处理数据的Instance对象main:驱动应用程序performTraining:使用数据集训练模型getTestInstance:创建一个测试用例
Instances类保存代表单个数据集元素的元素:
public class BookDecisionTree {
private Instances trainingData;
public static void main(String[] args) {
...
}
public BookDecisionTree(String fileName) {
...
}
private J48 performTraining() {
...
}
private Instance getTestInstance(
...
}
}
构造函数打开一个文件并使用BufferReader实例创建一个Instances类的实例。数据集的每个元素要么是要素,要么是结果。setClassIndex方法指定了结果类的索引。在这种情况下,它是数据集的最后一个索引,对应于成功或失败:
public BookDecisionTree(String fileName) {
try {
BufferedReader reader = new BufferedReader(
new FileReader(fileName));
trainingData = new Instances(reader);
trainingData.setClassIndex(
trainingData.numAttributes() - 1);
} catch (IOException ex) {
// Handle exceptions
}
}
我们将使用J48类来生成一个决策树。这个类使用 C4.5 决策树算法来生成修剪或未修剪的树。方法指定使用未修剪的树。buildClassifier方法实际上是基于所使用的数据集创建分类器:
private J48 performTraining() {
J48 j48 = new J48();
String[] options = {"-U"};
try {
j48.setOptions(options);
j48.buildClassifier(trainingData);
} catch (Exception ex) {
ex.printStackTrace();
}
return j48;
}
我们想要测试这个模型,所以我们将为每个测试用例创建一个实现Instance接口的对象。一个getTestInstance helper 方法被传递了三个参数,代表一个数据元素的三个特性。DenseInstance类是一个实现Instance接口的类。传递的值被分配给实例,并返回实例:
private Instance getTestInstance(
String binding, String multicolor, String genre) {
Instance instance = new DenseInstance(3);
instance.setDataset(trainingData);
instance.setValue(trainingData.attribute(0), binding);
instance.setValue(trainingData.attribute(1), multicolor);
instance.setValue(trainingData.attribute(2), genre);
return instance;
}
main方法使用前面所有的方法来处理和测试我们的图书数据集。首先,使用图书数据集文件的名称创建一个BookDecisionTree实例:
public static void main(String[] args) {
try {
BookDecisionTree decisionTree =
new BookDecisionTree("books.arff");
...
} catch (Exception ex) {
// Handle exceptions
}
}
接下来,调用performTraining方法来训练模型。我们还显示了树:
J48 tree = decisionTree.performTraining();
System.out.println(tree.toString());
执行时,将显示以下内容:
J48 unpruned tree
------------------
Binding = Hardcover: Success (5.0/1.0)
Binding = Paperback: Failure (5.0/1.0)
Binding = Leather: Success (3.0/1.0)
Number of Leaves : 3
Size of the tree : 4
测试图书决策树
我们将用两个不同的测试用例来测试这个模型。两者都使用相同的代码来设置实例。我们使用带有测试用例特定值的getTestInstance方法,然后使用带有classifyInstance的实例来获得结果。为了获得更具可读性的内容,我们生成一个字符串,然后显示如下:
Instance testInstance = decisionTree.
getTestInstance("Leather", "yes", "historical");
int result = (int) tree.classifyInstance(testInstance);
String results = decisionTree.trainingData.attribute(3).value(result);
System.out.println(
"Test with: " + testInstance + " Result: " + results);
testInstance = decisionTree.
getTestInstance("Paperback", "no", "historical");
result = (int) tree.classifyInstance(testInstance);
results = decisionTree.trainingData.attribute(3).value(result);
System.out.println(
"Test with: " + testInstance + " Result: " + results);
执行这段代码的结果如下:
Test with: Leather,yes,historical Result: Success
Test with: Paperback,no,historical Result: Failure
这符合我们的预期。这种技术是基于在做出排序决定之前和之后获得的信息量。这可以基于如下计算的熵来测量:
Entropy = -portionPos * log2(portionPos) - portionNeg* log2(portionNeg)
在这个例子中,portionPos是正的数据部分,portionNeg是负的数据部分。基于 books 文件,我们可以计算绑定的熵,如下表所示。通过从 1.0 中减去用于结合的熵来计算信息增益:
我们可以用类似的方式计算颜色和类型使用的熵。颜色的信息增益为 0.05 ,流派的信息增益为 0.15 。因此,对树的第一层使用绑定类型更有意义。
由于 C4.5 算法确定剩余的特征不提供任何额外的信息增益,因此该示例的结果树由两层组成。
当选择具有大量值的特征时,例如客户的信用卡号,信息获取可能会有问题。使用这种类型的属性会迅速缩小范围,但它的选择性太强,没有多大价值。
支持向量机
一个支持向量机 ( SVM )是一个监督机器学习算法,用于分类和回归问题。它主要用于分类问题。该方法创建超平面来对训练数据进行分类。超平面可以被想象成分隔两个区域的几何平面。在二维空间中,它将是一条线。在三维空间中,它将是一个二维平面。对于更高的维度,更难概念化,但它们确实存在。
考虑下图,该图描述了两种类型的数据点的分布。这些线代表分隔这些点的可能的超平面。SVM 过程的一部分是为问题数据集寻找最佳超平面。我们将在编码示例中详细阐述这个数字。
超平面示例
支持向量是位于超平面附近的数据点。SVM 模型使用核的概念将输入数据映射到更高阶的维度空间,以使数据更容易结构化。这样做的映射函数可能导致无限维空间;也就是说,可能存在无限数量的可能映射。
然而,所谓的内核技巧,内核函数是一种避免这种映射并避免可能发生的不可行计算的方法。支持向量机支持不同类型的内核。内核列表可以在http://crsouza . com/2010/03/kernel-functions-for-machine-learning-applications/找到。选择合适的内核取决于问题。常用的内核包括:
- 线性:使用一个线性超平面
- 多项式:使用超平面的多项式方程
- 径向基函数(RBF) :使用非线性超平面
- Sigmoid:Sigmoid 核,也称为双曲正切核,来自神经网络领域,相当于一个两层感知器神经网络
这些内核支持不同的数据分析算法。
支持向量机对于人类难以想象的高维空间非常有用。在上图中,两个属性用于预测第三个属性。当存在更多属性时,可以使用 SVM。需要对 SVM 进行训练,对于较大的数据集,这可能需要更长的时间。
我们将使用 Weka 类SMO来演示 SVM 分析。该类支持 John Platt 的顺序最小优化算法。关于这个算法的更多信息可以在https://www . Microsoft . com/en-us/research/publication/fast-training-of-support-vector-machines-using-sequential-minimal-optimization/找到。
SMO类支持以下内核,可以在使用该类时指定:
- Puk :基于皮尔逊 VII 函数的通用核
- 多内核:多项式内核
- RBF kernel:RBF 内核
该算法使用训练数据来创建分类模型。然后,测试数据可用于评估模型。我们还可以评估单个数据元素。
使用 SVM 获取露营数据
为了便于说明,我们将使用一个由年龄、收入和某人是否露营组成的数据集。我们希望能够根据年龄和收入预测某人是否倾向于露营。我们使用的数据以.arff格式存储,并非基于调查,而是为了解释 SVM 进程而创建的。输入数据在camping.txt文件中找到,如下所示。文件扩展名不必是.arff:
@relation camping
@attribute age numeric
@attribute income numeric
@attribute camps {1, 0}
@data
23,45600,1
45,65700,1
72,55600,1
24,28700,1
22,34200,1
28,32800,1
32,24600,1
25,36500,1
26,91000,0
29,85300,0
67,76800,0
86,58900,0
56,125300,0
25,125000,0
22,43600,1
78,125700,1
73,56500,1
29,87600,0
65,79300,0
下图显示了数据的分布情况。注意右上角的异常值。生成此图的 JavaFX 代码位于www.packtpub.com/support:
野营图
我们将从读入数据和处理异常开始:
try {
BufferedReader datafile;
datafile = readDataFile("camping.txt");
...
} catch (Exception ex) {
// Handle exceptions
}
readDataFile方法如下:
public BufferedReader readDataFile(String filename) {
BufferedReader inputReader = null;
try {
inputReader = new BufferedReader(
new FileReader(filename));
} catch (FileNotFoundException ex) {
// Handle exceptions
}
return inputReader;
}
Instances类保存一系列数据实例,其中每个实例都是年龄、收入和露营值。setClassIndex方法指出哪个属性将被预测。在本例中,它是camps属性:
Instances data = new Instances(datafile);
data.setClassIndex(data.numAttributes() - 1);
为了训练模型,我们将把数据集分成两组。第一个14实例用于训练模型,最后一个5实例用于测试模型。Instances构造函数的第二个参数指定数据集中的起始索引,最后一个参数指定要包含多少个实例:
Instances trainingData = new Instances(data, 0, 14);
Instances testingData = new Instances(data, 14, 5);
创建一个Evaluation类实例来评估模型。还创建了一个SMO类的实例。SMO类的buildClassifier方法使用数据集构建分类器:
Evaluation evaluation = new Evaluation(trainingData);
Classifier smo = new SMO();
smo.buildClassifier(data);
evaluateModel方法使用测试数据评估模型。然后显示结果:
evaluation.evaluateModel(smo, testingData);
System.out.println(evaluation.toSummaryString());
输出如下。请注意一个错误分类的实例。这对应于前面提到的异常值:
Correctly Classified Instances 4 80 %
Incorrectly Classified Instances 1 20 %
Kappa statistic 0.6154
Mean absolute error 0.2
Root mean squared error 0.4472
Relative absolute error 41.0256 %
Root relative squared error 91.0208 %
Coverage of cases (0.95 level) 80 %
Mean rel. region size (0.95 level) 50 %
Total Number of Instances 5
测试个别实例
我们还可以使用classifyInstance方法测试一个单独的实例。在下面的序列中,我们使用DenseInstance类创建一个新的实例。然后使用露营数据集的属性对其进行填充:
Instance instance = new DenseInstance(3);
instance.setValue(data.attribute("age"), 78);
instance.setValue(data.attribute("income"), 125700);
instance.setValue(data.attribute("camps"), 1);
需要使用setDataset方法将实例与数据集相关联:
instance.setDataset(data);
然后将classifyInstance方法应用于smo实例,并显示结果:
System.out.println(smo.classifyInstance(instance));
执行时,我们得到以下输出:
1.0
也有替代的测试方法。常见的一种叫做交叉验证折叠。这种方法将数据集分为*褶皱、*褶皱,这些褶皱是数据集的分区。通常会创建 10 个分区。九个分区用于训练,一个用于测试。每次使用数据集的不同分区重复 10 次,并使用结果的平均值。这个技巧在https://WEKA . wikispaces . com/Generating+cross-validation+folds+(Java+approach)有描述。
我们现在将检查贝叶斯网络的目的和使用。
贝叶斯网络
贝叶斯网络,也称为贝叶斯网或信念网络,是通过描述世界不同属性的状态及其统计关系来反映特定世界或环境的模型。这些模型可以用来展示各种各样的真实场景。在下图中,我们建立了一个系统模型,描述了各种因素与我们上班迟到可能性之间的关系:
贝叶斯网络
图上的每个圆圈代表系统的一个节点或部分,它可以有不同的值和每个值的概率。例如,停电可能是真的或假的——要么停电,要么没有停电。停电的概率会影响你的闹钟不响的概率,你可能会睡过头,从而上班迟到。
图表顶部的节点往往比底部的节点意味着更高层次的因果关系。更高的节点称为父节点,它们可能有一个或多个子节点。贝叶斯网络只涉及具有因果相关性的节点,因此允许更有效地计算概率。与其他模型不同,我们不必存储和分析每个节点的每种可能的状态组合。相反,我们可以计算和存储相关节点的概率。此外,贝叶斯网络很容易适应,并且可以随着关于特定世界的更多知识的获得而增长。
使用贝叶斯网络
为了使用 Java 对这种类型的网络进行建模,我们将使用 JB eyes(github.com/vangj/jbaye…)来创建一个网络。JBayes 是一个开源库,用于创建一个简单的贝叶斯信念网络 ( BBN )。它可以免费用于个人或商业用途。在我们的下一个例子中,我们将执行近似推理,这是一种被认为不太准确但可以减少计算时间的技术。这种技术经常在处理大数据时使用,因为它可以在合理的时间内生成可靠的模型。我们通过对每个节点进行加权采样来进行近似推理。JBayes 还提供了对精确推理的支持。精确推断最常用于较小的数据集或准确性非常重要的情况。JBayes 使用连接树算法执行精确推理。
为了开始我们的近似推理模型,我们将首先创建我们的节点。我们将使用前面描述影响准时到达的属性的图表来构建我们的网络。在下面的代码示例中,我们使用方法链接来创建节点。其中三个方法带有一个String参数。name方法是与每个节点相关联的名称。为了简洁起见,我们只使用首字母,所以 s 代表storms , t代表traffic,以此类推。value方法允许我们为节点设置值。在每种情况下,我们的节点只能有两个值:t表示真,或者f表示假:
Node storms = Node.newBuilder().name("s").value("t").value("f").build();
Node traffic = Node.newBuilder().name("t").value("t").value("f").build();
Node powerOut = Node.newBuilder().name("p").value("t").value("f").build();
Node alarm = Node.newBuilder().name("a").value("t").value("f").build();
Node overslept = Node.newBuilder().name("o").value("t").value("f").build();
Node lateToWork = Node.newBuilder().name("l").value("t").value("f").build();
接下来,我们为每个子节点分配父节点。请注意,storms是traffic和powerOut的父节点。lateToWork节点有两个父节点,traffic和overslept:
traffic.addParent(storms);
powerOut.addParent(storms);
lateToWork.addParent(traffic);
alarm.addParent(powerOut);
overslept.addParent(alarm);
lateToWork.addParent(overslept);
然后,我们为每个节点定义条件概率表 ( CPTs )。这些表基本上是表示每个节点的每个属性的概率的二维数组。如果我们有不止一个父节点,就像在lateToWork节点的情况下,我们需要为每个节点准备一行。在这个例子中,我们使用了任意的概率值,但是注意每一行的总和必须是1.0:
storms.setCpt(new double[][] {{0.7, 0.3}});
traffic.setCpt(new double[][] {{0.8, 0.2}});
powerOut.setCpt(new double[][] {{0.5, 0.5}});
alarm.setCpt(new double[][] {{0.7, 0.3}});
overslept.setCpt(new double[][] {{0.5, 0.5}});
lateToWork.setCpt(new double[][] {
{0.5, 0.5},
{0.5, 0.5}
});
最后,我们创建一个Graph对象,并将每个节点添加到我们的图结构中。然后,我们使用此图进行采样:
Graph bayesGraph = new Graph();
bayesGraph.addNode(storms);
bayesGraph.addNode(traffic);
bayesGraph.addNode(powerOut);
bayesGraph.addNode(alarm);
bayesGraph.addNode(overslept);
bayesGraph.addNode(lateToWork);
bayesGraph.sample(1000);
此时,我们可能对每个事件的概率感兴趣。我们可以使用prob方法来检查每个节点的True或False值的概率:
double[] stormProb = storms.probs();
double[] trafProb = traffic.probs();
double[] powerProb = powerOut.probs();
double[] alarmProb = alarm.probs();
double[] overProb = overslept.probs();
double[] lateProb = lateToWork.probs();
out.println("nStorm Probabilities");
out.println("True: " + stormProb[0] + " False: " + stormProb[1]);
out.println("nTraffic Probabilities");
out.println("True: " + trafProb[0] + " False: " + trafProb[1]);
out.println("nPower Outage Probabilities");
out.println("True: " + powerProb[0] + " False: " + powerProb[1]);
out.println("vAlarm Probabilities");
out.println("True: " + alarmProb[0] + " False: " + alarmProb[1]);
out.println("nOverslept Probabilities");
out.println("True: " + overProb[0] + " False: " + overProb[1]);
out.println("nLate to Work Probabilities");
out.println("True: " + lateProb[0] + " False: " + lateProb[1]);
我们的输出包含每个节点的每个值的概率。例如,风暴发生的概率是 71%,而不发生的概率是 29%:
Storm Probabilities
True: 0.71 False: 0.29
Traffic Probabilities
True: 0.726 False: 0.274
Power Outage Probabilities
True: 0.442 False: 0.558
Alarm Probabilities
True: 0.543 False: 0.457
Overslept Probabilities
True: 0.556 False: 0.444
Late to Work Probabilities
True: 0.469 False: 0.531
注意
请注意,在这个例子中,我们使用了产生上班迟到可能性非常高的数字,大约为 47%。这是因为我们已经将父节点的概率设置得相当高。如果风暴发生的几率较低,或者如果我们也改变了一些其他的子节点,这个数据会有很大的变化。
如果我们想保存有关样本的信息,可以使用以下代码将数据保存到 CSV 文件中:
try {
CsvUtil.saveSamples(bayesGraph, new FileWriter(
new File("C://JBayesInfo.csv")));
} catch (IOException e) {
// Handle exceptions
}
关于监督学习的讨论结束后,我们现在将转向无监督学习。
无监督机器学习
无监督机器学习不使用带注释的数据;也就是说,数据集确实包含预期的结果。虽然有几种无监督学习算法,但我们将展示关联规则学习的使用来说明这种学习方法。
关联规则学习
关联规则学习是一种识别数据项之间关系的技术。这是所谓的市场篮子分析的一部分。当购物者进行购买时,这些购买很可能由不止一个项目组成,并且当它这样做时,有某些项目倾向于一起购买。关联规则学习是识别这些相关项目的一种方法。当发现关联时,可以为其制定规则。
例如,如果顾客购买尿布和乳液,他们也可能购买婴儿湿巾。分析可以发现这些关联,并且可以形成陈述观察结果的规则。该规则将被表达为*{尿布、洗液} =>{湿巾}* 。能够识别这些购买模式允许商店提供特殊优惠券,安排他们的产品更容易得到,或者实现任何数量的其他市场相关活动。
这种技术的一个问题是存在大量可能的关联。一种常用的有效方法是先验算法。该算法处理由一组项目定义的事务集合。这些项目可以被认为是购买,而交易可以被认为是一起购买的一组项目。该集合通常被称为数据库。
考虑下面的一组交易,其中, 1 表示该物品是作为交易的一部分购买的,而 0 表示该物品没有被购买:
| 交易 ID | 尿布 | 乳液 | 湿巾 | 公式 | | one | one | one | one | Zero | | Two | one | one | one | one | | three | Zero | one | one | Zero | | four | one | Zero | Zero | Zero | | five | Zero | one | one | one |
先验模型使用了几个分析术语:
- Support :这是数据库中包含项目子集的项目的比例。在之前的数据库中,{尿不湿,乳液} 项出现 2/5 次或者 20% 。
- 置信度:这是对规则为真的频率的度量。其计算方式为conf(X->Y)= sup(X∪Y)/sup(X)。
- Lift :衡量项目相互依赖的程度。定义为 lift(X->Y)=sup(X∪Y)/(sup(X) sup(Y))*。
- 杠杆率:杠杆率是指在 X 和 Y 相互独立的情况下, X 和 Y 所涵盖的交易数量。高于 0 的值是一个好的指示器。计算方法为 lev(X- > Y) = sup(X,Y) - sup(X) * sup(Y) 。
- 信念:衡量规则做出错误决定的频率。定义为conv(X->Y)= 1-sup(Y)/(1-conf(X->Y))。
这些定义和样本值可以在en.wikipedia.org/wiki/Associ…找到。
利用关联规则学习发现购买关系
我们将使用Apriori Weka 类来演示 Java 对使用两个数据集的算法的支持。第一个是之前讨论的数据,第二个是关于一个人在徒步旅行中可能携带的物品。
以下是婴儿信息的数据文件babies.arff:
@relation TEST_ITEM_TRANS
@attribute Diapers {1, 0}
@attribute Lotion {1, 0}
@attribute Wipes {1, 0}
@attribute Formula {1, 0}
@data
1,1,1,0
1,1,1,1
0,1,1,0
1,0,0,0
0,1,1,1
我们从使用一个BufferedReader实例读入文件开始。这个对象被用作Instances类的参数,它将保存数据:
try {
BufferedReader br;
br = new BufferedReader(new FileReader("babies.arff"));
Instances data = new Instances(br);
br.close();
...
} catch (Exception ex) {
// Handle exceptions
}
接下来,创建一个Apriori实例。我们设置要生成的规则数量和规则的最小置信度:
Apriori apriori = new Apriori();
apriori.setNumRules(100);
apriori.setMinMetric(0.5);
buildAssociations方法使用Instances变量生成关联。然后显示关联:
apriori.buildAssociations(data);
System.out.println(apriori);
将显示 100 条规则。以下是简短的输出。每个规则后面都有该规则的各种度量:
注意
请注意,规则 8 和 100 反映了前面的例子。
Apriori
=======
Minimum support: 0.3 (1 instances)
Minimum metric <confidence>: 0.5
Number of cycles performed: 14
Generated sets of large itemsets:
Size of set of large itemsets L(1): 8
Size of set of large itemsets L(2): 18
Size of set of large itemsets L(3): 16
Size of set of large itemsets L(4): 5
Best rules found:
1\. Wipes=1 4 ==> Lotion=1 4 <conf:(1)> lift:(1.25) lev:(0.16) [0] conv:(0.8)
2\. Lotion=1 4 ==> Wipes=1 4 <conf:(1)> lift:(1.25) lev:(0.16) [0] conv:(0.8)
3\. Diapers=0 2 ==> Lotion=1 2 <conf:(1)> lift:(1.25) lev:(0.08) [0] conv:(0.4)
4\. Diapers=0 2 ==> Wipes=1 2 <conf:(1)> lift:(1.25) lev:(0.08) [0] conv:(0.4)
5\. Formula=1 2 ==> Lotion=1 2 <conf:(1)> lift:(1.25) lev:(0.08) [0] conv:(0.4)
6\. Formula=1 2 ==> Wipes=1 2 <conf:(1)> lift:(1.25) lev:(0.08) [0] conv:(0.4)
7\. Diapers=1 Wipes=1 2 ==> Lotion=1 2 <conf:(1)> lift:(1.25) lev:(0.08) [0] conv:(0.4)
8\. Diapers=1 Lotion=1 2 ==> Wipes=1 2 <conf:(1)> lift:(1.25) lev:(0.08) [0] conv:(0.4)
...
62\. Diapers=0 Lotion=1 Formula=1 1 ==> Wipes=1 1 <conf:(1)> lift:(1.25) lev:(0.04) [0] conv:(0.2)
...
99\. Lotion=1 Formula=1 2 ==> Diapers=1 1 <conf:(0.5)> lift:(0.83) lev:(-0.04) [0] conv:(0.4)
100\. Diapers=1 Lotion=1 2 ==> Formula=1 1 <conf:(0.5)> lift:(1.25) lev:(0.04) [0] conv:(0.6)
这为我们提供了一个关系列表,我们可以用它来识别购买行为等活动中的模式。
强化学习
强化学习是当前神经网络和机器学习研究前沿的一种学习类型。与无监督和有监督的学习不同,强化学习基于动作的结果做出决策。这是一个以目标为导向的学习过程,类似于世界各地许多家长和教师使用的方法。我们教孩子们学习并在考试中表现出色,这样他们就能得到高分作为奖励。同样,强化学习可以用来教机器做出能带来最高回报的选择。
强化学习有四个主要组成部分:行动者或代理人、状态或场景、选择的行动和奖励。参与者是在应用程序中做出决策的对象或工具。国家是行动者存在的世界。行动者做出的任何决定都发生在国家的参数范围内。动作只是演员在给定一组选项时做出的选择。回报是每一个行动的结果,并影响未来选择特定行动的可能性。
必须指出,行动和行动发生的国家不是独立的。事实上,正确的或回报最高的行为往往取决于行为发生的状态。如果演员试图决定如何穿过水体,如果水体平静且相当小,游泳可能是一个不错的选择。如果演员想横渡太平洋,游泳将是一个可怕的选择。
要处理这个问题,我们可以考虑 Q 函数。该功能是由特定状态到该状态中的动作的映射产生的。Q 函数会将游过太平洋的奖励比游过小河的奖励低。Q 函数不是说游泳是一种低回报的活动,而是允许游泳有时有低回报,而其他时候有更高的回报。
强化学习总是从一张白纸开始。当迭代第一次开始时,参与者不知道最佳路径或决策序列。然而,在通过给定问题的多次迭代之后,考虑每个特定状态-动作对选择的结果,算法改进并学习做出最高回报的选择。
用于实现强化学习的算法包括在一系列复杂的过程和选择中实现回报的最大化。虽然目前正在视频游戏和其他离散环境中进行测试,但最终目标是这些算法在不可预测的现实世界场景中取得成功。在强化学习的主题中,有三种主要风格或类型:时间差异学习、q `-学习和状态-动作-奖励-状态-动作 ( SARSA )。
时间差异学习考虑先前学习的信息,以通知未来的决策。这种类型的学习假设了过去和未来决策之间的相关性。在采取行动之前,会进行预测。在选择行动之前,将该预测与关于环境的其他已知信息和类似决策进行比较。这一过程被称为自举,被认为是创造更准确和有用的结果。
Q-learning 使用上面提到的 Q 函数,不仅选择给定状态下某一特定步骤的最佳动作,而且选择从该点向前将导致最高奖励的动作。这就是所谓的最优策略。Q-learning 提供的一个很大的优势是不需要完整的状态模型就能做出决策。这使得它在行动和奖励随机变化的状态下发挥作用。
SARSA 是另一种用于强化学习的算法。它的名字是不言自明的:Q 值取决于当前的状态,当前选择的动作,该动作的奖励,动作完成后代理将存在的状态,以及在新状态下采取的后续动作。该算法向前看一步,以做出最佳决策。
目前可用于使用 Java 执行强化学习的工具有限。一个流行的工具是用于实现 Q 学习实验的平台 ( Piqle )。这个 Java 框架旨在为快速设计和测试或强化学习实验提供工具。Piqle 可以从 piqle.sourceforge.net 的下载。另一个健壮的工具叫做布朗-UMBC 强化学习和规划 ( BURPLAP )。在 http://burlap.cs.brown.edu发现的这个库也是为强化学习的算法和领域的开发而设计的。这种特殊的资源以状态和动作的灵活性而自豪,并支持广泛的规划和学习算法。BURLAP 还包括用于可视化目的的分析工具。
总结
机器学习与开发技术有关,这些技术允许应用程序学习,而不必显式编程来解决问题。这种灵活性允许这种应用程序在几乎不做修改的情况下用于更多样的设置中。
我们看到了如何使用训练数据来创建模型。一旦训练了模型,就使用测试数据来评估该模型。训练数据和测试数据都来自问题域。一旦完成训练,该模型将与其他输入数据一起用于进行预测。
我们学习了如何使用 Weka Java API 来创建决策树。该树由代表问题不同属性的内部节点组成。树叶代表结果。因为有许多方法来构造一棵树,所以决策树的一部分工作就是创建最好的树。
支持向量机将数据集分成多个部分,从而对数据集中的元素进行分类。这种分类基于数据的属性,如年龄、头发颜色或体重。使用该模型,可以根据数据实例的属性预测结果。
贝叶斯网络用于根据节点之间的父子关系进行预测。一个事件的概率直接影响子事件的概率,我们可以使用这些信息来预测复杂现实环境的结果。
在关联规则学习部分,我们学习了如何识别数据集元素之间的关系。更重要的关系允许我们建立规则来解决各种问题。
在我们对强化学习的讨论中,我们讨论了主体、状态、行动和奖励的要素以及它们之间的关系。我们还讨论了强化学习的具体类型,并为进一步的研究提供了资源。
在介绍了机器学习的要素之后,我们现在准备探索神经网络,这将在下一章中找到。