精通 Java 数据科学(一)
零、前言
如今,数据科学已经成为组织的一个非常重要的工具:他们已经收集了大量的数据,为了能够很好地利用这些数据,他们需要数据科学——关于从数据中提取知识的方法的学科。每天都有越来越多的公司意识到他们可以从数据科学中受益,并更有效、更有利地利用他们生产的数据。
对于 It 公司来说尤其如此,他们已经有了生成和处理数据的系统和基础设施。这些系统通常是用 Java 编写的,Java 是世界上许多大公司和小公司的首选语言。这并不奇怪,Java 提供了一个非常坚实和成熟的库生态系统,经过时间的考验和可靠,所以许多人信任 Java 并使用它来创建他们的应用程序。
因此,它也是许多数据处理应用程序的自然选择。由于现有的系统已经在 Java 中,因此对数据科学使用相同的技术堆栈是有意义的,并将机器学习模型直接集成到应用程序的生产代码库中。
这本书将涵盖这一点。我们将首先了解如何利用 Java 的工具箱来处理小型和大型数据集,然后研究如何进行初步的勘探数据分析。接下来,我们将回顾实现分类、回归、聚类和降维问题的通用机器学习模型的 Java 库。然后我们将进入更高级的技术,并讨论信息检索和自然语言处理、XGBoost、深度学习和用于处理大数据集的大规模工具,如 Apache Hadoop 和 Apache Spark。最后,我们还将了解如何评估和部署生产的模型,以便其他服务可以使用它们。
我们希望你会喜欢这本书。快乐阅读!
这本书涵盖的内容
第 1 章,使用 Java 的数据科学,概述了 Java 中现有的可用工具,并介绍了处理数据科学项目的方法,CRISP-DM。在这一章中,我们还介绍了我们正在运行的例子,构建一个搜索引擎。
第 2 章,数据处理工具箱,回顾标准 Java 库:用于在内存中存储数据的集合 API,用于读写数据的 IO API,以及用于组织数据处理管道的便捷方式的流 API。我们将研究标准库的扩展,如 Apache Commons Lang、Apache Commons IO、Google Guava 和 AOL Cyclops React。然后,我们将介绍存储数据的最常见方式——文本和 CSV 文件、HTML、JSON 和 SQL 数据库,并讨论如何从这些数据源获取数据。在本章的最后,我们将讨论如何为正在运行的搜索引擎收集数据,以及如何为此准备数据。
第 3 章,探索性数据分析,用 Java 执行数据的初步分析:我们看看如何计算常见的统计数据,如最小值和最大值、平均值和标准偏差。我们还谈了一点交互式分析,看看有哪些工具可以让我们在构建模型之前直观地检查数据。对于本章中的插图,我们使用我们为搜索引擎收集的数据。
第四章,监督学习——分类和回归,从机器学习开始,然后看 Java 中执行监督学习的模型。其中,我们看看如何使用下面的库——Smile、JSAT、LIBSVM、LIBLINEAR 和 Encog,我们看看如何使用这些库来解决分类和回归问题。我们在这里使用两个例子,首先,我们使用搜索引擎数据来预测一个 URL 是否会出现在结果的第一页,我们使用它来说明分类问题。其次,我们预测在给定硬件特性的情况下,两个矩阵相乘需要多少时间,并通过这个例子说明回归问题。
第五章,无监督学习——聚类和降维,探讨 Java 中可用的降维方法,我们将学习如何应用 PCA 和随机投影来降低这些数据的维度。上一章的硬件性能数据集说明了这一点。我们还研究了不同的数据聚类方法,包括凝聚聚类、K-Means 和 DBSCAN,并使用客户投诉数据集作为示例。
第 6 章,处理文本——自然语言处理和信息检索,讲述如何在数据科学应用中使用文本,我们还将学习如何为我们的搜索引擎提取更多有用的特征。我们还研究了 Apache Lucene,一个用于全文索引和搜索的库,以及 Stanford CoreNLP,一个用于执行自然语言处理的库。接下来,我们看看如何将单词表示为向量,并学习如何从共现矩阵中构建这样的嵌入,以及如何使用现有的嵌入,如 GloVe。我们还看了如何将机器学习用于文本,我们用一个情感分析问题来说明它,其中我们应用 LIBLINEAR 来分类评论是正面还是负面。
第 7 章,极限梯度提升,讲述了如何在 Java 中使用 XGBoost,并尝试将其应用于我们之前遇到的两个问题,分类 URL 是否出现在第一页,预测两个矩阵相乘的时间。此外,我们看看如何用 XGBoost 解决学习排序问题,并再次使用我们的搜索引擎示例作为说明。
第八章,用 DeepLearning4j 进行深度学习,涵盖了深度神经网络和 DeepLearning4j,一个用 Java 构建和训练这些网络的库。特别是,我们谈论卷积神经网络,看看我们如何使用它们进行图像识别-预测它是一只狗还是一只猫的图片。此外,我们讨论了数据扩充——生成更多数据的方法,还提到了我们如何使用 GPU 来加速训练。我们通过描述如何在 Amazon AWS 上租用 GPU 服务器来结束这一章。
第 9 章、扩展数据科学,讲述 Java、Apache Hadoop 和 Apache Spark 中可用的大数据工具。我们通过查看如何处理常见的抓取(互联网的副本)并计算每个文档的 TF-IDF 来说明这一点。此外,我们查看了 Apache Spark 中可用的图形处理工具,并为科学家建立了一个推荐系统,我们为下一篇可能的论文推荐了一位合著者。
第 10 章,部署数据科学模型,着眼于我们如何以一种可用的方式向外界公开模型。在这里,我们涵盖了 Spring Boot,并讨论了如何使用我们开发的搜索引擎模型对普通抓取的文章进行排序。最后,我们讨论了在线环境中评估模型性能的方法,并讨论了 A/B 测试和多武装匪徒。
这本书你需要什么
你需要拥有至少 2GB 内存和 Windows 7 /Ubuntu 14.04/Mac OS X 操作系统的最新系统。此外,您需要安装 Java 1.8.0 或更高版本以及 Maven 3.0.0 或更高版本。
这本书是给谁的
本书面向那些熟悉 Java 应用程序开发并熟悉数据科学基本概念的软件工程师。此外,对于那些还不了解 Java,但希望或需要学习它的数据科学家来说,它也很有用。
约定
在这本书里,你会发现许多区分不同种类信息的文本样式。下面是这些风格的一些例子和它们的含义的解释。
文本中的码字、数据库表名、文件夹名、文件名、文件扩展名、路径名、伪 URL、用户输入和 Twitter 句柄如下所示:“这里,我们创建SummaryStatistics对象并添加所有正文内容长度。”
代码块设置如下:
SummaryStatistics statistics = new SummaryStatistics(); data.stream().mapToDouble(RankedPage::getBodyContentLength)
.forEach(statistics::addValue);
System.out.println(statistics.getSummary());
任何命令行输入或输出都按如下方式编写:
mvn dependency:copy-dependencies -DoutputDirectory=lib
mvn compile
新术语和重要词汇以粗体显示。你在屏幕上看到的单词,例如,在菜单或对话框中,出现在文本中,如下所示:“相反,如果我们的模型输出一些分数,使得分数值越高,项目越有可能是肯定的,那么二元分类器被称为排名分类器。”
警告或重要提示出现在这样的框中。
提示和技巧是这样出现的。
读者反馈
我们随时欢迎读者的反馈。让我们知道你对这本书的看法——你喜欢或不喜欢什么。读者的反馈对我们来说很重要,因为它有助于我们开发出真正让你受益匪浅的图书。
要给我们发送总体反馈,只需发送电子邮件feedback@packtpub.com,并在邮件主题中提及书名。
如果有一个你擅长的主题,并且你有兴趣写一本书或者为一本书投稿,请查看我们在www.packtpub.com/authors的作者指南。
客户支持
既然您已经是 Packt book 的骄傲拥有者,我们有许多东西可以帮助您从购买中获得最大收益。
下载示例代码
你可以从你在www.packtpub.com的账户下载本书的示例代码文件。如果你在其他地方购买了这本书,你可以访问 www.packtpub.com/support 的并注册,让文件直接通过电子邮件发送给你。
您可以按照以下步骤下载代码文件:
- 使用您的电子邮件地址和密码登录或注册我们的网站。
- 将鼠标指针悬停在顶部的支持选项卡上。
- 点击代码下载和勘误表。
- 在搜索框中输入图书的名称。
- 选择您要下载代码文件的书。
- 从下拉菜单中选择您购买这本书的地方。
- 点击代码下载。
下载文件后,请确保使用最新版本的解压缩或解压文件夹:
- WinRAR / 7-Zip for Windows
- 适用于 Mac 的 Zipeg / iZip / UnRarX
- 用于 Linux 的 7-Zip / PeaZip
该书的代码包也托管在 GitHub 的 https://GitHub . com/packt publishing/Mastering-Java-for-Data-Science 上。我们在 github.com/PacktPublis…](github.com/PacktPublis…)
下载这本书的彩色图片
我们还为您提供了一个 PDF 文件,其中包含本书中使用的截图/图表的彩色图像。彩色图像将帮助您更好地理解输出中的变化。你可以从https://www . packtpub . com/sites/default/files/downloads/MasteringJavaforDataScience _ color images . pdf下载这个文件。
正误表
尽管我们已尽一切努力确保内容的准确性,但错误还是会发生。如果您在我们的某本书中发现了一个错误——可能是文本或代码中的错误——如果您能向我们报告,我们将不胜感激。这样做,你可以让其他读者免受挫折,并帮助我们改进本书的后续版本。如果您发现任何勘误表,请通过访问www.packtpub.com/submit-erra…,选择您的图书,点击勘误表提交表格链接,并输入勘误表的详细信息来报告。一旦您的勘误表得到验证,您的提交将被接受,该勘误表将被上传到我们的网站或添加到该标题的勘误表部分下的任何现有勘误表列表中。
要查看之前提交的勘误表,请前往www.packtpub.com/books/conte…并在搜索栏中输入图书名称。所需信息将出现在勘误表部分。
海盗行为
互联网上版权材料的盗版是所有媒体都存在的问题。在 Packt,我们非常重视版权和许可证的保护。如果您在互联网上发现我们作品的任何形式的非法拷贝,请立即向我们提供地址或网站名称,以便我们采取补救措施。
请通过copyright@packtpub.com联系我们,并提供可疑盗版材料的链接。
我们感谢您帮助保护我们的作者,以及我们为您带来有价值内容的能力。
问题
如果您对本书的任何方面有问题,可以通过questions@packtpub.com联系我们,我们将尽最大努力解决问题。
一、使用 Java 的数据科学
这本书是关于使用 Java 语言构建数据科学应用程序的。在本书中,我们将涵盖实现项目的所有方面,从数据准备到模型部署。
假设本书的读者以前接触过 Java 和数据科学,本书将有助于将这些知识提升到一个新的水平。这意味着学习如何有效地解决特定的数据科学问题,并最大限度地利用可用数据。
这是一个介绍性的章节,我们将在这里为所有其他章节奠定基础。在这里,我们将讨论以下主题:
- 什么是机器学习和数据科学?
- 数据挖掘的跨行业标准流程 ( CRIPS-DM ),一种进行数据科学项目的方法论
- 面向大中型数据科学应用的 Java 机器学习库
在本章结束时,你将知道如何着手一个数据科学项目,以及使用什么样的 Java 库来做这件事。
数据科学
数据科学是从各种形式的数据中提取可操作知识的学科。数据科学这个名字最近才出现——它是由 DJ Patil 和 Jeff Hammerbacher 发明的,并在 2012 年的文章数据科学家:21 世纪最性感的工作中得到推广。但是这个学科本身已经存在了很长一段时间,之前以其他名字为人所知,如数据挖掘或预测分析。数据科学,像它的前辈一样,建立在统计和机器学习算法的基础上,用于知识提取和模型构建。
术语数据科学的科学部分并非巧合——如果我们查阅科学,它的定义可以概括为以可测试的解释和预测为术语的知识的系统组织。这正是数据科学家所做的,通过从可用数据中提取模式,他们可以对未来的未知数据进行预测,并确保预测事先得到验证。
如今,数据科学应用于许多领域,包括(但不限于):
- 银行业:风险管理(例如,信用评分)、欺诈检测、交易
- 保险:索赔管理(例如,加快索赔审批)、风险和损失评估,以及欺诈检测
- 保健:预测疾病(如中风、糖尿病、癌症)和复发
- 零售 和 电子商务:购物篮分析(识别搭配良好的产品)、推荐引擎、产品分类和个性化搜索
本书涵盖了以下实际使用案例:
- 预测 URL 是否可能出现在搜索引擎的第一页
- 在给定硬件规格的情况下,预测操作完成的速度
- 为搜索引擎排列文本文档
- 检查图片上是猫还是狗
- 在社交网络中推荐朋友
- 在计算机集群上处理大规模文本数据
在所有这些情况下,我们将使用数据科学从数据中学习,并使用学到的知识来解决特定的业务问题。
我们还将在整本书中使用一个运行示例,构建一个搜索引擎。我们将使用它来说明许多数据科学概念,如监督机器学习、降维、文本挖掘和学习排序模型。
机器学习
机器学习是计算机科学的一部分,是数据科学的核心。数据本身,尤其是大量的数据,几乎没有什么用处,但是数据里面隐藏着非常有价值的模式。在机器学习的帮助下,我们可以识别这些隐藏的模式,提取它们,然后将学习到的信息应用到新的看不见的项目上。
例如,给定一个动物的图像,机器学习算法可以说出图片是狗还是猫;或者,考虑到银行客户的历史,它会说客户违约的可能性有多大,即无法偿还债务。
通常,机器学习模型被视为黑盒,它接受数据点并输出对它的预测。在这本书里,我们将看看这些黑盒里有什么,看看如何以及何时最好地使用它们。
机器学习解决的典型问题可以分为以下几组:
- 监督学习:对于每个数据点,我们都有一个标签- 额外信息,描述我们想要学习的结果。在猫对狗的情况下,数据点是动物的图像;标签描述的是狗还是猫。
- 无监督学习:我们只有原始数据点,没有标签信息可用。例如,我们有一组电子邮件,我们希望根据它们的相似程度对它们进行分组。没有与电子邮件相关联的明确标签,这使得该问题无人监管。
- 半监督学习:只对一部分数据给出标签。
- 强化学习:我们没有标签,有奖励;模型通过与它运行的环境互动得到的东西。基于奖励,它可以适应并最大化它。比如,一个学习下棋的模型,每吃掉对手一个图形就获得一个正奖励,每输一个图形就获得一个负奖励;而报酬与数字的价值成正比。
监督学习
正如我们之前讨论的,对于监督学习,我们有一些信息附加到每个数据点,标签,我们可以训练一个模型来使用它并从中学习。例如,如果我们想建立一个模型,告诉我们一张图片上是狗还是猫,那么图片就是数据点,是狗还是猫的信息就是标签。再比如预测房子的价格——房子的描述就是数据点,价格就是标签。
我们可以根据这些信息的性质将监督学习的算法分为分类和回归算法。
在分类问题中,标签来自于某个固定的有限类集合,比如{猫、狗}、{默认、非默认},或者{办公室、美食、娱乐、家居}。根据类的数量不同,分类问题可以是二元(只有两个可能的类)或者多类(几个类)。
分类算法的例子有朴素贝叶斯、逻辑回归、感知器、支持向量机()等等。我们将在第四章第一部分、监督学习-分类和回归中更详细地讨论分类算法。
在回归**问题中,标号是实数。例如,一个人的年薪可以从 0 美元到几十亿美元不等。因此,预测工资是一个回归问题。
回归算法的例子有线性回归、LASSO、支持向量回归 ( SVR )等。这些算法将在第二部分第四章、监督学习-分类和回归中详细描述。
一些监督学习方法是通用的,可以应用于分类和回归问题。例如,决策树、随机森林和其他基于树的方法可以处理这两种类型。我们将在第七章、的中讨论一个这样的算法,梯度提升机器。
神经网络还可以同时处理分类和回归问题,我们会在第八章、用 DeepLearning4J 进行深度学习中讲到。** **
无监督学习
无监督学习涵盖了我们没有标签可用,但仍然希望找到隐藏在数据中的一些模式的情况。有几种类型的无监督学习,我们将研究聚类分析,或聚类和无监督降维。
使聚集
通常,当人们谈论无监督学习时,他们会谈论聚类分析或聚类。聚类分析算法获取一组数据点,并尝试将它们分类成组,使得相似的项目属于同一组,而不同的项目不属于同一组。有许多方法可以使用它,例如,在客户细分或文本分类。
客户细分是聚类的一个例子。给定客户的一些描述,我们尝试将他们分组,使得一个组中的客户具有相似的简档并以相似的方式行为。这些信息可以用来了解这些群体中的人们想要什么,并且这可以用来为他们提供更好的广告和其他促销信息。
再比如文本分类。给定一个文本集合,我们希望在这些文本中找到共同的主题,并根据这些主题排列文本。例如,给定一个电子商务商店中的一组投诉,我们可能希望将谈论类似事情的投诉放在一起,这应该有助于系统用户更容易地浏览投诉。
聚类分析算法的例子有层次聚类、k-means、带噪声的应用的基于密度的空间聚类 ( DBSCAN )等等。我们会在第五章第一部分、无监督学习——聚类与降维中详细讲聚类。
降维
另一组无监督学习算法是降维算法。这组算法压缩数据集,只保留最有用的信息。如果我们的数据集包含太多信息,机器学习算法很难同时使用所有这些信息。算法处理所有数据可能需要太长时间,我们希望压缩数据,因此处理时间会更短。
有多种算法可以降低数据的维数,包括主成分分析 ( PCA )、局部线性嵌入和 t-SNE。所有这些算法都是无监督降维技术的例子。
不是所有的降维算法都是无监督的;他们中的一些人可以使用标签来更好地降低维度。例如,许多特征选择算法依赖于标签来查看哪些特征是有用的,哪些是无用的。
我们将在第五章、无监督学习-聚类和降维中详细讨论这一点。
自然语言处理
处理自然语言文本是非常复杂的,它们的结构不是很好,需要大量的清理和规范化。然而,我们周围的文本信息量是巨大的:每分钟都产生大量的文本数据,很难从中检索有用的信息。使用数据科学和机器学习对文本问题也很有帮助;它们让我们找到正确的文本,处理它,并提取有价值的信息。
我们可以通过多种方式使用文本信息。一个例子是信息检索,或者简单地说,文本搜索——给定一个用户查询和一组文档,我们希望在语料库中找到与该查询最相关的文档,并将它们呈现给用户。其他应用包括情感分析——预测产品评论是积极的、中立的还是消极的,或者根据评论如何谈论产品来对评论进行分组。
我们将在第六章、中更多地讨论信息检索、自然语言处理 ( NLP )和文本处理——自然语言处理和信息检索中的文本处理。此外,我们将在第 9 章、缩放数据科学中了解如何处理大量文本数据。
我们可以用于机器学习和数据科学的方法非常重要。同样重要的是我们创造它们,然后将它们用于生产系统的方式。数据科学流程模型帮助我们使其更有组织性和系统性,这也是我们接下来将讨论它们的原因。
数据科学过程模型
应用数据科学不仅仅是选择合适的机器学习算法并将其用于数据。记住机器学习只是项目的一小部分总是好的;还有其他部分,如了解问题、收集数据、测试解决方案和部署到生产环境中。
当从事任何项目时,不仅仅是数据科学项目,将它分解成更小的可管理的部分并逐个完成它们是有益的。对于数据科学,有描述如何以最佳方式完成的最佳实践,它们被称为流程模型。有多个型号,包括 CRISP-DM 和 OSEMN。
在本章中,CRISP-DM 被解释为获取、擦洗、探索、建模和解释 ( OSEMN ),它更适合于数据分析任务,并在较小的程度上解决了许多重要步骤。
CRISP-DM
数据挖掘的跨行业标准过程 ( CRISP-DM )是一种开发数据挖掘应用的过程方法论。它是在术语数据科学变得流行之前创建的,它是可靠的,并且经过了几代分析的时间考验。这些实践现在仍然有用,并且很好地描述了任何分析项目的高级步骤。
图片来源:https://en . Wikipedia . org/wiki/File:CRISP-DM _ Process _ diagram . png
CRISP-DM 方法将项目分解为以下步骤:
- 商业理解
- 数据理解
- 数据准备
- 建模
- 估价
- 部署
方法本身定义的不仅仅是这些步骤,但是通常了解步骤是什么以及每个步骤发生了什么对于一个成功的数据科学项目来说已经足够了。让我们分别看一下这些步骤。
第一步业务理解。这一步旨在了解企业存在什么样的问题,以及他们希望通过解决这些问题来实现什么。为了取得成功,数据科学应用程序必须对业务有用。这一步的结果是我们想要解决的问题的公式化,以及项目期望的结果是什么。
第二步是数据理解。在这一步,我们试图找出哪些数据可以用来解决问题。我们还需要找出我们是否已经有了数据;如果没有,我们需要思考如何才能得到它。根据我们找到(或没有找到)的数据,我们可能想要改变最初的目标。
当数据被收集后,我们需要探索它。审查数据的过程通常被称为探索性数据分析,它是任何数据科学项目不可或缺的一部分。它有助于理解创建数据的过程,并且已经可以提出解决问题的方法。这一步的结果是了解解决问题需要哪些数据源。我们将在第三章、探索性数据分析中详细讲述这一步。
CRISP-DM 的第三步是数据准备。为了使数据集有用,需要对其进行清理并将其转换为表格形式。表格形式意味着每行恰好对应一个观察值。如果我们的数据不是这种形状,它就不能被大多数机器学习算法使用。因此,我们需要准备数据,以便最终可以将其转换为矩阵形式并提供给模型。
此外,可能存在包含所需信息的不同数据集,并且它们可能不是同质的。这意味着我们需要将这些数据集转换成某种通用格式,以便模型能够读取。
这一步还包括特征工程——创建最能提供问题信息并以最佳方式描述数据的特征的过程。
许多数据科学家表示,在构建数据科学应用程序时,他们将大部分时间花在这一步上。我们将在第二章、数据处理工具箱以及整本书中谈到这一步。
第四步建模。在这一步中,数据已经处于正确的形状,我们将它馈送给不同的机器学习算法。这一步还包括参数调整、特性选择和选择最佳模型。
从机器学习的角度评估模型的质量发生在这个步骤中。要检查的最重要的事情是归纳的能力,这通常是通过交叉验证来完成的。在这一步中,我们可能还想回到上一步,做额外的清理和功能工程。结果是一个模型,它可能对解决步骤 1 中定义的问题有用。
第五步评估。它包括从商业角度评估模型——而不是从机器学习的角度。这意味着我们需要对迄今为止的结果进行严格审查,并计划下一步行动。模型达到我们想要的了吗?此外,一些发现可能会导致重新考虑最初的问题。在这一步之后,我们可以转到部署步骤或重复流程。
最后的第六步是模型部署。在这个步骤中,生产的模型被添加到产品中,因此结果是模型被集成到活动系统中。我们将在第 10 章、部署数据科学模型中介绍这一步骤。
通常,评估是困难的,因为并不总是能够说出模型是否达到了预期的结果。在这些情况下,可以将评估和部署步骤合并为一个步骤,只对部分用户部署和应用模型,然后收集用于评估模型的数据。我们还将在本书的最后一章简要介绍他们的做法,如 A/B 测试和多臂土匪。
连续的例子
整本书会有很多实际的用例,有时每章会有几个。但是我们也会有一个运行的例子,构建一个搜索引擎。这个问题很有趣,原因有很多:
- 这很有趣
- 几乎任何领域的企业都可以从搜索引擎中受益
- 许多企业已经有了文本数据;通常它没有被有效地使用,并且它的使用可以被改进
- 处理文本需要很大的努力,学会有效地处理文本是很有用的
我们将尽量保持简单,但是,通过这个例子,我们将在整本书中触及数据科学过程的所有技术部分:
- 数据理解:哪些数据可以对问题有用?我们如何获得这些数据?
- 数据准备:数据一旦获得,我们该如何处理?如果是 HTML,我们如何从中提取文本?我们如何从文本中提取单个的句子和单词?
- 建模:根据文档与查询的相关性对文档进行排序是一个数据科学问题,我们将讨论如何实现这个问题。
- 评估:可以对搜索引擎进行测试,看它对解决业务问题是否有用。
- 部署:最后,引擎可以作为 REST 服务部署,或者直接集成到实时系统中。
我们将在第二章、数据处理工具箱中获取并准备数据,在第三章、探索性数据分析中理解数据,在第四章、监督机器学习-分类和回归中构建简单模型并进行评估,在第六章、中查看如何处理文本使用文本-自然语言处理和信息检索, 在第 9 章、扩展数据科学中了解如何将它应用于数百万个网页,最后,在第 10 章、部署数据科学模型中了解我们如何部署它。
Java 中的数据科学
在本书中,我们将使用 Java 进行数据科学项目。乍一看,Java 似乎不是数据科学的好选择,不像 Python 或 R,它的数据科学和机器学习库更少,更冗长,缺乏交互性。另一方面,它有很多优点,如下所示:
- Java 是一种静态类型的语言,这使得维护代码库更容易,更难犯愚蠢的错误——编译器可以检测到其中的一些错误。
- 数据处理的标准库非常丰富,甚至还有更丰富的外部库。
- Java 代码通常比通常用于数据科学的脚本语言(如 R 或 Python)的代码更快。
- Maven 是 Java 世界中依赖性管理的事实上的标准,它使得向项目中添加新的库和避免版本冲突变得非常容易。
- 大多数用于可扩展数据处理的大数据框架都是用 Java 或 JVM 语言编写的,如 Apache Hadoop、Apache Spark 或 Apache Flink。
- 生产系统通常是用 Java 编写的,用其他语言构建模型会增加不必要的复杂性。用 Java 创建模型使得将它们集成到产品中变得更加容易。
接下来,我们将看看 Java 中可用的数据科学库。
数据科学图书馆
虽然与 R 相比,Java 中的数据科学库不多,但也不少。此外,通常可以使用用其他 JVM 语言编写的机器学习和数据挖掘库,如 Scala、Groovy 或 Clojure。因为这些语言共享运行时环境,所以很容易导入用 Scala 编写的库,并直接在 Java 代码中使用它们。
我们可以将图书馆分为以下几类:
- 数据处理库
- 数学和统计库
- 机器学习和数据挖掘库
- 文本处理库
现在,我们将详细了解它们。
数据处理库
标准 Java 库非常丰富,提供了很多数据处理工具,比如集合、I/O 工具、数据流和并行任务执行的手段。
标准库有非常强大的扩展,例如:
- 谷歌番石榴(github.com/google/guav…)和阿帕奇共同收藏(commons.apache.org/collections…)更丰富的收藏
- 用于简化 I/O 的 Apache Commons IO(commons.apache.org/io/)
- AOL Cyclops-React(github.com/aol/cyclops…)实现更丰富的功能方式并行流
我们将在第 2 章、数据处理工具箱中介绍数据处理的标准 API 及其扩展。在本书中,我们将使用 Maven 来包含外部库,如 Google Guava 或 Apache Commons IO。它是一个依赖管理工具,允许用几行 XML 代码指定外部依赖。比如添加谷歌番石榴,在pom.xml中声明如下依赖关系就足够了:
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<version>19.0</version>
</dependency>
当我们这样做时,Maven 将转到 Maven 中央存储库并下载指定版本的依赖项。找到pom.xml(比如上一个)的依赖片段的最好方法是在mvnrepository.com或者你最喜欢的搜索引擎上搜索。
Java 提供了一种通过 Java 数据库连接 ( JDBC )访问数据库的简单方法——一种统一的数据库访问协议。JDBC 使得连接几乎任何支持 SQL 的关系数据库成为可能,比如 MySQL、MS SQL、Oracle、PostgreSQL 等等。这允许将数据操作从 Java 转移到数据库端。
当无法使用数据库处理表格数据时,我们可以使用 DataFrame 库直接在 Java 中处理。DataFrame 是一种最初来自 R 的数据结构,它允许在程序中轻松地操作文本数据,而无需借助外部数据库。
例如,使用数据帧,可以根据某些条件过滤行,对列的每个元素应用相同的操作,按某些条件分组或与另一个数据帧连接。此外,一些数据框库可以轻松地将表格数据转换为矩阵形式,以便机器学习算法可以使用这些数据。
Java 中有一些数据框库。其中一些如下:
- 细木工(cardillo.github.io/joinery/)
- 餐桌锯(github.com/lwhite1/tab…
- 鞍(saddle.github.io/)Scala 的数据框架库
- 阿帕奇火花数据帧(spark.apache.org/
我们还将在第 2 章、数据处理工具箱中介绍数据库和数据框架,我们将在整本书中使用数据框架。
还有更复杂的数据处理库比如 Spring Batch(projects.spring.io/spring-batc…)。它们允许创建复杂的数据管道(从提取-转换-加载称为 ETL)并管理它们的执行。
此外,还有用于分布式数据处理的库,例如:
- 阿帕奇 Hadoop(hadoop.apache.org/)
- 阿帕奇火花(spark.apache.org/)
- Apache 小程序(https://flip . Apache . org/
我们将在第 9 章、扩展数据科学中讨论分布式数据处理。
数学和统计库
标准 Java 库中的数学支持相当有限,只包括计算对数的log、计算指数的exp等基本方法。
有更丰富的数学支持的外部库。例如:
- Apache Commons Math(commons.apache.org/math/)用于统计、优化和线性代数
- Apache Mahout(mahout.apache.org/)用于线性代数,还包括一个分布式线性代数和机器学习模块
- JBlas(jblas.org/)优化和非常快速的线性代数包,使用 Blas 库
此外,许多机器学习库带有一些额外的数学功能,通常是线性代数、统计和优化。
机器学习和数据挖掘库
有相当多的机器学习和数据挖掘库可用于 Java 和其他 JVM 语言。其中一些如下:
- WEKA(www.cs.waikato.ac.nz/ml/weka/)可能是 Java 中最著名的数据挖掘库,包含大量算法,有很多扩展。
- JavaML(java-ml.sourceforge.net/)是一个相当老且可靠的 ML 库,但不幸的是不再更新了
- smile(haifengl.github.io/smile/)是一个很有前途的 ML 库,目前正在积极开发中,很多新的方法正在加入其中。
- github.com/EdwardRaff/… 的 JSAT 包含了一系列令人印象深刻的机器学习算法。
- H2O(www.h2o.ai/)是一个用 Java 编写的分布式 ML 框架,但可用于多种语言,包括 Scala、R 和 Python。
- Apache Mahout(mahout.apache.org/)用于核内(一台机器)和分布式机器学习。Mahout Samsara 框架允许以独立于框架的方式编写代码,然后在 Spark、Flink 或 H2O 上执行。
有几个专门研究神经网络的库:
- 安可(www.heatonresearch.com/encog/)
- 深度学习 4j(deeplearning4j.org/
我们将在整本书中介绍其中的一些库。
文本处理
可以只使用标准 Java 库进行简单的文本处理,该库包含诸如StringTokenizer、java.text包或正则表达式之类的类。
除此之外,还有各种各样的文本处理框架可用于 Java,如下所示:
- Apache Lucene(lucene.apache.org/)是一个用于信息检索的库
- 斯坦福·科伦普(stanfordnlp.github.io/CoreNLP/
- Apache OpenNLP(opennlp.apache.org/)
- 凌派(alias-i.com/lingpipe/)
- 大门(gate.ac.uk/)
- 木槌(mallet.cs.umass.edu/)
- smile(haifengl.github.io/smile/)也有一些 NLP 的算法
大多数 NLP 库具有非常相似的功能和算法覆盖范围,这就是为什么选择使用哪一个通常是一个习惯或品味的问题。它们通常都有标记化、解析、词性标注、命名实体识别和其他文本处理算法。其中有些(比如 StanfordNLP)支持多种语言,有些只支持英语。
我们将在第 6 章、中介绍这些库,使用文本自然语言处理和信息检索。
摘要
在本章中,我们简要讨论了数据科学以及机器学习在其中扮演的角色。然后我们谈到做一个数据科学项目,以及什么方法论对它有用。我们讨论了其中的一个,CRISP-DM,它定义的步骤,这些步骤是如何关联的,以及每个步骤的结果。
最后,我们谈到了为什么用 Java 做数据科学项目是一个好主意,它是静态编译的,它很快,而且通常现有的生产系统已经在 Java 中运行。我们还提到了使用 Java 语言成功完成数据科学项目的库和框架。
有了这个基础,我们现在将进入数据科学项目中最重要(也是最耗时)的步骤——数据准备。**
二、数据处理工具箱
在前一章中,我们讨论了处理数据科学问题的最佳实践。我们看了 CRISP-DM,它是处理数据挖掘项目的方法论,其中第一步是数据预处理。在这一章中,我们将仔细看看如何在 Java 中做到这一点。
具体来说,我们将涵盖以下主题:
- 标准 Java 库
- 标准库的扩展
- 从不同来源读取数据,比如文本、HTML、JSON 和数据库
- 用于操作表格数据的数据框架
最后,我们会把所有东西放在一起,为搜索引擎准备数据。
本章结束时,你将能够处理数据,使其可用于机器学习和进一步分析。
标准 Java 库
标准 Java 库非常丰富,提供了许多数据操作工具,包括:
- 用于在内存中组织数据的集合
- 用于读写数据的 I/O
- 简化数据转换的流式 API
在本章中,我们将详细了解所有这些工具。
收集
数据是数据科学最重要的部分。当处理数据时,它需要被有效地存储和处理,为此我们使用数据结构。数据结构描述了有效存储数据以解决特定问题的方法,Java 集合 API 是数据结构的标准 Java API。这个 API 提供了在实际数据科学应用中有用的各种各样的实现。
我们不会详细描述集合 API,而是集中在最有用和最重要的部分——列表、集合和映射接口。
列表是集合,其中每个元素都可以通过其索引来访问。List接口的 g0-to 实现是ArrayList,它应该在 99%的情况下使用,可以如下使用:
List<String> list = new ArrayList<>();
list.add("alpha");
list.add("beta");
list.add("beta");
list.add("gamma");
System.out.println(list);
还有其他的List接口、LinkedList或CopyOnWriteArrayList的实现,但是它们很少被用到。
Set 是集合 API 中的另一个接口,它描述了一个不允许重复的集合。如果我们插入元素的顺序不重要,那么首选实现是HashSet,如果顺序重要,那么首选实现是LinkedHashSet。我们可以如下使用它:
Set<String> set = new HashSet<>();
set.add("alpha");
set.add("beta");
set.add("beta");
set.add("gamma");
System.out.println(set);
List和Set都实现了Iterable接口,这使得它们可以使用for-each循环:
for (String el : set) {
System.out.println(el);
}
Map接口允许将键映射到值,在其他语言中有时被称为字典或关联数组。g0-to 实现是HashMap:
Map<String, String> map = new HashMap<>();
map.put("alpha", "α");
map.put("beta", "β");
map.put("gamma", "γ");
System.out.println(map);
如果需要保持插入顺序,可以使用LinkedHashMap;如果你知道map接口将被多线程访问,使用ConcurrentHashMap。
Collections类提供了几个助手方法来处理集合,比如排序,或者提取max或者min元素:
String min = Collections.min(list);
String max = Collections.max(list);
System.out.println("min: " + min + ", max: " + max);
Collections.sort(list);
Collections.shuffle(list);
还有其他集合,比如Queue、Deque、Stack、线程安全集合等等。它们不太常用,对数据科学也不是很重要。
输入/输出
数据科学家经常使用文件和其他数据源。I/O 是从数据源读取和写回结果所必需的。Java I/O API 为此提供了两种主要的抽象类型:
InputStream、OutputStream为二进制数据Reader、Writer为文本数据
典型的数据科学应用处理文本而不是原始的二进制数据——数据通常以 TXT、CSV、JSON 和其他类似的文本格式存储。这就是为什么我们将集中讨论第二部分。
读取输入数据
能够读取数据是数据科学家最重要的技能,这些数据通常是文本格式,可以是 TXT、CSV 或任何其他格式。在 Java I/O API 中,Reader类的子类处理读取文本文件。
假设我们有一个包含一些句子的text.txt文件(这些句子可能有意义,也可能没有意义):
- 我的狗也喜欢吃香肠
- 除了盈余外,马达还能接受
- 每一个有能力的斜线成功与世界范围的指责
- 持续的任务在内疚的吻周围咳嗽
如果需要将整个文件作为字符串列表读取,通常的 Java I/O 方式是使用BufferedReader:
List<String> lines = new ArrayList<>();
try (InputStream is = new FileInputStream("data/text.txt")) {
try (InputStreamReader isReader = new InputStreamReader(is,
StandardCharsets.UTF_8)) {
try (BufferedReader reader = new BufferedReader(isReader)) {
while (true) {
String line = reader.readLine();
if (line == null) {
break;
}
lines.add(line);
}
isReader.close();
}
}
}
提供字符编码很重要——这样,Reader就知道如何将字节序列转换成合适的String对象。除了 UTF 8,还有 UTF-16,ISO-8859(这是基于 ASCII 的英语文本编码)和许多其他标准。
直接获取文件的BufferedReader有一个快捷方式:
Path path = Paths.get("data/text.txt");
try (BufferedReader reader = Files.newBufferedReader(path,
StandardCharsets.UTF_8)) {
// read line-by-line
}
即使使用这种快捷方式,您也可以看到,对于从文件中读取一系列行这样简单的任务来说,这是非常冗长的。您可以将它封装在一个助手函数中,或者使用 Java NIO API,它提供了一些助手方法来简化这项任务:
Path path = Paths.get("data/text.txt");
List<String> lines = Files.readAllLines(path, StandardCharsets.UTF_8);
System.out.println(lines);
Java NIO 快捷方式仅适用于文件。稍后,我们将讨论适用于任何 InputStream 对象的快捷方式,而不仅仅是文件。
写入输出数据
在数据被读取和处理后,我们通常希望将它放回磁盘。对于文本,这通常使用Writer对象来完成。
假设我们从text.txt中读取句子,我们需要将每一行转换成大写,并将它们写回一个新文件output.txt;编写文本数据最方便的方式是通过PrintWriter类:
try (PrintWriter writer = new PrintWriter("output.txt", "UTF-8")) {
for (String line : lines) {
String upperCase = line.toUpperCase(Locale.US);
writer.println(upperCase);
}
}
在 Java NIO API 中,它看起来像这样:
Path output = Paths.get("output.txt");
try (BufferedWriter writer = Files.newBufferedWriter(output,
StandardCharsets.UTF_8)) {
for (String line : lines) {
String upperCase = line.toUpperCase(Locale.US);
writer.write(upperCase);
writer.newLine();
}
}
两种方法都是正确的,你应该选择你喜欢的那一种。然而,记住总是包括编码是很重要的;否则,它可能会使用一些依赖于平台的默认值,有时是任意的。
流式 API
Java 8 是 Java 语言历史上的一大进步。在其他特性中,有两个重要的东西——流和 Lambda 表达式。
在 Java 中,流是一个对象序列,Streams API 提供了函数式的操作来转换这些序列,比如 map、filter 和 reduce。流的源可以是包含元素的任何东西,例如数组、集合或文件。
例如,让我们创建一个简单的Word类,它包含一个令牌及其词性:
public class Word {
private final String token;
private final String pos;
// constructor and getters are omitted
}
为了简洁起见,我们将总是省略这种数据类的构造函数和 getters,但是用注释指出这一点。
现在,我们来考虑一句话我的狗也喜欢吃香肠。使用这个类,我们可以将其表示如下:
Word[] array = { new Word("My", "RPR"), new Word("dog", "NN"),
new Word("also", "RB"), new Word("likes", "VB"),
new Word("eating", "VB"), new Word("sausage", "NN"),
new Word(".", ".") };
这里,我们使用 Penn Treebank POS 符号,其中NN表示名词,或者VB表示动词。
现在,我们可以使用Arrays.stream实用程序方法将这个数组转换成一个流:
Stream<Word> stream = Arrays.stream(array);
可以使用stream方法从集合中创建流:
List<Word> list = Arrays.asList(array);
Stream<Word> stream = list.stream();
流上的操作被链接在一起,形成了美观易读的数据处理管道。对流最常见的操作是映射和过滤操作:
- Map 对每个元素应用相同的变换函数
- 给定一个谓词函数,Filter 过滤掉不满足它的元素
在管道的末端,您使用收集器收集结果。Collectors类提供了几个实现,比如toList、toSet、toMap等等。
假设我们只想保留名词标记。使用 Streams API,我们可以如下操作:
List<String> nouns = list.stream()
.filter(w -> "NN".equals(w.getPos()))
.map(Word::getToken)
.collect(Collectors.toList());
System.out.println(nouns);
或者,我们可能想要检查流中有多少唯一的 POS 标签。为此,我们可以使用toSet收集器:
Set<String> pos = list.stream()
.map(Word::getPos)
.collect(Collectors.toSet());
System.out.println(pos);
在处理文本时,我们有时可能希望将一系列字符串连接在一起:
String rawSentence = list.stream()
.map(Word::getToken)
.collect(Collectors.joining(" "));
System.out.println(rawSentence);
或者,我们可以根据单词的POS标签对它们进行分组:
Map<String, List<Word>> groupByPos = list.stream()
.collect(Collectors.groupingBy(Word::getPos));
System.out.println(groupByPos.get("VB"));
System.out.println(groupByPos.get("NN"));
此外,还有一个有用的toMap收集器,它可以使用一些字段对集合进行索引。例如,如果我们想获得从令牌到Word对象的映射,可以使用下面的代码来实现:
Map<String, Word> tokenToWord = list.stream()
.collect(Collectors.toMap(Word::getToken, Function.identity()));
System.out.println(tokenToWord.get("sausage"));
除了对象流,Streams API 还提供了原语流——整型、双精度型和其他原语流。这些流有用于统计计算的有用方法,例如sum、max、min或average。可以使用mapToInt或mapToDouble等函数将普通流转换为原始流。
例如,这就是我们如何找到句子中所有单词的最大长度:
int maxTokenLength = list.stream()
.mapToInt(w -> w.getToken().length())
.max().getAsInt();
System.out.println(maxTokenLength);
流操作易于并行化;它们分别应用于每一项,因此多线程可以做到这一点,而不会相互干扰。因此,通过将工作分散到多个处理器上并并行执行所有任务,可以大大加快这些操作的速度。
Java 利用了这一点,并提供了一种简单而富于表现力的方法来创建并行代码;对于集合,您只需要调用parallelStream方法:
int[] firstLengths = list.parallelStream()
.filter(w -> w.getToken().length() % 2 == 0)
.map(Word::getToken)
.mapToInt(String::length)
.sequential()
.sorted()
.limit(2)
.toArray();
System.out.println(Arrays.toString(firstLengths));
在这个例子中,过滤和映射是并行完成的,但是随后流被转换成顺序流,被排序,并且最上面的两个元素被提取到一个数组中。虽然这个例子不是很有意义,但是它展示了使用流可以做多少事情。
最后,标准 Java I/O 库提供了一些方便的方法。例如,可以使用Files.lines方法将文本文件表示为一串行:
Path path = Paths.get("text.txt");
try (Stream<String> lines = Files.lines(path, StandardCharsets.UTF_8)) {
double average = lines
.flatMap(line -> Arrays.stream(line.split(" ")))
.map(String::toLowerCase)
.mapToInt(String::length)
.average().getAsDouble();
System.out.println("average token length: " + average);
}
流是处理数据的一种表达性强、功能强大的方式,掌握这个 API 对于用 Java 做数据科学非常有帮助。稍后,我们将经常使用 Stream API,因此您将看到更多如何使用它的示例。
标准库的扩展
标准 Java 库非常强大,但是有些东西需要花很长时间来编写,或者根本就没有。标准库有许多扩展,最突出的库是 Apache Commons(一组库)和 Google Guava。它们使得使用标准 API 或扩展它变得更加容易,例如,通过添加新的集合或实现。
首先,我们将简要介绍这些库的最相关部分,稍后我们将看到它们在实践中是如何使用的。
Apache common(Apache 公共)
Apache Commons 是 Java 开源库的集合,目标是创建可重用的 Java 组件。有很多,包括 Apache Commons Lang、Apache Commons IO、Apache Commons Collections 等等。
公共语言
Apache Commons Lang 是一组扩展了java.util包的实用程序类,它们通过提供许多解决常见问题并节省大量时间的小方法,使 Java 开发人员的生活变得更加轻松。
为了在 Java 中包含外部库,我们通常使用 Maven,这使得管理依赖关系变得非常容易。有了 Maven,Apache Commons Lang 库可以使用下面的dependency片段来包含:
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>3.4</version>
</dependency>
该库包含了许多对通用 Java 编程有用的方法,比如使实现对象的equals和hashCode方法、序列化助手和其他方法变得更加容易。总的来说,它们并不特别适用于数据科学,但是有一些辅助函数非常有用。举个例子,
- 用于生成数据的
RandomUtils和RandomStringUtils StringEscapeUtils和LookupTranslator用于转义和不转义字符串EqualsBuilder和HashCodeBuilder用于快速执行equals和hashCode方法StringUtils和WordUtils了解有用的字符串操作方法Pair类
如需了解更多信息,您可以阅读位于 commons.apache.org/lang 的文档。
查看可用内容的最佳方式是下载该包并查看其中的可用代码。每个 Java 开发者都会发现很多有用的东西。
公共 IO
像 Apache Commons Lang 扩展了java.util标准包,Apache Commons IO 扩展了java.io;这是一个 Java 实用程序库,用于协助 Java 中的 I/O,正如我们之前了解到的,它可能非常冗长。
要在您的项目中包含该库,请将dependency片段添加到pom.xml文件中:
<dependency>
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
<version>2.5</version>
</dependency>
我们已经从 Java NIO 中了解了Files.lines。虽然它很方便,但我们并不总是处理文件,有时需要从其他InputStream获取行,例如,一个网页或一个 web 套接字。
为此,Commons IO 提供了一个非常有用的实用程序类IOUtils。使用它,将整个输入流读入字符串或字符串列表非常容易:
try (InputStream is = new FileInputStream("data/text.txt")) {
String content = IOUtils.toString(is, StandardCharsets.UTF_8);
System.out.println(content);
}
try (InputStream is = new FileInputStream("data/text.txt")) {
List<String> lines = IOUtils.readLines(is, StandardCharsets.UTF_8);
System.out.println(lines);
}
虽然我们在这个例子中使用了FileInputStream对象,但是它可以是任何其他的流。第一种方法IOUtils.toString特别有用,我们稍后将使用它来抓取网页和处理来自 web 服务的响应。
在这个库中有很多更有用的 I/O 方法,为了得到一个好的概述,你可以参考在commons.apache.org/io获得的文档。
公共收藏
Java Collections API 非常强大,它为 Java 中的数据结构定义了一组很好的抽象。Commons 集合使用这些抽象,并用新的实现和新的集合来扩展标准集合 API。要包含该库,请使用以下代码片段:
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-collections4</artifactId>
<version>4.1</version>
</dependency>
该库中一些有用的收藏有:
Bag:这是可以多次保存同一元素的集合的接口BidiMap:这代表双向地图。它可以从键映射到值,也可以从值映射到键
它与 Google Guava 中的集合有一些重叠,这将在下一节课中解释,但是它还有一些没有实现的附加集合。举个例子,
LRUMap:用于实现缓存PatriciaTrie:用于快速字符串前缀查找
其他公共模块
Commons Lang、IO 和 Collections 是众多 Commons 库中的几个。还有其他对数据科学有用的公共模块:
- Commons compress 用于读取压缩文件,例如,
bzip2(用于读取维基百科转储)、gzip、7z等 - Commons CSV 用于读取和写入 CSV 文件(我们将在后面使用它)
- Commons math 用于统计计算和线性代数(我们稍后也会用到)
你可以参考commons.apache.org/的完整列表。
谷歌番石榴
谷歌番石榴很像阿帕奇 Commons 它是一组实用程序,扩展了标准的 Java API,使生活变得更加简单。但是与 Apache Commons 不同的是,Google Guava 是一个同时涵盖许多领域的库,包括集合和 I/O。
要将其包含在项目中,请使用dependency:
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<version>19.0</version>
</dependency>
我们将从 Guava I/O 模块开始。为了举例说明,我们将使用一些生成的数据。我们已经使用了word类,它包含一个令牌及其词性标签,这里我们将生成更多的单词。要做到这一点,我们可以使用数据生成工具,如 www.generatedata.com/和 T2。让我们定义下面的模式,如下面的屏幕截图所示:
之后,可以将生成的数据保存为 CSV 格式,将分隔符设置为制表符(t),并将其保存到words.txt。我们已经为您生成了一个文件;你可以在chapter2库中找到它。
Guava 定义了一些处理 I/O 的抽象。其中之一是CharSource,它是对任何基于字符的数据源的抽象,在某种意义上,它与标准的Reader类非常相似。此外,与 Commons IO 类似,还有一个用于处理文件的实用程序类。它被称为Files(不要与java.nio.file.Files混淆),并且包含帮助函数,使文件 I/O 更容易。使用这个类,可以读取文本文件的所有行,如下所示:
File file = new File("data/words.txt");
CharSource wordsSource = Files.asCharSource(file, StandardCharsets.UTF_8);
List<String> lines = wordsSource.readLines();
Google Guava Collections 遵循与 Commons Collections 相同的理念;它建立在标准集合 API 的基础上,提供了新的实现和抽象。有一些实用程序类,比如用于列表的Lists,用于集合的Sets,等等。
Lists中的一个方法是transform,它就像流中的map,它应用于列表中的每个元素。结果列表的元素被延迟评估;仅当需要元素时,才触发函数的计算。让我们用它将文本文件中的行转换成一列Word对象:
List<Word> words = Lists.transform(lines, line -> {
String[] split = line.split("t");
return new Word(split[0].toLowerCase(), split[1]);
});
这与 Streams API 中的 map 的主要区别在于,transform 会立即返回一个列表,因此不需要首先创建一个流,调用map函数,最后将结果收集到列表中。
与 Commons 集合类似,Java API 中也有新的集合不可用。对数据科学最有用的集合是Multiset、Multimap和Table。
Multisets 是同一元素可以多次存储的集合,通常用于计数。当我们想要计算每个术语出现的次数时,这个类对于文本处理特别有用。
让我们看看我们读到的单词,并计算每个pos标签出现的次数:
Multiset<String> pos = HashMultiset.create();
for (Word word : words) {
pos.add(word.getPos());
}
如果我们想输出按计数排序的结果,有一个特殊的实用函数:
Multiset<String> sortedPos = Multisets.copyHighestCountFirst(pos);
System.out.println(sortedPos);
Multimap 是一个映射,每个键可以有多个值。多地图有几种类型。两种最常见的地图如下:
ListMultimap:这将一个键与一个值列表相关联,类似于Map<Key, List<Value>>SetMultimap:这将一个键与一组值相关联,类似于Map<Key, Set<Value>>
这对于实现group by逻辑非常有用。让我们看看每个POS标签的平均长度:
ArrayListMultimap<String, String> wordsByPos = ArrayListMultimap.create();
for (Word word : words) {
wordsByPos.put(word.getPos(), word.getToken());
}
可以将多地图视为集合地图:
Map<String, Collection<String>> wordsByPosMap = wordsByPos.asMap();
wordsByPosMap.entrySet().forEach(System.out::println);
最后,Table集合可以看作是map接口的二维扩展;现在,不是一个键,每个条目由两个键索引,row键和column键。除此之外,还可以使用column键获得整列,或者使用row键获得一行。
例如,我们可以计算每个(单词、词性)对在数据集中出现的次数:
Table<String, String, Integer> table = HashBasedTable.create();
for (Word word : words) {
Integer cnt = table.get(word.getPos(), word.getToken());
if (cnt == null) {
cnt = 0;
}
table.put(word.getPos(), word.getToken(), cnt + 1);
}
将数据放入表中后,我们可以分别访问行和列:
Map<String, Integer> nouns = table.row("NN");
System.out.println(nouns);
String word = "eu";
Map<String, Integer> posTags = table.column(word);
System.out.println(posTags);
像在 Commons Lang 中一样,Guava 也包含用于处理原语的实用程序类,比如用Ints处理int原语,用Doubles处理double原语,等等。例如,它可用于将原始包装的集合转换为原始数组:
Collection<Integer> values = nouns.values();
int[] nounCounts = Ints.toArray(values);
int totalNounCount = Arrays.stream(nounCounts).sum();
System.out.println(totalNounCount);
最后,Guava 为排序数据提供了一个很好的抽象- Ordering,它扩展了标准的Comparator接口。它为创建比较器提供了清晰流畅的界面:
Ordering<Word> byTokenLength =
Ordering.natural().<Word> onResultOf(w -> w.getToken().length()).reverse();
List<Word> sortedByLength = byTokenLength.immutableSortedCopy(words);
System.out.println(sortedByLength);
由于Ordering实现了Comparator接口,它可以用于任何需要比较器的地方。例如,对于Collections.sort:
List<Word> sortedCopy = new ArrayList<>(words);
Collections.sort(sortedCopy, byTokenLength);
除此之外,它还提供了其他方法,如提取 top-k 或 bottom-k 元素:
List<Word> first10 = byTokenLength.leastOf(words, 10);
System.out.println(first10);
List<Word> last10 = byTokenLength.greatestOf(words, 10);
System.out.println(last10);
它与先排序,然后取第一个或最后一个 k 元素相同,但效率更高。
还有其他有用的类:
- 可定制的哈希实现,如杂音哈希和其他
Stopwatch用于测量时间
更多见解,可以参考github.com/google/guav…和github.com/google/guav…。
你可能已经注意到了,番石榴和阿帕奇共有地有很多共同点。选择使用哪一个是个人喜好的问题——两个库都经过了很好的测试,并在许多生产系统中得到积极的应用。但是番石榴的开发更加积极,新功能出现的频率也更高,所以如果你只想用其中的一种,那么番石榴可能是更好的选择。
美国在线独眼巨人反应
正如我们已经了解的,Java Streams API 是以函数方式处理数据的一种非常强大的方式。Cyclops React 库通过在流上添加新的操作来扩展这个 API,并允许对流执行进行更多的控制。要包含该库,将其添加到pom.xml文件:
<dependency>
<groupId>com.aol.simplereact</groupId>
<artifactId>cyclops-react</artifactId>
<version>1.0.0-RC4</version>
</dependency>
它添加的一些方法是zipWithIndex和 cast 以及便利收集器,如toList、toSet和toMap。更重要的是,它为并行执行提供了更多的控制,例如,可以提供一个自定义执行器,用于处理数据或以声明方式拦截异常。
同样,使用这个库,很容易从迭代器创建并行流——使用标准库很难做到这一点。
例如,我们以words.txt为例,从中提取所有的 POS 标签,然后创建一个映射,将每个标签与一个惟一的索引相关联。对于读取数据,我们将使用来自 Commons IO 的LineIterator,否则仅使用标准 Java APIs 很难并行化。此外,我们创建一个定制的执行器,它将用于并行执行流操作:
LineIterator it = FileUtils.lineIterator(new File("data/words.txt"), "UTF-8");
ExecutorService executor = Executors.newCachedThreadPool();
LazyFutureStream<String> stream =
LazyReact.parallelBuilder().withExecutor(executor).from(it);
Map<String, Integer> map = stream
.map(line -> line.split("t"))
.map(arr -> arr[1].toLowerCase())
.distinct()
.zipWithIndex()
.toMap(Tuple2::v1, t -> t.v2.intValue());
System.out.println(map);
executor.shutdown();
it.close();
这是一个非常简单的例子,并没有描述这个库中所有可用的功能。欲了解更多信息,请参考他们的文档,可在github.com/aol/cyclops…找到。我们还会在后面章节的其他例子中用到它。
访问数据
到目前为止,我们已经花了很多时间来描述如何读写数据。但是还有更多的内容:数据通常以不同的格式出现,比如 CSV、HTML 或 JSON,或者可以存储在数据库中。了解如何访问和处理这些数据对于数据科学非常重要,现在我们将详细描述如何对最常见的数据格式和数据源进行访问和处理。
文本数据和 CSV
我们已经详细讨论过读取文本数据,例如,可以使用 NIO API 中的Files helper 类或 Commons IO 中的IOUtils来完成。
CSV(逗号分隔值)是在纯文本文件中组织表格数据的常用方法。虽然手动解析 CSV 文件是可能的,但是有一些极限情况,这使得它有点麻烦。幸运的是,有很好的库可以实现这个目的,其中之一就是 Apache Commons CSV:
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-csv</artifactId>
<version>1.4</version>
</dependency>
为了说明如何使用这个库,让我们再次生成一些随机数据。这次我们也可以使用www.generatedata.com/并定义以下模式:
现在我们可以创建一个特殊的类来保存这些数据:
public static class Person {
private final String name;
private final String email;
private final String country;
private final int salary;
private final int experience;
// constructor and getters are omitted
}
然后,要读取 CSV 文件,您可以执行以下操作:
List<Person> result = new ArrayList<>();
Path csvFile = Paths.get("data/csv-example-generatedata_com.csv");
try (BufferedReader reader = Files.newBufferedReader(csvFile, StandardCharsets.UTF_8)) {
CSVFormat csv = CSVFormat.RFC4180.withHeader();
try (CSVParser parser = csv.parse(reader)) {
Iterator<CSVRecord> it = parser.iterator();
it.forEachRemaining(rec -> {
String name = rec.get("name");
String email = rec.get("email");
String country = rec.get("country");
int salary = Integer.parseInt(rec.get("salary").substring(1));
int experience = Integer.parseInt(rec.get("experience"));
Person person = new Person(name, email, country, salary, experience);
result.add(person);
});
}
}
前面的代码创建了一个CSVRecord对象的迭代器,我们从每个这样的对象中提取值,并将它们传递给一个数据对象。当 CSV 文件非常大,可能无法完全容纳在可用内存中时,创建迭代器非常有用。
如果文件不太大,也可以一次读取整个 CSV 并将结果放入一个列表中:
List<CSVRecord> records = parse.getRecords();
最后,制表符分隔的文件可以看作是 CSV 的一个特例,也可以使用这个库来读取。为此,您只需使用TDF格式进行解析:
CSVFormat tsv = CSVFormat.TDF.withHeader();
Web 和 HTML
现在互联网上有大量的数据,能够访问这些数据并将其转换成机器可读的东西对于数据科学家来说是一项非常重要的技能。
从互联网上获取数据有多种方式。幸运的是,标准 Java API 提供了一个特殊的类来做这件事,URL。使用URL,可以打开一个InputStream,它将包含响应体。对于网页,通常是它的 HTML 内容。使用 Commons IO 的IOUtils,这样做很简单:
try (InputStream is = new URL(url).openStream()) {
return IOUtils.toString(is, StandardCharsets.UTF_8);
}
这段代码相当有用,所以把它放入某个 helper 方法,比如UrlUtils.request,会很有帮助。
这里我们假设网页的内容总是 UTF-8。它可能在很多情况下都有效,尤其是对于英文页面,但偶尔也会失败。对于更复杂的爬虫,可以使用来自 Apache Tika(tika.apache.org/)的编码检测。
前面的方法返回原始的 HTML 数据,这本身是没有用的;大多数时候,我们感兴趣的是文本内容,而不是标记。有一些用于处理 HTML 的库,其中之一是 Jsoup:
<dependency>
<groupId>org.jsoup</groupId>
<artifactId>jsoup</artifactId>
<version>1.9.2</version>
</dependency>
让我们考虑一个例子。Kaggle.com 是一个举办数据科学竞赛的网站,每个竞赛都有一个排行榜,显示每个参与者的表现。假设你有兴趣从https://www . ka ggle . com/c/avito-duplicate-ads-detection/leader board中提取这些信息,如下图截图所示:
这些信息包含在一个表中,为了从这个表中提取数据,我们需要找到一个唯一指向这个表的锚。为此,你可以使用检查器来查看页面(在 Mozilla Firefox 或 Google Chrome 中按下 F12 将打开检查器窗口):
使用 Inspector,我们可以注意到表的 ID 是leaderboard-table,为了在 Jsoup 中获得这个表,我们可以使用下面的 CSS 选择器#leaderboard-table。因为我们实际上对表格的行感兴趣,所以我们将使用#leaderboard-table tr。
关于团队名称的信息包含在列表表格的第三列中。因此,要提取它,我们需要获取第三个<td>标签,然后查看它的<a>标签。同样,为了提取分数,我们获取第四个<td>标签的内容。
执行此操作的代码如下:
Map<String, Double> result = new HashMap<>();
String rawHtml = UrlUtils.request("https://www.kaggle.com/c/avito-duplicate-ads-detection/leaderboard");
Document document = Jsoup.parse(rawHtml);
Elements tableRows = document.select("#leaderboard-table tr");
for (Element tr : tableRows) {
Elements columns = tr.select("td");
if (columns.isEmpty()) {
continue;
}
String team = columns.get(2).select("a.team-link").text();
double score = Double.parseDouble(columns.get(3).text());
result.put(team, score);
}
Comparator<Map.Entry<String, Double>> byValue = Map.Entry.comparingByValue();
result.entrySet().stream()
.sorted(byValue.reversed())
.forEach(System.out::println);
这里我们重用UrlUtils.request函数来获取我们之前定义的 HTML,然后用 Jsoup 处理它。
Jsoup 利用 CSS 选择器来访问解析后的 HTML 文档中的条目。要了解更多,你可以阅读相关的文档,可以在jsoup.org/cookbook/ex…找到。
JSON
JSON 作为 web 服务之间的通信方式越来越受欢迎,逐渐取代了 XML 和其他格式。知道如何处理它可以让你从互联网上各种各样的数据源中提取数据。
有相当多的 JSON 库可用于 Java。Jackson 就是其中之一,它有一个简化版本,叫做jackson-jr,适用于大多数简单的情况,我们只需要从 JSON 中快速提取数据。要添加它,请使用以下命令:
<dependency>
<groupId>com.fasterxml.jackson.jr</groupId>
<artifactId>jackson-jr-all</artifactId>
<version>2.8.1</version>
</dependency>
为了说明这一点,让我们考虑一个返回 JSON 的简单 API。我们可以使用www.jsontest.com/,它提供了许多虚拟的 web 服务。其中之一是 md5.jsontest.com[的 MD5 服务;给定一个字符串,它返回它的 MD5 散列。](md5.jsontest.com/)
以下是它的输出示例:
{
"original": "mastering java for data science",
"md5": "f4c8637d7274f13b58940ff29f669e8a"
}
让我们使用它:
String text = "mastering java for data science";
String json = UrlUtils.request("http://md5.jsontest.com/?text=" + text.replace(' ', '+'));
Map<String, Object> map = JSON.std.mapFrom(json);
System.out.println(map.get("original"));
System.out.println(map.get("md5"));
在这个例子中,web 服务的 JSON 响应非常简单。然而,列表和嵌套对象有更复杂的情况。比如,www.github.com提供了很多 API,其中一个就是api.github.com/users/alexe…。对于给定的用户,它返回他们所有的存储库。它有一个对象列表,每个对象都有一个嵌套对象。
在具有动态类型的语言中,比如 Python,这很简单——这种语言并不强迫您拥有特定的类型,对于这种特殊情况来说,这很好。然而,在 Java 中,静态类型系统要求定义一个类型;每次需要提取东西的时候,都需要做铸造。
例如,如果我们想获得第一个对象的元素 ID,我们需要做这样的事情:
String username = "alexeygrigorev";
String json = UrlUtils.request("https://api.github.com/users/" + username + "/repos");
List<Map<String, ?>> list = (List<Map<String, ?>>) JSON.std.anyFrom(json);
String name = (String) list.get(0).get("name");
System.out.println(name);
如您所见,我们需要进行大量的类型转换,代码很快变得非常混乱。一个解决方案可能是使用一种类似于 Xpath 的查询语言,称为 JsonPath。可在github.com/jayway/Json…访问 Java 的实现。要使用它,请添加以下内容:
<dependency>
<groupId>com.jayway.jsonpath</groupId>
<artifactId>json-path</artifactId>
<version>2.2.0</version>
</dependency>
如果我们想要检索所有用 Java 编写的存储库,并且至少有一个 start,那么下面的查询就可以做到:
ReadContext ctx = JsonPath.parse(json);
String query = "$..[?(@.language=='Java' && @.stargazers_count > 0)]full_name";
List<String> javaProjects = ctx.read(query);
这肯定会为过滤等简单的数据操作节省一些时间,但不幸的是,对于更复杂的事情,您可能仍然需要进行大量的强制转换来进行手动转换。
对于更复杂的查询(例如,发送 POST 请求),最好使用特殊的库,比如 Apache http components(hc.apache.org/)。
数据库
在组织中,数据通常保存在关系数据库中。Java 将 Java 数据库连接 ( JDBC )定义为访问任何支持 SQL 的数据库的抽象。
在我们的例子中,我们将使用 MySQL,它可以从www.mysql.com/下载,但原则上它可以是任何其他数据库,如 PostgreSQL、Oracle、MS SQL 和许多其他数据库。要连接到 MySQL 服务器,我们可以使用 JDBC MySQL 驱动程序:
<dependency>
<groupId>mysql</groupId>
<artifactId>mysql-connector-java</artifactId>
<version>5.1.39</version>
</dependency>
如果你想使用不同的数据库,那么你可以使用你最喜欢的搜索引擎,找到合适的 JDBC 驱动程序。交互代码将保持不变,如果您使用标准 SQL,查询代码应该也不会改变。
例如,我们将使用为 CSV 示例生成的相同数据。首先,我们将它加载到数据库中,然后进行简单的选择。
让我们定义以下模式:
CREATE SCHEMA `people` DEFAULT CHARACTER SET utf8 ;
CREATE TABLE `people`.`people` (
`person_id` INT UNSIGNED NOT NULL AUTO_INCREMENT,
`name` VARCHAR(45) NULL,
`email` VARCHAR(100) NULL,
`country` VARCHAR(45) NULL,
`salary` INT NULL,
`experience` INT NULL,
PRIMARY KEY (`person_id`));
现在,为了连接到数据库,我们通常使用DataSource抽象。MySQL 驱动提供了一个实现:MysqlDataSource:
MysqlDataSource datasource = new MysqlDataSource();
datasource.setServerName("localhost");
datasource.setDatabaseName("people");
datasource.setUser("root");
datasource.setPassword("");
现在使用DataSource对象,我们可以加载数据。做这件事有两种方法:简单的方法是,当我们单独插入每个对象时,以及批处理模式,其中我们首先准备一个批处理,然后插入一个批处理的所有对象。批处理模式选项通常工作得更快。
我们先来看看通常的模式:
try (Connection connection = datasource.getConnection()) {
String sql = "INSERT INTO people (name, email, country, salary, experience) VALUES (?, ?, ?, ?, ?);";
try (PreparedStatement statement = connection.prepareStatement(sql)) {
for (Person person : people) {
statement.setString(1, person.getName());
statement.setString(2, person.getEmail());
statement.setString(3, person.getCountry());
statement.setInt(4, person.getSalary());
statement.setInt(5, person.getExperience());
statement.execute();
}
}
}
注意,在 JDBC 中,索引的枚举从 1 开始,而不是从 0 开始。
批次很相似。为了准备批处理,我们首先使用来自 Guava 的Lists.partition函数,并将所有数据放入 50 个对象的批处理中。然后使用addBatch函数将块中的每个对象添加到一个批处理中:
List<List<Person>> chunks = Lists.partition(people, 50);
try (Connection connection = datasource.getConnection()) {
String sql = "INSERT INTO people (name, email, country, salary, experience) VALUES (?, ?, ?, ?, ?);";
try (PreparedStatement statement = connection.prepareStatement(sql)) {
for (List<Person> chunk : chunks) {
for (Person person : chunk) {
statement.setString(1, person.getName());
statement.setString(2, person.getEmail());
statement.setString(3, person.getCountry());
statement.setInt(4, person.getSalary());
statement.setInt(5, person.getExperience());
statement.addBatch();
}
statement.executeBatch();
}
}
}
批处理模式比通常的数据处理方式更快,但需要更多的内存。如果您需要处理大量数据,并且关心速度,那么批处理模式是一个更好的选择,但是它会使代码变得更加复杂。因此,使用更简单的方法可能更好。
现在,当数据被加载时,我们可以查询数据库。例如,让我们选择一个国家的所有人:
String country = "Greenland";
try (Connection connection = datasource.getConnection()) {
String sql = "SELECT name, email, salary, experience FROM people WHERE country = ?;";
try (PreparedStatement statement = connection.prepareStatement(sql)) {
List<Person> result = new ArrayList<>();
statement.setString(1, country);
try (ResultSet rs = statement.executeQuery()) {
while (rs.next()) {
String name = rs.getString(1);
String email = rs.getString(2);
int salary = rs.getInt(3);
int experience = rs.getInt(4);
Person person = new Person(name, email, country, salary, experience);
result.add(person);
}
}
}
}
这样,我们就可以执行任何想要的 SQL 查询,并将结果放入 Java 对象中进行进一步处理。
你可能已经注意到在 JDBC 有很多样板代码。样板文件可以用 Spring JDBC 模板库来简化(见www.springframework.org)。
DataFrames
数据帧是在内存中表示表格数据的一种便捷方式。最初 DataFrames 来自 R 编程语言,但现在在其他语言中也很常见;例如,在 Python 中,pandas 库提供了一个类似于 R 的 DataFrame 数据结构。
Java 中存储数据的通常方式是列表、映射和其他对象集合。我们可以把这些集合想象成表,但是我们只能通过行来评估数据。然而,对于数据科学来说,操作列同样重要,这也是数据框架有用的地方。
例如,它们允许您对同一列的所有值应用相同的函数,或者查看值的分布。
在 Java 中,没有太多成熟的实现,但是有一些实现具有所有需要的功能。在我们的例子中,我们将使用joinery:
<dependency>
<groupId>joinery</groupId>
<artifactId>joinery-dataframe</artifactId>
<version>1.7</version>
</dependency>
遗憾的是,joinery在 Maven Central 上不可用;因此,要将它包含到 Maven 项目中,您需要指向另一个 Maven 存储库bintray。为此,将这个repository添加到pom文件的存储库部分:
<repository>
<id>bintray</id>
<url>http://jcenter.bintray.com</url>
</repository>
细木工取决于Apache POI,所以你也需要加上:
<dependency>
<groupId>org.apache.poi</groupId>
<artifactId>poi</artifactId>
<version>3.14</version>
</dependency>
使用细木工技术,读取数据非常容易:
DataFrame<Object> df = DataFrame.readCsv("data/csv-example-generatedata_com.csv");
一旦数据被读取,我们不仅可以访问数据帧的每一行,还可以访问每一列。给定一个列名,Joinery 返回一个存储在列中的值的List,我们可以用它对它进行各种转换。
例如,假设我们希望将我们拥有的每个国家与一个唯一的索引相关联。我们可以这样做:
List<Object> country = df.col("country");
Map<String, Long> map = LazyReact.sequentialBuilder()
.from(country)
.cast(String.class)
.distinct()
.zipWithIndex()
.toMap(Tuple2::v1, Tuple2::v2);
List<Object> indexes = country.stream().map(map::get).collect(Collectors.toList());
之后,我们可以删除带有country的旧列,并包含新的索引列:
df = df.drop("country");
df.add("country_index", indexes);
System.out.println(df);
Joinery 可以做更多的事情——分组、连接、旋转和为机器学习模型创建设计矩阵。我们将在以后的几乎所有章节中再次使用它。同时,你可以在 cardillo.github.io/joinery/阅读更…
搜索引擎-准备数据
在第一章中,我们介绍了运行示例,构建一个搜索引擎。搜索引擎是这样一种程序,给定用户的查询,它返回按照与该查询的相关性排序的结果。在本章中,我们将执行第一步-获取和处理数据。
假设我们在一个门户网站上工作,用户生成了很多内容,但是他们很难找到其他人所创建的内容。为了克服这个问题,我们建议建立一个搜索引擎,产品管理部门已经确定了用户会提出的典型查询。
例如,“中国食物”,“自制披萨”,“如何学习编程”是这个列表中的典型查询。
现在我们需要收集数据。幸运的是,互联网上已经有搜索引擎可以接受查询并返回他们认为相关的 URL 列表。我们可以用它们来获取数据。你可能已经知道这样的引擎——谷歌或必应,仅举两例。
因此,我们可以应用我们在这一章中学到的知识,使用 JSoup 解析来自 Google、Bing 或任何其他搜索引擎的数据。或者,也可以使用诸如flow-app.com/这样的服务来帮你提取,但是需要注册。
最后,我们需要的是一个查询和一个最相关的 URL 列表。提取相关的 URL 作为一个练习,但是我们已经准备了一些结果,如果你愿意的话可以使用:对于每个查询,从搜索结果的前三页有 30 个最相关的页面。此外,您可以在代码包中找到对爬行有用的代码。
现在,当我们有了 URL,我们感兴趣的是检索它们并保存它们的 HTML 代码。为此,我们需要一个比我们已经有的UrlUtils.request更智能的爬虫。
我们必须添加的一个特别的东西是超时:一些页面需要很长时间来加载,因为它们要么很大,要么服务器遇到一些问题,需要一段时间来响应。在这些情况下,当一个页面在 30 秒内无法下载时,放弃是有意义的。
在 Java 中,这可以用Executors来完成。首先,让我们创建一个Crawler类,并声明executor字段:
int numProc = Runtime.getRuntime().availableProcessors();
executor = Executors.newFixedThreadPool(numProc);
然后,我们可以如下使用这个执行程序:
try {
Future<String> future = executor.submit(() -> UrlUtils.request(url));
String result = future.get(30, TimeUnit.SECONDS);
return Optional.of(result);
} catch (TimeoutException e) {
LOGGER.warn("timeout exception: could not crawl {} in {} sec", url, timeout);
return Optional.empty();
}
这段代码将删除花费太长时间检索的页面。
我们需要将抓取的 HTML 页面存储在某个地方。有几个选择:文件系统上的一堆 HTML 文件、关系存储(如 MySQL)或键值存储。键值存储看起来是最好的选择,因为我们有一个键,URL,和值,HTML。为此,我们可以使用 MapDB,这是一个实现了Map接口的纯 Java 键值存储。本质上,它是一个由磁盘上的文件支持的Map。
因为它是纯 Java,所以使用它所需要做的就是包含它的依赖项:
<dependency>
<groupId>org.mapdb</groupId>
<artifactId>mapdb</artifactId>
<version>3.0.1</version>
</dependency>
现在我们可以使用它:
DB db = DBMaker.fileDB("urls.db").closeOnJvmShutdown().make();
HTreeMap<?, ?> htreeMap = db.hashMap("urls").createOrOpen();
Map<String, String> urls = (Map<String, String>) htreeMap;
因为它实现了Map接口,所以它可以被当作普通的Map来处理,我们可以把 HTML 放在那里。因此,让我们读取相关的 URL,下载它们的 HTML,并保存到地图:
Path path = Paths.get("data/search-results.txt");
List<String> lines = FileUtils.readLines(path.toFile(), StandardCharsets.UTF_8);
lines.parallelStream()
.map(line -> line.split("t"))
.map(split -> "http://" + split[2])
.distinct()
.forEach(url -> {
try {
Optional<String> html = crawler.crawl(url);
if (html.isPresent()) {
LOGGER.debug("successfully crawled {}", url);
urls.put(url, html.get());
}
} catch (Exception e) {
LOGGER.error("got exception when processing url {}", url, e);
}
});
这里我们在parallelStream中这样做是为了加快执行速度。超时将确保它在合理的时间内完成。
首先,让我们从页面中提取一些非常简单的内容,如下所示:
- URL 的长度
- 标题的长度
- 无论查询是否包含在标题中
- 正文中整个文本的长度
<h1>-<h6>标签的数量- 链接数量
为了保存这些信息,我们可以创建一个特殊的类,RankedPage。
public class RankedPage {
private String url;
private int position;
private int page;
private int titleLength;
private int bodyContentLength;
private boolean queryInTitle;
private int numberOfHeaders;
private int numberOfLinks;
// setters, getters are omitted
}
现在,让我们为每个页面创建一个这个类的对象。我们使用flatMap是因为有些 URL 没有 HTML 数据。
Stream<RankedPage> pages = lines.parallelStream().flatMap(line -> {
String[] split = line.split("t");
String query = split[0];
int position = Integer.parseInt(split[1]);
int searchPageNumber = 1 + (position - 1) / 10; // converts position to a page number
String url = "http://" + split[2];
if (!urls.containsKey(url)) { // no crawl available
return Stream.empty();
}
RankedPage page = new RankedPage(url, position, searchPageNumber);
String html = urls.get(url);
Document document = Jsoup.parse(html);
String title = document.title();
int titleLength = title.length();
page.setTitleLength(titleLength);
boolean queryInTitle = title.toLowerCase().contains(query.toLowerCase());
page.setQueryInTitle(queryInTitle);
if (document.body() == null) { // no body for the document
return Stream.empty();
}
int bodyContentLength = document.body().text().length();
page.setBodyContentLength(bodyContentLength);
int numberOfLinks = document.body().select("a").size();
page.setNumberOfLinks(numberOfLinks);
int numberOfHeaders = document.body().select("h1,h2,h3,h4,h5,h6").size();
page.setNumberOfHeaders(numberOfHeaders);
return Stream.of(page);
});
在这段代码中,我们为每个页面查找它的 HTML。如果没有被抓取,我们跳过这个页面;然后我们解析 HTML 并检索前面的基本特性。
这只是我们可以计算的可能页面特征的一小部分。稍后,我们将在此基础上添加更多功能。
在这个例子中,我们得到了一个页面流。我们可以对这个流做任何我们想做的事情,例如,将它保存到 JSON 或转换成 DataFrame。本书附带的代码包中有一些例子,展示了如何进行这些类型的转换。例如,从 Java 对象列表到 Joinery DataFrame的转换在BeanToJoinery实用程序类中可用。
摘要
处理任何数据科学问题都有几个步骤,而数据准备步骤是第一步。标准的 Java API 有大量的工具使这项任务成为可能,并且有许多库使它变得容易得多。
在这一章中,我们讨论了其中的许多,包括对 Java API 的扩展,如 Google Guava 我们讨论了从文本、HTML 和数据库等不同来源读取数据的方法;最后,我们讨论了 DataFrame,这是一种用于操作表格数据的有用结构。
在下一章中,我们将仔细查看本章中提取的数据,并执行探索性数据分析。
三、探索性数据分析
在前一章中,我们讨论了数据处理,这是将数据转换成可用于分析的形式的重要步骤。在本章中,我们在清理和查看数据之后进行下一个逻辑步骤。这一步被称为探索性数据分析 ( EDA ),它由汇总数据和创建可视化组成。
在本章中,我们将讨论以下主题:
- 使用 Apache Commons Math 和 Joinery 进行汇总统计
- Java 和 JVM 中 EDA 的交互式 shells
在本章结束时,你将知道如何计算汇总统计数据和用 Joinery 创建简单的图表。
Java 中的探索性数据分析
探索性数据分析是指获取数据集并从中提取最重要的信息,这样就有可能了解数据的样子。这包括两个主要部分:总结和可视化。
汇总步骤对于理解数据非常有帮助。对于数值变量,在这一步我们计算最重要的样本统计数据:
- 极值(最小值和最大值)
- 平均值或样本平均值
- 标准差,描述了数据的分布
我们通常会考虑其他统计数据,如中位数和四分位数(25%和 75%)。
正如我们在前一章已经看到的,Java 提供了一套很好的数据准备工具。相同的工具集可以用于 EDA,尤其是创建摘要。
搜索引擎数据集
在这一章中,我们将使用我们正在运行的例子——构建一个搜索引擎。在第二章、数据处理工具箱中,我们从搜索引擎返回的 HTML 页面中提取了一些数据。这个数据集包括一些数字特征,比如标题的长度和内容的长度。
为了存储这些功能,我们创建了以下类:
public class RankedPage {
private String url;
private int position;
private int page;
private int titleLength;
private int bodyContentLength;
private boolean queryInTitle;
private int numberOfHeaders;
private int numberOfLinks;
// setters, getters are omitted
}
看看这些信息对搜索引擎是否有用是很有趣的。例如,给定一个 URL,我们可能想知道它是否可能出现在引擎输出的第一页。通过 EDA 查看数据将有助于我们了解这是否可行。
此外,真实世界的数据很少是干净的。我们将使用 EDA 来尝试发现一些奇怪或有问题的观察结果。
让我们开始吧。我们将数据保存为 JSON 格式,现在我们可以使用 streams 和 Jackson 读回它:
Path path = Paths.get("./data/ranked-pages.json");
try (Stream<String> lines = Files.lines(path)) {
return lines.map(line -> parseJson(line)).collect(Collectors.toList());
}
这是返回一系列RankedPage对象的函数体。我们从ranked-page.json文件中读取它们。然后我们使用parseJson函数将 JSON 转换成 Java 类:
JSON.std.beanFrom(RankedPage.class, line);
看完数据,我们就可以分析了。通常,分析的第一步是查看汇总统计数据,我们可以使用 Apache Commons Math 来实现这一点。
阿帕奇公共数学
一旦我们读取了数据,我们就可以计算统计数据。正如我们前面提到的,我们通常对诸如最小值、最大值、平均值、标准偏差等汇总感兴趣。我们可以使用 Apache Commons 数学库。我们把它包含在pom.xml里吧:
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId>
<version>3.6.1</version>
</dependency>
有一个用于计算汇总的SummaryStatistics类。让我们用它来计算一些关于我们抓取的页面的正文内容长度分布的统计数据:
SummaryStatistics statistics = new SummaryStatistics(); data.stream().mapToDouble(RankedPage::getBodyContentLength)
.forEach(statistics::addValue);
System.out.println(statistics.getSummary());
这里,我们创建了SummaryStatistics对象,并添加了所有的正文内容长度。之后,我们可以调用一个getSummary方法来一次获得所有的汇总统计数据。这将打印以下内容:
StatisticalSummaryValues:
n: 4067
min: 0.0
max: 8675779.0
mean: 14332.239242685007
std dev: 144877.54551111493
variance: 2.0989503193325176E10
sum: 5.8289217E7
DescriptiveStatistics方法是这个库中另一个有用的类。它允许获得更多的值,如中位数和百分位数,以及百分位数;更好地展示数据的样子:
double[] dataArray = data.stream()
.mapToDouble(RankedPage::getBodyContentLength)
.toArray();
DescriptiveStatistics desc = new DescriptiveStatistics(dataArray);
System.out.printf("min: %9.1f%n", desc.getMin());
System.out.printf("p05: %9.1f%n", desc.getPercentile(5));
System.out.printf("p25: %9.1f%n", desc.getPercentile(25)); System.out.printf("p50: %9.1f%n", desc.getPercentile(50)); System.out.printf("p75: %9.1f%n", desc.getPercentile(75)); System.out.printf("p95: %9.1f%n", desc.getPercentile(95)); System.out.printf("max: %9.1f%n", desc.getMax());
这将产生以下输出:
min: 0.0
p05: 527.6
p25: 3381.0
p50: 6612.0
p75: 11996.0
p95: 31668.4
max: 8675779.0
从输出中,我们可以注意到最小长度为零,这很奇怪;最有可能的是,存在数据处理问题。此外,最大值非常高,这表明存在异常值。稍后,从我们的分析中排除这些值是有意义的。
内容长度为零的页面可能是爬行错误。我们来看看这几页的比例:
double proportion = data.stream()
.mapToInt(p -> p.getBodyContentLength() == 0 ? 1 : 0)
.average().getAsDouble();
System.*out*.printf("proportion of zero content length: %.5f%n", proportion);
我们看到没有多少页面是零长度的,所以删除它们是相当安全的。
稍后,在下一章中,我们将尝试预测一个 URL 是否来自第一个搜索页面结果。如果一些特征对于每个页面具有不同的值,那么机器学习模型将能够看到这种差异,并将其用于更准确的预测。让我们看看不同页面的内容长度值是否相似。
为此,我们可以按页面对 URL 进行分组,并计算平均内容长度。正如我们已经知道的,Java 流可以用来做到这一点:
Map<Integer, List<RankedPage>> byPage = data.stream()
.filter(p -> p.getBodyContentLength() != 0)
.collect(Collectors.groupingBy(RankedPage::getPage));
请注意,我们为空白页面添加了一个过滤器,因此它们不会出现在组中。现在,我们可以使用组来计算平均值:
List<DescriptiveStatistics> stats = byPage.entrySet().stream()
.sorted(Map.Entry.comparingByKey())
.map(e -> calculate(e.getValue(), RankedPage::getBodyContentLength))
.collect(Collectors.toList());
这里,calculate是一个函数,它接受一个集合,计算每个元素上提供的函数(在本例中使用getBodyContentLength,并从中创建一个DescriptiveStatistics对象:
private static DescriptiveStatistics calculate(List<RankedPage> data,
ToDoubleFunction<RankedPage> getter) {
double[] dataArray = data.stream().mapToDouble(getter).toArray();
return new DescriptiveStatistics(dataArray);
}
现在,在列表中,您将拥有每个组的描述性统计数据(在本例中为页面)。然后,我们可以用任何我们想要的方式展示它们。考虑下面的例子:
Map<String, Function<DescriptiveStatistics, Double>> functions = new LinkedHashMap<>();
functions.put("min", d -> d.getMin());
functions.put("p05", d -> d.getPercentile(5));
functions.put("p25", d -> d.getPercentile(25));
functions.put("p50", d -> d.getPercentile(50));
functions.put("p75", d -> d.getPercentile(75));
functions.put("p95", d -> d.getPercentile(95));
functions.put("max", d -> d.getMax());
System.out.print("page");
for (Integer page : byPage.keySet()) {
System.out.printf("%9d ", page);
}
System.out.println();
for (Entry<String, Function<DescriptiveStatistics, Double>> pair : functions.entrySet()) {
System.out.print(pair.getKey());
Function<DescriptiveStatistics, Double> function = pair.getValue();
System.out.print(" ");
for (DescriptiveStatistics ds : stats) {
System.out.printf("%9.1f ", function.apply(ds));
}
System.out.println();
}
这会产生以下输出:
page 0 1 2
min 5.0 1.0 5.0
p05 1046.8 900.6 713.8
p25 3706.0 3556.0 3363.0
p50 7457.0 6882.0 6383.0
p75 13117.0 12067.0 11309.8
p95 42420.6 30557.2 27397.0
max 390583.0 8675779.0 1998233.0
输出表明,内容长度的分布在来自搜索引擎结果的不同页面的 URL 之间是不同的。因此,在预测给定 URL 的搜索页码时,这可能很有用。
细木工制品
您可能会注意到,我们刚刚编写的代码非常冗长。当然,可以把它放在一个 helper 函数中,在需要时调用它,但是还有另一种更简洁的方法来计算这些统计数据——使用 joinery 及其数据框架。
在 Joinery 中,DataFrame对象有一个名为describe的方法,它创建一个包含汇总统计信息的新数据框架:
List<RankedPage> pages = Data.readRankedPages();
DataFrame<Object> df = BeanToJoinery.convert(pages, RankedPage.class);
df = df.retain("bodyContentLength", "titleLength", "numberOfHeaders"); DataFrame<Object> describe = df.describe();
System.out.println(describe.toString());
在前面的代码中,Data.readRankedPages是一个 helper 方法,它读取 JSON 数据并将其转换为一组 Java 对象,BeanToJoinery.convert是一个 helper 类,它将一组 Java 对象转换为一个DataFrame。
然后,我们只保留三列,删除其他所有内容。以下是输出:
body contentLength numberOfHeaders titleLength
count 4067.00000000 4067.00000000 4067.00000000
mean 14332.23924269 25.25325793 46.17334645
std 144877.5455111 32.13788062 27.72939822
var 20989503193.32 1032.84337047 768.91952552
max 8675779.000000 742.00000000 584.00000000
min 0.00000000 0.00000000 0.00000000
我们还可以查看不同组的平均值,例如,不同页面的平均值。为此,我们可以使用groupBy方法:
DataFrame<Object> meanPerPage = df.groupBy("page").mean()
.drop("position")
.sortBy("page")
.transpose();
System.out.println(meanPerPage);
除了在groupBy之后应用 mean 之外,我们还删除了一个列位置,因为我们已经知道位置对于不同页面是不同的。此外,我们在最后应用转置操作;这是一个技巧,当有许多列时,使输出适合一个屏幕。这会产生以下输出:
page 0 1 2
bodyContentLength 12577 18703 11286
numberOfHeaders 30 23 21
numberOfLinks 276 219 202
queryInTitle 0 0 0
titleLength 46 46 45
我们可以看到,一些变量的平均值差异很大。对于其他变量,如queryInTitle,似乎没有任何区别。但是,请记住这是一个布尔变量,因此平均值介于 0 和 1 之间。出于某种原因,Joinery 决定不在这里显示小数部分。
现在,我们知道如何在 Java 中计算一些简单的汇总统计数据,但是要做到这一点,我们首先需要编写一些代码,编译它,然后运行它。这不是最方便的过程,有更好的交互方式,即避免编译并立即得到结果。接下来,我们将看到如何在 Java 中实现它。
Java 中的交互式探索性数据分析
Java 是一种静态类型的编程语言,用 Java 编写的代码需要编译。虽然 Java 适合开发复杂的数据科学应用程序,但它使得交互式地探索数据变得更加困难;每次,我们都需要重新编译源代码,重新运行分析脚本来查看结果。这意味着,如果我们需要读取一些数据,我们将不得不一遍又一遍地这样做。如果数据集很大,程序需要更多的时间来启动。
因此很难与数据交互,这使得在 Java 中进行 EDA 比在其他语言中更困难。特别是读取-评估-打印循环 ( REPL )这个交互 shell,对于做 EDA 来说是相当重要的一个特性。
不幸的是,Java 8 没有 REPL,但是有几个替代方案:
- 其他交互式 JVM 语言,如 JavaScript、Groovy 或 Scala
- 带有 jshell 的 Java 9
- 完全不同的平台,如 Python 或 R
在这一章中,我们将着眼于前两个选项——JVM 语言和 Java 9 的 REPL。
JVM 语言
你大概知道,Java 平台不仅是 Java 编程语言,而且 Java 虚拟机 (JVM)可以运行其他 JVM 语言的代码。有很多运行在 JVM 上的语言都有 REPL,比如 JavaScript、Scala、Groovy、JRuby 和 Jython。还有很多。所有这些语言都可以访问任何用 Java 编写的代码,而且它们有交互式控制台。
例如,Groovy 与 Java 非常相似,在 Java 8 之前,几乎所有用 Java 编写的代码都可以在 Groovy 中运行。但是,对于 Java 8 来说,情况就不再是这样了。Groovy 不支持 lambda 表达式和函数接口的新 Java 语法,所以我们不能在那里运行本书的大部分代码。
Scala 是另一种流行的函数式 JVM 语言,但是它的语法与 Java 非常不同。对于数据处理来说,它是一种非常强大和富有表现力的语言,它有一个漂亮的交互式外壳,并且有许多用于进行数据分析和数据科学的库。
此外,有几个 JavaScript 实现可用于 JVM。其中一个是 Nashorn,它自带 Java 8 开箱即用;没有必要将它作为一个独立的依赖项包含进来。Joinery 还带有一个内置的交互式控制台,它利用了 JavaScript,在本章的后面,我们将看到如何使用它。虽然所有这些语言都很好,但它们超出了本书的范围。你可以从这些书中了解更多:
- 动作麻利,迪克·科尼格,曼宁
- Scala 数据分析食谱, 阿伦·马尼瓦南,帕克特出版社
交互式 Java
说 Java 是一种 100%非交互式语言是不公平的;有一些扩展直接为 Java 提供了 REPL 环境。
一个这样的环境是看起来完全像 Java 的脚本语言(BeanShell)。但是,它太旧了,并且不支持新的 Java 8 语法,所以对于进行交互式数据分析来说,它不是很有用。
更有趣的是 Java 9,它附带了一个名为 JShell 的集成 REPL,支持 tab 上的自动补全、Java 流和 lambda 表达式的 Java 8 语法。在撰写本文时,Java 9 只作为早期访问版本提供。你可以从 jdk9.java.net/download/.下…
启动 shell 很容易:
$ jshell
但是通常你想访问一些库,因此它们需要在类路径中。通常,我们使用 Maven 来管理依赖项,所以我们可以运行下面的代码将在pom文件中指定的所有jar库复制到我们选择的目录中:
mvn dependency:copy-dependencies -DoutputDirectory=lib
mvn compile
完成后,我们可以像这样运行 shell:
jshell -cp lib/*:target/classes
如果您在 Windows 上,请用分号替换冒号:
jshell -cp lib/*;target/classes
然而,我们的实验表明,不幸的是,JShell 还很原始,有时会崩溃。在撰写本文时,计划在 2017 年 3 月底发布。现在,我们不会更详细地讨论 JShell,但是本章前半部分的所有代码都应该可以在这个控制台上运行,不需要额外的配置。此外,我们应该能够立即看到输出。
到目前为止,我们已经使用 Joinery 几次了,它也支持执行简单的 EDA。接下来,我们将看看如何用细木工板壳进行分析。
细木工外壳
细木工已经多次被证明对数据处理和简单的 EDA 很有用。它有一个交互式的外壳,你可以立即得到答案。
如果数据已经是 CSV 格式,那么可以从系统控制台调用 Joinery shell:
$ java joinery.DataFrame shell
你可以在github.com/cardillo/jo…看到例子,所以如果你的数据已经在 CSV 中,你就可以开始了,只需按照那里的指示。
在本书中,当数据帧不是 CSV 格式时,我们将看一个更复杂的例子。在我们的例子中,数据是 JSON 格式的,而 Joinery shell 不支持这种格式,所以我们需要先做一些预处理。
我们能做的是在 Java 代码中创建一个 DataFrame 对象,然后创建交互式 shell 并将 DataFrame 传递到那里。让我们看看如何能做到这一点。
但是在我们这样做之前,我们需要添加一些依赖项来使之成为可能。第一,Joinery shell 使用 JavaScript,但不使用 JVM 附带的 Nashorn engin 相反,它使用的是 Mozilla 的名为 Rhino 的引擎。因此,我们需要将它包含到我们的pom:
<dependency>
<groupId>rhino</groupId>
<artifactId>js</artifactId>
<version>1.7R2</version>
</dependency>
第二,它依赖于一个特殊的自动补全库jline。我们也来补充一下:
<dependency>
<groupId>jline</groupId>
<artifactId>jline</artifactId>
<version>2.14.2</version>
</dependency>
使用 Maven 给了你很大的灵活性;它更简单,不需要您手动下载所有的库,并从源代码构建 Joinery 来执行 shell。所以,我们让 Maven 来处理。
现在我们可以使用它了:
List<RankedPage> pages = Data.readRankedPages();
DataFrame<Object> dataFrame = BeanToJoinery.convert(pages, 'RankedPage.class);
Shell.repl(Arrays.asList(dataFrame));
让我们将这段代码保存到一个chapter03.JoineryShell类中。之后,我们可以用下面的 Maven 命令运行它:
mvn exec:java -Dexec.mainClass="chapter03.JoineryShell"
这将把我们带到细木工外壳:
# DataFrames for Java -- null, 1.7-8e3c8cf
# Java HotSpot(TM) 64-Bit Server VM, Oracle Corporation, 1.8.0_91
# Rhino 1.7 release 2 2009 03 22
>
我们在 Java 中传递给 shell 对象的所有数据帧都可以在 Shell 中的 Frames 变量中找到。所以,为了得到DataFrame,我们可以这样做:
> var df = frames[0]
要查看DataFrame的内容,只需写下它的名字:
> df
你会看到前几排DataFrame。请注意,自动完成功能按预期工作:
> df.<tab>
您将看到选项列表。
我们可以使用这个 shell 调用数据帧上的相同方法,就像我们在普通的 Java 应用程序中使用的方法一样。例如,您可以按如下方式计算平均值:
> df.mean().transpose()
我们将看到以下输出:
bodyContentLength 14332.23924269
numberOfHeaders 25.25325793
numberOfLinks 231.16867470
page 1.03221047
position 18.76518318
queryInTitle 0.59822965
titleLength 46.17334645
或者。我们可以执行相同的groupBy示例:
> df.drop('position').groupBy('page').mean().sortBy('page').transpose()
这将产生以下输出:
page 0 1 2
bodyContentLength 12577 18703 11286
numberOfHeaders 30 23 21
numberOfLinks 276 219 202
queryInTitle 0 0 0
titleLength 46 46 45
最后,用细木工也可以创造一些简单的情节。为此,我们需要使用一个额外的库。对于绘图,细木工使用xchart。让我们把它包括进来:
<dependency>
<groupId>com.xeiam.xchart</groupId>
<artifactId>xchart</artifactId>
<version>2.5.1</version>
</dependency>
并再次运行控制台。现在我们可以使用plot函数:
> df.retain('titleLength').plot(PlotType.SCATTER)
我们会看到这个:
这里,我们看到一个标题长度超过 550 个字符的异常值。让我们把 200 以上的都去掉,再看一下图片。另外,记住有一些零长度的内容页面,所以我们也可以把它们过滤掉。
为了只保留那些满足某些条件的行,我们使用了select方法。它采用一个函数,应用于每一行;如果函数返回true,则保留该行。
我们可以这样使用它:
> df.retain('titleLength')
.select(function(list) { return list.get(0) <= 200; })
.select(function(list) { return list.get(0) > 0;})
.plot(PlotType.SCATTER)
前面代码中的换行符是为了提高可读性而添加的,但是它们在控制台中不起作用,所以不要使用它们。
现在,我们有了一个更清晰的画面:
遗憾的是,joinery 的绘图能力相当有限,使用xchart制作图形需要很大的努力。
正如我们已经知道的,在细木工中,很容易计算不同组之间的统计数据;我们只需要使用groupBy方法。然而,不可能容易地使用这种方法来绘制数据,以便容易地比较每组的分布。
还有其他工具也可以用于 EDA:
- 用 Java 编写的 Weka 是一个用于执行数据挖掘的库。它有一个用于执行 EDA 的 GUI 界面。
- 另一个 Java 库 Smile 有一个 Scala shell 和一个 smile-plot 库,用于创建可视化效果。
不幸的是,Java 通常不是执行 EDA 的理想选择,有其他更适合的动态语言。例如,R 和 Python 对于这个任务来说是理想的,但是介绍它们超出了本书的范围。你可以从以下书籍中了解更多信息:
- 通过 R,Gergely Daroczi 掌握数据分析
- Python 机器学习,塞巴斯蒂安·拉什卡
摘要
在这一章中,我们讨论了探索性数据分析,简称 EDA。我们讨论了如何用 Java 进行 EDA,包括创建摘要和简单的可视化。
在本章中,我们使用了我们的搜索引擎示例,并分析了我们之前收集的数据。我们的分析表明,对于来自搜索引擎结果的不同页面的 URL,一些变量的分布看起来是不同的。这表明,可以利用这些差异来建立一个模型,预测 URL 是否来自第一页。
在下一章,我们将看看如何做到这一点,并讨论监督机器学习算法,如分类和回归。
四、监督学习——分类和回归
在前几章中,我们学习了如何在 Java 中预处理数据,以及如何进行探索性的数据分析。现在,我们已经打下了基础,我们准备开始创建机器学习模型。
首先,我们从监督学习开始。在有监督的设置中,我们有一些信息附加在每个观察上,叫做标签,我们想从中学习,对没有标签的观察进行预测。
有两种类型的标签:第一种是离散和有限的,如真/假或买/卖,第二种是连续的,如工资或温度。这些类型对应于两种类型的监督学习:分类和回归。我们将在本章中讨论它们。
本章包括以下几点:
- 分类问题
- 回归问题
- 每种类型的评估指标
- Java 中可用实现的概述
到本章结束时,你将知道如何使用 Smile、LIBLINEAR 和其他 Java 库来构建有监督的机器学习模型。
分类
在机器学习中,分类问题处理具有有限组可能值的离散目标。这意味着有一组可能的结果,给定一些特征,我们想要预测结果。
二元分类是最常见的分类问题类型,因为target变量只能有两个可能的值,比如True / False、Relevant / Not Relevant、Duplicate / Not Duplicate、Cat / Dog等等。
有时目标变量可以有两个以上的结果,例如颜色、物品类别、汽车型号等等,我们称之为多类分类。通常,每个观察只能有一个标签,但在某些设置中,可以为一个观察分配多个值。多类分类可以转化为一组二分类问题,这就是为什么我们将主要集中于二分类。
二元分类模型
正如我们已经讨论过的,二元分类模型处理只有两种可能结果需要预测的情况。通常,在这些设置中,我们有正类项目(存在某种效果)和负类项目(不存在某种效果)。
例如,积极的标签可以是相关的、重复的、不能偿还债务的等等。正类的实例通常被赋予目标值 1。此外,我们还有负面的实例,如不相关、不重复、偿还债务,它们被赋予目标值0。
这种分为积极和消极两类的做法有些人为,在某些情况下并不真正有意义。例如,如果我们有猫和狗的图像,即使只有两个类,说Cat是正类而Dog是负类也有点牵强。但这对于模型来说并不重要,所以我们仍然可以这样分配标签:Cat是1,而Dog是0。
一旦我们训练了一个模型,我们通常不会对硬预测感兴趣,比如正面效应在那里,或者这是一只猫。更有趣的是积极或消极影响的程度,这通常是通过预测概率来实现的。例如,如果我们想建立一个模型来预测一个客户是否会无法偿还债务,那么说这个客户有 30%的违约比这个客户不会违约更有用。
有许多模型可以解决二元分类问题,不可能涵盖所有的模型。我们将简要介绍实践中最常用的方法。它们包括以下内容:
- 逻辑回归
- 支持向量机
- 决策树
- 神经网络
我们假设您已经熟悉这些方法,并且至少对它们的工作原理有所了解。不要求非常熟悉,但要了解更多信息,您可以查阅以下书籍:
- 统计学习入门, G .詹姆斯, D .威滕, T .哈斯蒂, R .蒂布拉尼,斯普林格
- Python 机器学习, S .拉什卡, Packt 出版
说到库,我们将涉及 Smile、JSAT、LIBSVM、LIBLINEAR 和 Encog。让我们从微笑开始。
微笑
统计机器智能和学习引擎 ( 微笑)是一个拥有大量分类和其他机器学习算法的库。对我们来说,最有趣的是逻辑回归、SVM 和随机森林,但你可以在 github.com/haifengl/sm… GitHub 官方页面上看到可用算法的完整列表。](github.com/haifengl/sm…)
该库可以在 Maven Central 上获得,在撰写本文时的最新版本是 1.1.0。若要将它包含到项目中,请添加以下依赖项:
<dependency>
<groupId>com.github.haifengl</groupId>
<artifactId>smile-core</artifactId>
<version>1.1.0</version>
</dependency>
正在积极开发中;新功能和错误修复经常被添加,但发布的频率并不高。我们建议使用 Smile 的最新可用版本,要获得它,您需要从源代码构建它。为此:
- 安装
sbt,这是一个用于构建 scala 项目的工具。可以按照http://www . Scala-SBT . org/release/docs/Manual-installation . html的说明进行操作 - 使用 git 从github.com/haifengl/sm…克隆项目
- 要构建库并将其发布到本地 Maven 存储库,请运行以下命令:
sbt core/publishM2
微笑库由几个子模块组成,如smile-core、smile-nlp、smile-plot等。出于本章的目的,我们只需要核心包,前面的命令将只构建核心包。在撰写本文时,GitHub 上的当前版本是 1.2.0。因此,在构建它之后,将下面的依赖项添加到 pom 中:
<dependency>
<groupId>com.github.haifengl</groupId>
<artifactId>smile-core</artifactId>
<version>1.2.0</version>
</dependency>
Smile 的模型期望数据是双精度的二维数组形式,标签信息是整数的一维数组形式。对于二进制模型,值应为0或1。Smile 中的一些模型可以处理多类分类问题,所以有可能有更多的标签,不仅仅是0、1,还有2、3等等。
在 Smile 中,模型是使用builder模式构建的;你创建一个特殊的类,设置一些参数,最后它返回它构建的对象。这个builder类通常被称为Trainer,所有的模型都应该有一个Trainer对象。
例如,考虑训练一个 RandomForest 模型:
double[] X = ... // training data
int[] y = ... // 0 and 1 labels
RandomForest model = new RandomForest.Trainer()
.setNumTrees(100)
.setNodeSize(4)
.setSamplingRates(0.7)
.setSplitRule(SplitRule.ENTROPY)
.setNumRandomFeatures(3)
.train(X, y);
RandomForest.Trainer类接受一组参数和训练数据,最终生成训练好的RandomForest模型。来自 Smile 的 RandomForest 的实现具有以下参数:
- 这是模型中要训练的树的数量
nodeSize:这是叶节点中的最小项目数samplingRate:这是用于生长每棵树的训练数据的比率splitRule:这是用于选择最佳分割的杂质测量numRandomFeatures:这是模型为选择最佳分割随机选择的特征数量
类似地,逻辑回归训练如下:
LogisticRegression lr = new LogisticRegression.Trainer()
.setRegularizationFactor(lambda)
.train(X, y);
一旦我们有了一个模型,我们就可以用它来预测以前看不见的商品的标签。为此,我们使用predict方法:
double[] row = // data
int prediction = model.predict(row);
这段代码输出给定项目最可能的类。但是,我们往往更感兴趣的不是标签本身,而是拥有标签的概率。如果一个模型实现了SoftClassifier接口,那么有可能得到这样的概率:
double[] probs = new double[2];
model.predict(row, probs);
运行这段代码后,probs数组将包含概率。
JSAT
Java 统计分析工具 ( JSAT )是另一个 Java 库,里面包含了很多常用机器学习算法的实现。您可以在 github.com/EdwardRaff/… 的查看已实施车型的完整列表。
要将 JSAT 添加到 Java 项目中,请将下面的代码片段添加到pom:
<dependency>
<groupId>com.edwardraff</groupId>
<artifactId>JSAT</artifactId>
<version>0.0.5</version>
</dependency>
与 Smile 模型不同,它只需要一个带有特征信息的 doubles 数组,JSAT 需要一个特殊的数据包装类。如果我们有一个数组,它被转换成 JSAT 表示,如下所示:
double[][] X = ... // data
int[] y = ... // labels
// change to more classes for more classes for multi-classification
CategoricalData binary = new CategoricalData(2);
List<DataPointPair<Integer>> data = new ArrayList<>(X.length);
for (int i = 0; i < X.length; i++) {
int target = y[i];
DataPoint row = new DataPoint(new DenseVector(X[i]));
data.add(new DataPointPair<Integer>(row, target));
}
ClassificationDataSet dataset = new ClassificationDataSet(data, binary);
一旦我们准备好数据集,我们就可以训练一个模型。让我们再次考虑随机森林分类器:
RandomForest model = new RandomForest();
model.setFeatureSamples(4);
model.setMaxForestSize(150);
model.trainC(dataset);
首先,我们为模型设置一些参数,然后,在最后,我们调用trainC方法(这意味着训练一个分类器)。
在 JSAT 实现中,与 Smile 相比,Random Forest 的调整选项较少,只有可供选择的要素数量和要生长的树的数量。
此外,JSAT 包含逻辑回归的几个实现。通常的逻辑回归模型没有任何参数,它是这样训练的:
LogisticRegression model = new LogisticRegression();
model.trainC(dataset);
如果我们想要一个正则化的模型,那么我们需要使用LogisticRegressionDCD类。对偶坐标下降 ( DCD )是用于训练逻辑回归的最优化方法。我们这样训练它:
LogisticRegressionDCD model = new LogisticRegressionDCD();
model.setMaxIterations(maxIterations);
model.setC(C);
model.trainC(fold.toJsatDataset());
在该代码中,C是正则化参数,C的值越小,对应的正则化效果越强。
最后,为了输出概率,我们可以做以下事情:
double[] row = // data
DenseVector vector = new DenseVector(row);
DataPoint point = new DataPoint(vector);
CategoricalResults out = model.classify(point);
double probability = out.getProb(1);
CategoricalResults类包含大量信息,包括每个类的概率和最可能的标签。
LIBSVM 和 LIBLINEAR
接下来,我们考虑两个类似的库,LIBSVM 和 LIBLINEAR。
- LIBSVM(www.csie.ntu.edu.tw/~cjlin/libs…)是一个实现支持向量机模型的库,包括支持向量分类器
- LIBLINEAR(www.csie.ntu.edu.tw/~cjlin/libl…)是一个快速线性分类算法库,如线性 SVM 和逻辑回归
这两个库来自同一个研究小组,并且具有非常相似的接口。我们将从 LIBSVM 开始。
LIBSVM 是一个库,它实现了许多不同的 SVM 算法。它是用 C++实现的,并且有官方支持的 Java 版本。它可以在 Maven Central 上获得:
<dependency>
<groupId>tw.edu.ntu.csie</groupId>
<artifactId>libsvm</artifactId>
<version>3.17</version>
</dependency>
注意,LIBSVM 的 Java 版本不如 C++版本更新频繁。尽管如此,前面的版本是稳定的,不应该包含错误,但它可能比它的 C++版本慢。
要使用 LIBSVM 中的 SVM 模型,首先需要指定参数。为此,您创建一个svm_parameter类。在内部,您可以指定许多参数,包括:
- 内核类型(
RBF、POLY或LINEAR) - 正则化参数
C probability你可以设置为1来得到概率svm_type应设置为C_SVC;这表明模型应该是一个分类器
回想一下,SVM 模型可以有不同的核,根据我们使用的核,我们有不同的模型和不同的参数。这里,我们将考虑最常用的内核;线性(或无核)、多项式和径向基函数 ( RBF ),又称高斯核)。
首先,我们从线性内核开始。首先,我们创建一个svm_paramter对象,其中我们将内核类型设置为LINEAR,并要求它输出概率:
svm_parameter param = new svm_parameter();
param.svm_type = svm_parameter.C_SVC;
param.kernel_type = svm_parameter.LINEAR;
param.probability = 1;
param.C = C;
// default parameters
param.cache_size = 100;
param.eps = 1e-3;
param.p = 0.1;
param.shrinking = 1;
接下来,我们有一个多项式核。回想一下,多项式内核由以下公式指定:
它有三个附加参数,即控制内核的gamma、coef0和degree,还有C -正则化参数。我们可以这样配置POLY SVM 的svm_parameter类:
svm_parameter param = new svm_parameter();
param.svm_type = svm_parameter.C_SVC;
param.kernel_type = svm_parameter.POLY;
param.C = C;
param.degree = degree;
param.gamma = 1;
param.coef0 = 1;
param.probability = 1;
// plus defaults from the above
最后,高斯核(或 RBF)具有以下公式:
因此有一个参数gamma,它控制高斯曲线的宽度。我们可以像这样指定带有RBF内核的模型:
svm_parameter param = new svm_parameter();
param.svm_type = svm_parameter.C_SVC;
param.kernel_type = svm_parameter.RBF;
param.C = C;
param.gamma = gamma;
param.probability = 1;
// plus defaults from the above
一旦我们创建了配置对象,我们需要将数据转换成正确的格式。该库希望数据以稀疏格式表示。对于单个数据行,到所需格式的转换如下:
double[] dataRow = // single row vector
svm_node[] svmRow = new svm_node[dataRow.length];
for (int j = 0; j < dataRow.length; j++) {
svm_node node = new svm_node();
node.index = j;
node.value = dataRow[j];
svmRow[j] = node;
}
因为我们通常有一个矩阵,而不仅仅是一行,所以我们将前面的代码应用于这个矩阵的每一行:
double[][] X = ... // data
int n = X.length;
svm_node[][] nodes = new svm_node[n][];
for (int i = 0; i < n; i++) {
nodes[i] = wrapAsSvmNode(X[i]);
}
这里,wrapAsSvmNode是一个函数,它将一个向量包装成一个由svm_node对象组成的数组。
现在,我们可以将数据和标签一起放入svm_problem对象:
double[] y = ... // labels
svm_problem prob = new svm_problem();
prob.l = n;
prob.x = nodes;
prob.y = y;
最后,我们可以使用参数和问题规范来训练 SVM 模型:
svm_model model = svm.svm_train(prob, param);
一旦模型被训练,我们就可以用它来分类看不见的数据。获取概率的方法如下:
double[][] X = // test data
int n = X.length;
double[] results = new double[n];
double[] probs = new double[2];
for (int i = 0; i < n; i++) {
svm_node[] row = wrapAsSvmNode(X[i]);
svm.svm_predict_probability(model, row, probs);
results[i] = probs[1];
}
由于我们使用了param.probability = 1,我们可以使用svm.svm_predict_probability方法来预测概率。与 Smile 一样,该方法接受一个 doubles 数组,并将输出写入其中。在这个操作之后,它将包含这个数组中的概率。
最后,在训练的时候,LIBSVM 在控制台上输出很多东西。如果我们对这个输出不感兴趣,我们可以用下面的代码片段禁用它:
svm.svm_set_print_string_function(s -> {});
只需将它添加到代码的开头,就再也看不到调试信息了。
下一个库是 LIBLINEAR,它提供了非常快速和高性能的线性分类器,如具有线性核的 SVM 和逻辑回归。它可以轻松扩展到数千万甚至数亿个数据点。它的界面与 LIBSVM 非常相似,我们需要指定参数和数据,然后训练一个模型。
与 LIBSVM 不同,LIBLINEAR 没有官方的 Java 版本,但是在 liblinear.bwaldvogel.de/有一个非官方的 Java 端口。要使用它,请包括以下内容:
<dependency>
<groupId>de.bwaldvogel</groupId>
<artifactId>liblinear</artifactId>
<version>1.95</version>
</dependency>
该接口与 LIBSVM 非常相似。首先,我们定义参数:
SolverType solverType = SolverType.L1R_LR;
double C = 0.001;
double eps = 0.0001;
Parameter param = new Parameter(solverType, C, eps);
在本例中,我们指定了三个参数:
solverType:定义将要使用的模型C:这是正则化的量,C 越小,正则化越强epsilon:这是停止训练过程的容忍度;合理的默认值是0.0001
对于分类问题,以下是我们可以使用的解决方案:
- 逻辑回归 :
L1R_LR或L2R_LR - SVM :
L1R_L2LOSS_SVC或L2R_L2LOSS_SVC
这里,我们有两个模型:逻辑回归和 SVM,以及两种正规化类型,L1 和 L2。我们如何决定选择哪种模型和使用哪种正则化?根据官方常见问题解答(可以在这里找到:www.csie.ntu.edu.tw/~cjlin/libl…),我们应该:
- 与逻辑回归相比,我更喜欢 SVM,因为它训练速度更快,而且通常精度更高
- 首先尝试 L2 正则化,除非你需要一个稀疏的解决方案,在这种情况下使用 L1
接下来,我们需要准备我们的数据。如前所述,我们需要将其包装成某种特殊的格式。首先,让我们看看如何包装单个数据行:
double[] row = // data
int m = row.length;
Feature[] result = new Feature[m];
for (int i = 0; i < m; i++) {
result[i] = new FeatureNode(i + 1, row[i]);
}
注意,我们将1添加到索引中。0是偏置项,所以实际特性要从1开始。
我们可以将这段代码放入一个wrapRow函数中,然后将整个数据集包装如下:
double[][] X = // data
int n = X.length;
Feature[][] matrix = new Feature[n][];
for (int i = 0; i < n; i++) {
matrix[i] = wrapRow(X[i]);
}
现在,我们可以用数据和标签创建Problem类:
double[] y = // labels
Problem problem = new Problem();
problem.x = wrapMatrix(X);
problem.y = y;
problem.n = X[0].length + 1;
problem.l = X.length;
注意,这里我们还需要提供数据的维度,也就是特征的数量加 1。我们需要增加一个,因为它包含了偏差项。
现在我们准备训练模型:
Model model = LibLinear.train(fold, param);
当模型被训练后,我们可以用它来分类看不见的数据。在下面的例子中,我们将输出概率:
double[] dataRow = // data
Feature[] row = wrapRow(dataRow);
Linear.predictProbability(model, row, probs);
double result = probs[1];
前面的代码适用于逻辑回归模型,但不适用于 SVM,SVM 无法输出概率,因此前面的代码将为L1R_L2LOSS_SVC等求解器抛出错误。我们可以做的是获得原始输出:
double[] values = new double[1];
Feature[] row = wrapRow(dataRow);
Linear.predictValues(model, row, values);
double result = values[0];
在这种情况下,结果将不会包含概率,而是一些真实值。当该值大于零时,模型预测该类为正。
如果我们想将这个值映射到[0, 1]范围,我们可以使用sigmoid函数:
public static double[] sigmoid(double[] scores) {
double[] result = new double[scores.length];
for (int i = 0; i < result.length; i++) {
result[i] = 1 / (1 + Math.exp(-scores[i]));
}
return result;
}
最后,和 LIBSVM 一样,LIBLINEAR 也将很多东西输出到标准输出。如果您不希望看到它,可以使用以下代码将其静音:
PrintStream devNull = new PrintStream(new NullOutputStream());
Linear.setDebugOutput(devNull);
这里,我们使用 Apache IO 中的NullOutputStream,它什么也不做,所以屏幕保持干净。
想知道什么时候用 LIBSVM,什么时候用 LIBLINEAR?对于大型数据集,通常不可能使用任何内核方法。在这种情况下,您应该更喜欢 LIBLINEAR。此外,LIBLINEAR 特别适合文本处理,比如文档分类。我们将在第六章、中更详细地介绍这些案例——自然语言处理和信息检索。
Encog
到目前为止,我们已经介绍了许多模型,即逻辑回归、SVM 和 RandomForest,并且我们已经查看了实现它们的多个库。但是我们还没有涉及神经网络。在 Java 中,有一个专门处理神经网络的特殊库——Encog。它可以在 Maven Central 上获得,并且可以通过以下代码片段进行添加:
<dependency>
<groupId>org.encog</groupId>
<artifactId>encog-core</artifactId>
<version>3.3.0</version>
</dependency>
在包括库之后,第一步是指定神经网络的架构。我们可以这样做:
BasicNetwork network = new BasicNetwork();
network.addLayer(new BasicLayer(new ActivationSigmoid(), true, noInputNeurons));
network.addLayer(new BasicLayer(new ActivationSigmoid(), true, 30));
network.addLayer(new BasicLayer(new ActivationSigmoid(), true, 1));
network.getStructure().finalizeStructure();
network.reset();
这里,我们创建一个网络,它有一个输入层,一个内层有 30 个神经元,一个输出层有 1 个神经元。在每一层中,我们使用 sigmoid 作为激活函数,并添加偏置输入(true参数)。最后,reset方法随机初始化网络中的权重。
对于输入和输出,Encog 期望二维双数组。在二进制分类的情况下,我们通常有一个一维数组,所以我们需要转换它:
double[][] X = // data
double[] y = // labels
double[][] y2d = new double[y.length][];
for (int i = 0; i < y.length; i++) {
y2d[i] = new double[] { y[i] };
}
一旦数据被转换,我们就把它包装成一个特殊的包装类:
MLDataSet dataset = new BasicMLDataSet(X, y2d);
然后,该数据集可用于训练:
MLTrain trainer = new ResilientPropagation(network, dataset);
double lambda = 0.01;
trainer.addStrategy(new RegularizationStrategy(lambda));
int noEpochs = 101;
for (int i = 0; i < noEpochs; i++) {
trainer.iteration();
}
我们不会在这里详细介绍 Encog,但我们会在第 8 章、中回到神经网络,用 DeepLearning4j 进行深度学习,在那里我们会看到一个不同的库——Deep Learning 4j。
Java 中还有很多其他的机器学习库。例如威卡、H2O、贾瓦尔等。这是不可能涵盖所有的,但你也可以尝试一下,看看你是否喜欢他们比我们已经涵盖的。
接下来,我们将看到如何评估分类模型。
估价
我们已经介绍了许多机器学习库,其中许多实现了相同的算法,如随机森林或逻辑回归。此外,每个单独的模型可以具有许多不同的参数,逻辑回归具有正则化系数,SVM 通过设置核及其参数来配置。
我们如何从这么多可能的变体中选择最佳的单一模型?
为此,我们首先定义一些评估指标,然后选择根据该指标实现最佳性能的模型。对于二进制分类,我们可以使用许多指标进行比较,最常用的指标如下:
- 准确度和误差
- 精确度、召回率和 F1
- AUC(澳大利亚)
我们使用这些指标来观察模型对新的未知数据的概括能力。因此,当数据对模型来说是新的时,对这种情况建模是很重要的。这通常是通过将数据分成几个部分来完成的。因此,我们还将涵盖以下内容:
- 结果评估
- k 倍交叉验证
- 培训、验证和测试
让我们从最直观的评估指标——准确性开始。
准确(性)
准确性是评估分类器最直接的方式:我们进行预测,查看预测的标签,然后将其与实际值进行比较。如果价值观一致,那么模型是正确的。然后,我们可以对我们所有的数据都这样做,看看正确预测的例子的比率;这正是准确性所描述的。因此,准确性告诉我们有多少例子模型预测了正确的标签。计算它是微不足道的:
int n = actual.length;
double[] proba = // predictions;
double[] prediction = Arrays.stream(proba).map(p -> p > threshold ? 1.0 : 0.0).toArray();
int correct = 0;
for (int i = 0; i < n; i++) {
if (actual[i] == prediction[i]) {
correct++;
}
}
double accuracy = 1.0 * correct / n;
准确性是最简单的评估标准,很容易向任何人解释,甚至是非技术人员。
然而,有时,准确性并不是模型性能的最佳衡量标准。接下来我们来看看它的问题是什么,用什么来代替。
精确度、召回率和 F1
在某些情况下,精度值是有欺骗性的:它们表明分类器是好的,尽管它不是。例如,假设我们有一个不平衡的数据集:只有 1%的例子是正面的,其余的(99%)是负面的。然后,一个总是预测为负的模型在 99%的情况下是正确的,因此将具有 0.99 的准确度。但是这个模型并没有用。
除了准确性,还有其他方法可以解决这个问题。精确度和召回率都在这些指标中,因为它们都着眼于模型正确识别的积极项目的比例。所以,如果我们有大量的反面例子,我们仍然可以对模型进行一些有意义的评估。
可以使用混淆矩阵来计算精度和召回率,混淆矩阵是一个总结了二元分类器性能的表:
当我们使用二元分类模型来预测某个数据项的实际值时,有四种可能的结果:
- 真正 ( TP ):实际类为正,我们预测为正
- 真负 ( TN ):实际类为负,我们预测为负
- 假阳性 ( FP ):实际类是阴性,我们却说是阳性
- 假阴性 ( FN ):实际类是阳性,我们却说是阴性
前两种情况(TP 和 TN)是正确的预测,实际值和预测值是相同的。最后两种情况(FP 和 FN)是不正确的分类,因为我们无法预测正确的标签。
现在,假设我们有一个带有已知标签的数据集,并对其运行我们的模型。然后,设TP为真正例数,TN为真反例数,以此类推。
然后我们可以使用这些值来计算精度和召回率:
- 精度是模型预测为阳性的所有项目中正确预测为阳性的项目所占的比例。就混淆矩阵而言,精度是
TP / (TP + FP)。 - 回忆是正确预测的阳性项目在实际阳性项目中所占的比例。利用来自混淆矩阵的值,回忆是
TP / (TP + FN)。 - 通常很难决定是应该优化精确度还是召回率。但是还有另一个将精确度和召回率结合成一个数字的指标,它被称为 F1 分数。
为了计算精度和召回率,我们首先需要计算混淆矩阵单元的值:
int tp = 0, tn = 0, fp = 0, fn = 0;
for (int i = 0; i < actual.length; i++) {
if (actual[i] == 1.0 && proba[i] > threshold) {
tp++;
} else if (actual[i] == 0.0 && proba[i] <= threshold) {
tn++;
} else if (actual[i] == 0.0 && proba[i] > threshold) {
fp++;
} else if (actual[i] == 1.0 && proba[i] <= threshold) {
fn++;
}
}
然后,我们可以使用这些值来计算精度和召回率:
double precision = 1.0 * tp / (tp + fp);
double recall = 1.0 * tp / (tp + fn);
最后,f1可以用下面的公式计算:
double f1 = 2 * precision * recall / (precision + recall);
当数据集不平衡时,这些指标非常有用。
ROC 和 AU ROC (AUC)
前面的度量对于产生硬输出的二进制分类器是好的;它们只告诉类是否应该被分配一个积极的标签或消极的。相反,如果我们的模型输出一些分数,使得分数的值越高,项目越有可能是正面的,那么二元分类器被称为排序分类器。
大多数模型可以输出属于某一类的概率,我们可以用它来对例子进行排序,这样积极的东西可能会排在第一位。
ROC 曲线直观地告诉我们一个分级分类器从负面例子中分离正面例子有多好。ROC 曲线的构建方式如下:
- 根据分数对观察值进行排序,然后从原点开始
- 如果观察值为正,则向上;如果观察值为负,则向右。
这样,在理想情况下,我们首先总是向上,然后总是向右,这将产生最佳的 ROC 曲线。在这种情况下,我们可以说,正例与反例的分离是完美的。如果分离不完美,但仍然 OK ,曲线将上升为正例,但有时会在错误分类发生时右转。最后,一个糟糕的分类器将不能区分正例与反例,曲线将在向上和向右之间交替。
让我们看一些例子:
图上的对角线代表基线——随机分类器将达到的性能。曲线离基线越远越好。
不幸的是,在 Java 中没有 ROC 曲线的易用实现。我们自己实现代码并不难。在这里,我们将概述如何做到这一点,你会发现在代码库一章的实现。
所以绘制 ROC 曲线的算法如下:
- 设 POS 为阳性标记的数量,NEG 为阴性标记的数量
- 按分数降序排列数据
- 从(0,0)开始
- 对于排序顺序中的每个示例,
- o 如果示例为正,则在图中上移 1 / POS,
- o 否则,在图表中向右移动 1 /负。
这是一个简化的算法,并假设分数是不同的。如果分数不明显,并且同一个分数有不同的实际标签,就需要做一些调整。
它是在RocCurve类中实现的,您可以在源代码中找到。您可以按如下方式使用它:
RocCurve.plot(actual, prediction);
调用它将创建一个类似于这个的情节:
曲线下的面积表示正例与反例之间的分离程度。如果分离度很好,那么面积会接近 1。但如果分类器不能区分正反例,曲线会绕着随机基线曲线走,面积会接近 0.5 。
曲线下的面积通常缩写为 AUC,或者有时缩写为 AU ROC,以强调该曲线是 ROC 曲线。
AUC 有一个非常好的解释——AUC 的值对应于随机选择的阳性样本得分高于随机选择的阴性样本的概率。自然地,如果这个概率很高,我们的分类器在分离正面和负面例子方面做得很好。
这使得 AUC 成为许多情况下的一种评估指标,特别是当数据集不平衡时,因为一个类别的示例比另一个类别的多得多。
幸运的是,Java 中有 AUC 的实现。例如,它在 Smile 中实现。你可以这样使用它:
double[] predicted = ... //
int[] truth = ... //
double auc = AUC.measure(truth, predicted);
现在,当我们讨论可能的评估指标时,我们需要应用它们来测试我们的模型。我们需要小心处理。如果我们对用于训练的相同数据进行评估,那么评估结果将过于乐观。接下来,我们将看到什么是正确的做法。
结果验证
当从数据中学习时,总是有过度拟合的危险。当模型开始学习数据中的噪声而不是检测有用的模式时,就会发生过度拟合。检查模型是否过度拟合总是很重要的,否则当应用于看不见的数据时,它将是无用的。
检查模型是否过拟合的典型且最实用的方法是模拟看不见的数据,也就是说,取一部分可用的标记数据,不使用它进行训练。
这种技术被称为保留,我们保留一部分数据,仅用于评估。
我们还在分割前打乱原始数据集。在许多情况下,我们会做一个简化的假设,即数据的顺序并不重要,也就是说,一个观察值对另一个观察值没有影响。在这种情况下,在拆分之前打乱数据将会消除项目顺序可能产生的影响。另一方面,如果数据是时间序列数据,那么打乱它不是一个好主意,因为观察值之间存在一些相关性。
那么,让我们实现保持分离。我们假设我们拥有的数据已经用X表示了——一个具有特征的双精度二维数组和y——一个标签一维数组。
首先,我们创建一个助手类来保存数据:
public class Dataset {
private final double[][] X;
private final double[] y;
// constructor and getters are omitted
}
分割我们的数据集应该产生两个数据集,所以我们也为其创建一个类:
public class Split {
private final Dataset train;
private final Dataset test;
// constructor and getters are omitted
}
现在,假设我们想把数据分成两部分:训练和测试。我们还想指定训练集的大小,我们将使用一个testRatio参数:应该进入测试集的项目的百分比。
我们做的第一件事是生成一个带索引的数组,然后根据testRatio对其进行拆分:
int[] indexes = IntStream.range(0, dataset.length()).toArray();
int trainSize = (int) (indexes.length * (1 - testRatio));
int[] trainIndex = Arrays.copyOfRange(indexes, 0, trainSize);
int[] testIndex = Arrays.copyOfRange(indexes, trainSize, indexes.length);
如果需要,我们也可以打乱索引:
Random rnd = new Random(seed);
for (int i = indexes.length - 1; i > 0; i--) {
int index = rnd.nextInt(i + 1);
int tmp = indexes[index];
indexes[index] = indexes[i];
indexes[i] = tmp;
}
然后,我们可以为训练集选择实例,如下所示:
int trainSize = trainIndex.length;
double[][] trainX = new double[trainSize][];
double[] trainY = new double[trainSize];
for (int i = 0; i < trainSize; i++) {
int idx = trainIndex[i];
trainX[i] = X[idx];
trainY[i] = y[idx];
}
最后,将它包装到我们的Dataset类中:
Dataset train = new Dataset(trainX, trainY);
如果我们对测试集重复同样的操作,我们可以将训练集和测试集放入一个Split对象中:
Split split = new Split(train, test);
现在我们可以使用 train fold 进行训练,使用 test fold 测试模型。
如果我们把前面所有的代码放到Dataset类的一个函数中,例如trainTestSplit,我们可以如下使用它:
Split split = dataset.trainTestSplit(0.2);
Dataset train = split.getTrain();
// train the model using train.getX() and train.getY()
Dataset test = split.getTest();
// test the model using test.getX(); test.getY();
这里,我们在train数据集上训练一个模型,然后在test集上计算评估度量。
k 倍交叉验证
只提供一部分数据并不总是最好的选择。相反,我们可以做的是将它分成 K 个部分,然后只对第 1/K 个数据测试模型。
这叫做 k 倍交叉验证;它不仅给出了性能估计,而且给出了误差的可能传播。通常,我们感兴趣的是能提供良好和稳定性能的模型。K-fold 交叉验证有助于我们选择这样的模型。
接下来,我们准备用于 k 倍交叉验证的数据,如下所示:
- 首先,将数据分成 K 个部分
- 然后,对于这些零件中的每一个:
- 取一部分作为验证集
- 将剩余的 K-1 零件作为训练集
如果我们把它翻译成 Java,第一步会是这样的:
int[] indexes = IntStream.range(0, dataset.length()).toArray();
int[][] foldIndexes = new int[k][];
int step = indexes.length / k;
int beginIndex = 0;
for (int i = 0; i < k - 1; i++) {
foldIndexes[i] = Arrays.copyOfRange(indexes, beginIndex, beginIndex + step);
beginIndex = beginIndex + step;
}
foldIndexes[k - 1] = Arrays.copyOfRange(indexes, beginIndex, indexes.length);
这为每个 K 折叠创建了一个索引数组。我们也可以像前面一样打乱索引数组。
现在,我们可以从每个折叠创建拆分:
List<Split> result = new ArrayList<>();
for (int i = 0; i < k; i++) {
int[] testIdx = folds[i];
int[] trainIdx = combineTrainFolds(folds, indexes.length, i);
result.add(Split.fromIndexes(dataset, trainIdx, testIdx));
}
在前面的代码中,我们有两个额外的方法:
combineTrainFolds:这个函数接收带有索引的 K-1 个数组,并将它们组合成一个Split.fromIndexes:这将创建一个训练和测试索引的分割。
当我们创建一个简单的保持测试集时,我们已经讨论了第二个功能。
第一个函数combineTrainFolds是这样实现的:
private static int[] combineTrainFolds(int[][] folds, int totalSize, int excludeIndex) {
int size = totalSize - folds[excludeIndex].length;
int result[] = new int[size];
int start = 0;
for (int i = 0; i < folds.length; i++) {
if (i == excludeIndex) {
continue;
}
int[] fold = folds[i];
System.arraycopy(fold, 0, result, start, fold.length);
start = start + fold.length;
}
return result;
}
同样,我们可以将前面的代码放入Dataset类的函数中,并像下面这样调用它:
List<Split> folds = train.kfold(3);
现在,当我们有了一个Split对象的列表时,我们可以创建一个特殊的函数来执行交叉验证:
public static DescriptiveStatistics crossValidate(List<Split> folds,
Function<Dataset, Model> trainer) {
double[] aucs = folds.parallelStream().mapToDouble(fold -> {
Dataset foldTrain = fold.getTrain();
Dataset foldValidation = fold.getTest();
Model model = trainer.apply(foldTrain);
return auc(model, foldValidation);
}).toArray();
return new DescriptiveStatistics(aucs);
}
这个函数的作用是,获取一个折叠列表和一个回调函数,并创建一个模型。在模型被训练之后,我们计算它的 AUC。
此外,我们利用 Java 的并行循环能力,同时在每个折叠上训练模型。
最后,我们将在每个折叠上计算的 AUC 放入一个DescriptiveStatistics对象中,该对象稍后可用于返回 AUC 的平均值和标准差。您可能还记得,DescriptiveStatistics类来自 Apache Commons 数学库。
让我们考虑一个例子。假设我们想要使用来自LIBLINEAR的逻辑回归,并为正则化参数C选择最佳值。我们可以这样使用前面的函数:
double[] Cs = { 0.01, 0.05, 0.1, 0.5, 1.0, 5.0, 10.0 };
for (double C : Cs) {
DescriptiveStatistics summary = crossValidate(folds, fold -> {
Parameter param = new Parameter(SolverType.L1R_LR, C, 0.0001);
return LibLinear.train(fold, param);
});
double mean = summary.getMean();
double std = summary.getStandardDeviation();
System.out.printf("L1 logreg C=%7.3f, auc=%.4f ± %.4f%n", C, mean, std);
}
这里,LibLinear.train是一个助手方法,它接受一个Dataset对象和一个Parameter对象,然后训练一个 LIBLINEAR 模型。这将打印所有提供的C值的 AUC,因此您可以看到哪一个是最好的,并选择具有最高平均 AUC 的一个。
培训、验证和测试
当进行交叉验证时,仍然存在过度拟合的危险。由于我们在同一个验证集上尝试了许多不同的实验,我们可能会意外地选择在验证集上表现良好的模型——但它可能稍后无法推广到看不见的数据。
这个问题的解决方案是在最开始的时候拿出一个测试集,在我们选择出我们认为最好的模型之前,不要碰它。我们只用它来评估最终的模型。
那么,我们如何选择最佳模型呢?我们能做的就是对剩下的训练数据做交叉验证。它可以被保持或 k-fold 交叉验证。一般来说,您应该更喜欢进行 k-fold 交叉验证,因为它还可以提供性能分布,您也可以在模型选择中使用它。
下图说明了该过程:
根据图表,典型的数据科学工作流应该如下所示:
- 0 :选择一些指标进行验证,例如,准确度或 AUC
- 1 :将所有数据分成训练集和测试集
- 2 :进一步拆分训练数据,保留一个验证数据集,或者拆分成 k 个折叠
- 3 :使用验证数据进行选型和参数优化
- 4 :根据验证集选择最佳模型,并对照坚持测试集进行评估
避免过于频繁地查看测试集是很重要的,它应该很少使用,并且只用于最终评估,以确保所选模型不会过度拟合。如果认证方案设置正确,认证分数应与最终测试分数一致。如果发生这种情况,我们可以肯定模型不会过度拟合,并且能够推广到看不见的数据。
使用我们之前创建的类和代码,它可以转换成下面的 Java 代码:
Dataset data = new Dataset(X, y);
Dataset train = split.getTrain();
List<Split> folds = train.kfold(3);
// now use crossValidate(folds, ...) to select the best model
Dataset test = split.getTest();
// do final evaluation of the best model on test
有了这些信息,我们准备做一个关于二元分类的项目。
案例研究-页面预测
现在我们将继续我们运行的例子,搜索引擎。这里我们想做的是尝试预测一个 URL 是否来自搜索引擎结果的第一页。所以,是时候使用我们在这一章中已经介绍过的材料了。
在第二章、数据处理工具箱中,我们创建了以下对象来存储关于页面的信息:
public class RankedPage {
private String url;
private int position;
private int page;
private int titleLength;
private int bodyContentLength;
private boolean queryInTitle;
private int numberOfHeaders;
private int numberOfLinks;
}
首先,我们可以从向该对象添加一些方法开始,如下所示:
isHttps:这将告诉我们该 URL 是否是 HTTPS,是否可以用url.startsWith("https://")实现isComDomain:这应该告诉我们 URL 是否属于 COM 域,以及我们是否可以用url.contains(".com")来实现它isOrgDomain、isNetDomain:与上一个相同,但分别针对 ORG 和 NETnumberOfSlashes:这是 URL 中斜杠字符的个数,可以用番石榴的CharMatcher:CharMatcher.*is*('/').countIn(url)实现
这些模型描述了我们得到的每个 URL,所以我们称之为特征方法,我们可以在我们的机器学习模型中使用这些方法的结果。
如前所述,我们有一个读取 JSON 数据并从中创建 Joinery 数据帧的方法:
List<RankedPage> pages = RankedPageData.readRankedPages();
DataFrame<Object> dataframe = BeanToJoinery.convert(pages, RankedPage.class);
有了数据后,第一步是提取目标变量的值:
List<Object> page = dataframe.col("page");
double[] target = page.stream()
.mapToInt(o -> (int) o)
.mapToDouble(p -> (p == 0) ? 1.0 : 0.0)
.toArray();
为了得到特征矩阵X,我们可以使用 Joinery 为我们创建一个二维数组。首先,我们需要删除一些变量,即目标变量、URL 以及位置,因为位置显然与页面相关。我们可以这样做:
dataframe = dataframe.drop("page", "url", "position");
double[][] X = dataframe.toModelMatrix(0.0);
接下来,我们可以使用我们在本章中创建的Dataset类,并将它分成训练和测试部分:
Dataset dataset = new Dataset(X, target);
Split split = dataset.trainTestSplit(0.2);
Dataset train = split.getTrain();
Dataset test = split.getTest();
此外,对于某些算法,对要素进行标准化很有帮助,这样它们的平均值和单位标准差为零。这样做的原因是为了帮助优化算法更快地收敛。
为此,我们计算矩阵中每一列的平均值和标准偏差,然后从每个值中减去平均值,再除以标准偏差。为了简洁起见,我们在这里省略了这个函数的代码,但是您可以在代码仓库一章中找到它。
下面的代码可以做到这一点:
preprocessor = StandardizationPreprocessor.train(train);
train = preprocessor.transform(train);
test = preprocessor.transform(test);
现在我们准备开始训练不同的模型。让我们先从 Smile 开始尝试逻辑回归实现。我们将使用 k-fold 交叉验证来选择其正则化参数λ的最佳值。
List<Fold> folds = train.kfold(3);
double[] lambdas = { 0, 0.5, 1.0, 5.0, 10.0, 100.0, 1000.0 };
for (double lambda : lambdas) {
DescriptiveStatistics summary = Smile.crossValidate(folds, fold -> {
return new LogisticRegression.Trainer()
.setRegularizationFactor(lambda)
.train(fold.getX(), fold.getYAsInt());
});
double mean = summary.getMean();
double std = summary.getStandardDeviation();
System.out.printf("logreg, λ=%8.3f, auc=%.4f ± %.4f%n", lambda, mean, std);
}
注意这里的Dataset类有一个新方法getYAsInt,它简单地返回表示为整数数组的目标变量。当我们运行它时,它会产生以下输出:
logreg, λ= 0.000, auc=0.5823 ± 0.0041
logreg, λ= 0.500, auc=0.5822 ± 0.0040
logreg, λ= 1.000, auc=0.5820 ± 0.0037
logreg, λ= 5.000, auc=0.5820 ± 0.0030
logreg, λ= 10.000, auc=0.5823 ± 0.0027
logreg, λ= 100.000, auc=0.5839 ± 0.0009
logreg, λ=1000.000, auc=0.5859 ± 0.0036
它显示了λ的值,我们得到的该值的 AUC,以及跨不同折叠的 AUC 的标准偏差。
我们看到我们得到的 AUC 相当低。这不应该是一个惊喜:仅使用我们现在拥有的信息显然不足以完全逆向工程搜索引擎的排名算法。在接下来的章节中,我们将学习如何从页面中提取更多的信息,这些技术将有助于大大增加 AUC。
我们可以注意到的另一件事是,不同λ值的 AUC 非常相似,但其中一个具有最低的标准偏差。在这种情况下,我们应该总是选择方差最小的模型。
我们还可以尝试更复杂的分类器,如 RandomForest:
DescriptiveStatistics rf = Smile.crossValidate(folds, fold -> {
return new RandomForest.Trainer()
.setNumTrees(100)
.setNodeSize(4)
.setSamplingRates(0.7)
.setSplitRule(SplitRule.ENTROPY)
.setNumRandomFeatures(3)
.train(fold.getX(), fold.getYAsInt());
});
System.out.printf("random forest auc=%.4f ± %.4f%n", rf.getMean(), rf.getStandardDeviation());
这将创建以下输出:
random forest auc=0.6093 ± 0.0209
这个分类器平均比逻辑回归分类器好 2%,但是我们也可以注意到标准偏差相当高。因为它高得多,我们可以怀疑,在测试数据上,该模型的表现可能比逻辑回归模型差得多。
接下来,我们也可以尝试训练其他模型。但是,让我们假设我们这样做了,最后我们得出结论,使用lambda=100的逻辑回归给出了最佳性能。然后,我们可以对整个训练数据集进行再训练,然后使用测试集进行最终评估:
LogisticRegression logregFinal = new LogisticRegression.Trainer()
.setRegularizationFactor(100.0)
.train(train.getX(), train.getYAsInt());
double auc = Smile.auc(logregFinal, test);
System.out.printf("final logreg auc=%.4f%n", auc);
该代码产生以下输出:
final logreg auc=0.5807
因此,事实上,我们可以看到,所选模型产生的 AUC 与我们交叉验证中的相同。这是一个很好的迹象,表明该模型可以很好地概括,不会过度填充。
出于好奇,我们还可以检查 RandomForest 模型在训练集上的表现。由于它具有较高的方差,因此它的表现可能比逻辑回归差,但也可能好得多。让我们在整个列车上重新训练它:
RandomForest rfFinal = new RandomForest.Trainer()
.setNumTrees(100)
.setNodeSize(4)
.setSamplingRates(0.7)
.setSplitRule(SplitRule.ENTROPY)
.setNumRandomFeatures(3)
.train(train.getX(), train.getYAsInt());
double auc = Smile.auc(rfFinal, test);
System.out.printf("final rf auc=%.4f%n", finalAuc);
它打印以下内容:
final rf auc=0.5778
因此,事实上,模型的高方差导致测试分数低于交叉验证分数。这不是一个好的迹象,这样的模型不应该是首选。
因此,对于这样的数据集,表现最好的模型是逻辑回归。
如果你想知道如何使用其他机器学习库来解决这个问题,可以查看本章的代码库。在那里,我们为 JSAT、JavaML、LIBSVM、LIBLINEAR 和 Encog 创建了一些例子。
至此,我们结束了本章关于分类的部分,接下来我们将研究另一个被称为回归的监督学习问题。
回归
在机器学习中,回归问题处理标签信息连续的情况。这可以是预测明天的气温、股票价格、一个人的工资或者电子商务网站上一件商品的评级。
有许多模型可以解决衰退问题:
- 普通最小二乘法 ( OLS )就是通常的线性回归
- 岭回归和套索是 OLS 的正则化变体
- 基于树的模型,如 RandomForest
- 神经网络
处理回归问题与处理分类问题非常相似,总体框架保持不变:
- 首先,您选择一个评估指标
- 然后,您将数据分为训练和测试
- 您在训练中训练模型,使用交叉验证调整参数,并使用保留的测试集进行最终验证。
用于回归的机器学习库
我们已经讨论了许多可以处理分类问题的机器学习库。通常,这些库也有回归模型。让我们简单回顾一下。
微笑
Smile 是一个通用的机器学习库,所以它也有回归模型。你可以看看模特名单,这里:github.com/haifengl/sm…。
例如,这是创建简单线性回归的方法:
OLS ols = new OLS(data.getX(), data.getY());
对于正则化回归,可以使用脊或套索:
double lambda = 0.01;
RidgeRegression ridge = new RidgeRegression(data.getX(), data.getY(), lambda);
LASSO lasso = new LASSO(data.getX(), data.getY(), lambda);
使用 RandomForest 与分类情况非常相似:
int nbtrees = 100;
RandomForest rf = new RandomForest.Trainer(nbtrees)
.setNumRandomFeatures(15)
.setMaxNodes(128)
.setNodeSize(10)
.setSamplingRates(0.6)
.train(data.getX(), data.getY());
预测也与分类情况相同。我们需要做的只是使用predict方法:
double result = model.predict(row);
JSAT
JSAT 也是一个通用库,包含许多解决回归问题的实现。
与分类一样,它需要一个用于数据的包装类和一个用于回归的特殊包装:
double[][] X = ... //
double[] y = ... //
List<DataPointPair<Double>> data = new ArrayList<>(X.length);
for (int i = 0; i < X.length; i++) {
DataPoint row = new DataPoint(new DenseVector(X[i]));
data.add(new DataPointPair<Double>(row, y[i]));
}
RegressionDataSet dataset = new RegressionDataSet(data);
一旦数据集被包装在正确的类中,我们就可以像这样训练模型:
MultipleLinearRegression linreg = new MultipleLinearRegression();
linreg.train(dataset);;
前面的代码训练通常的 OLS 线性回归。
与 Smile 不同,当矩阵病态时,OLS 不会产生稳定的解,也就是说,它有一些线性相关的解。在这种情况下,使用正则化模型。
可以使用以下代码来训练正则化线性回归:
RidgeRegression ridge = new RidgeRegression();
ridge.setLambda(lambda);
ridge.train(dataset);
然后,为了预测,我们还需要做一些转换:
double[] row = .. . //
DenseVector vector = new DenseVector(row);
DataPoint point = new DataPoint(vector);
double result = model.regress(point);
其他图书馆
我们之前提到的其他库也有解决回归问题的模型。
例如,在 LIBSVM 中,可以通过将svm_type参数设置为EPSILON_SVR或NU_SVR来进行回归,而代码的其余部分几乎与分类情况相同。同样,在 LIBLINEAR 中,回归问题通过选择L2R_L2LOSS_SVR或L2R_L2LOSS_SVR_DUAL模型来解决。
也可以用神经网络解决回归问题,例如在 Encog 中。您唯一需要更改的是损失函数:您应该使用回归损失函数,比如均方差,而不是最小化分类损失函数(比如logloss)。
因为大部分代码都非常相似,所以没有必要详细介绍。一如既往,我们在章节代码库中准备了一些代码示例,请随意查看。
估价
与分类一样,我们也需要评估模型的结果。有一些指标有助于做到这一点,并选择最佳模型。先来过两个最流行的:均方误差 ( MSE )和平均绝对误差 ( MAE )。
均方误差(mean square error)
均方误差 ( MSE )是实际值和预测值的平方差之和。用 Java 计算它很容易:
double[] actual, predicted;
int n = actual.length;
double sum = 0.0;
for (int i = 0; i < n; i++) {
diff = actual[i] - predicted[i];
sum = sum + diff * diff;
}
double mse = sum / n;
通常,MSE 的值很难解释,这就是为什么我们经常取 MSE 的平方根;这叫做均方根误差 ( RMSE )。它更容易解释,因为它与目标变量使用相同的单位。
double rmse = Math.sqrt(mse);
平均绝对误差
平均绝对误差 ( MAE ),是评估性能的替代指标。它不取误差的平方,而只取实际值和预测值之差的绝对值。我们可以这样计算:
double sum = 0.0;
for (int i = 0; i < n; i++) {
sum = sum + Math.abs(actual[i] - predicted[i]);
}
double mae = sum / n;
有时我们会在数据中发现异常值——非常不规则的值。如果我们有很多离群值,我们应该选择 MAE 而不是 RMSE,因为它对他们更稳健。如果我们没有很多离群值,那么 RMSE 应该是首选。
还有其他指标,如 MAPE 或 RMSE,但它们使用频率较低,因此我们不会涉及它们。
虽然我们只是简单地浏览了一下解决回归问题的库,但是有了从解决分类问题的概述中获得的基础,做一个回归项目就足够了。
案例研究-硬件性能
在这个项目中,我们将尝试预测在不同的计算机上将两个矩阵相乘需要多少时间。
这个项目的数据集最初来自西德涅夫和格尔格尔(2014)的论文自动选择最快的算法实现,并在 Mail.RU 组织的一次机器学习比赛上提供。您可以在 mlbootcamp.ru/championshi…](mlbootcamp.ru/championshi…)
内容是俄语的,所以如果你不会说俄语,最好使用有翻译支持的浏览器。
你会找到数据集的副本以及本章的代码。
该数据集包含以下数据:
m、k、n表示矩阵的维度,m*k为矩阵A的维度,k*n为矩阵B的维度- 硬件特征,如 CPU 速度、内核数量、是否启用超高速缓存以及 CPU 类型
- 操作系统
这个问题的解决方案对研究非常有用,当选择硬件来运行实验时。那样的话。您可以使用该模型来选择应该产生最佳性能的构建。
因此,目标是在给定大小和环境特征的情况下,预测两个矩阵相乘需要多少秒。虽然本文使用 MAPE 作为评估指标,但我们将使用 RMSE,因为它更易于实施和解释。
首先,我们需要读取数据。有两个文件,一个包含特征,一个包含标签。我们先来读一下目标:
DataFrame<Object> targetDf = DataFrame.readCsv("data/performance/y_train.csv");
List<Double> targetList = targetDf.cast(Double.class).col("time");
double[] target = Doubles.toArray(targetList);
接下来,我们来读一下特写:
DataFrame<Object> dataframe = DataFrame.readCsv("data/performance/x_train.csv");
如果我们查看数据,我们会注意到有时缺失的值被编码为一个字符串None。我们需要把它转换成真正的 Java null。为此,我们可以定义一个特殊的函数:
private static List<Object> noneToNull(List<Object> memfreq) {
return memfreq.stream()
.map(s -> isNone(s) ? null : Double.parseDouble(s.toString()))
.collect(Collectors.toList());
}
现在,使用它来处理原始列,然后删除它们,并添加转换后的列:
List<Object> memfreq = noneToNull(dataframe.col("memFreq"));
List<Object> memtRFC = noneToNull(dataframe.col("memtRFC"));
dataframe = dataframe.drop("memFreq", "memtRFC");
dataframe.add("memFreq", memfreq);
dataframe.add("memtRFC", memtRFC);
数据集中有一些分类变量。我们可以看看它们。首先,让我们创建一个数据帧,它包含原始帧的类型:
List<Object> types = dataframe.types().stream()
.map(c -> c.getSimpleName())
.collect(Collectors.toList());
List<Object> columns = new ArrayList<>(dataframe.columns());
DataFrame<Object> typesDf = new DataFrame<>();
typesDf.add("column", columns);
typesDf.add("type", types);
因为我们只对分类值感兴趣,所以我们需要选择类型为String的特性:
DataFrame<Object> stringTypes = typesDf.select(p -> p.get(1).equals("String"));
分类变量在机器学习问题中经常使用的方式被称为虚拟编码,或一种热编码。在这种编码方案中:
- 只要有可能的值,我们就创建尽可能多的列
- 对于每个观察值,我们为该列加上数字
1,它对应于分类变量的值,其余的列得到0
细木工可以为我们自动完成这种转换:
double[][] X = dataframe.toModelMatrix(0.0);
前面的代码将对所有分类变量应用一个热编码方案。
然而,对于我们现有的数据,分类变量的一些值只出现一次或几次。通常,我们对这种不常出现的值不感兴趣,所以我们可以用一些人工值(如OTHER)来替换它们。
在细木工领域,我们是这样做的:
- 从
DataFrame中删除所有分类列 - 对于每一列,我们计算这些值出现的次数,并用
OTHER替换不频繁
让我们把它翻译成 Java 代码。这样我们就得到分类变量:
Object[] columns = stringTypes.col("column").toArray();
DataFrame<Object> categorical = dataframe.retain(columns);
dataframe = dataframe.drop(stringTypes.col("column").toArray());
为了计数,我们可以使用来自番石榴的Multiset集合。然后,我们用OTHER替换不常用的,并将结果放回数据帧:
for (Object column : categorical.columns()) {
List<Object> data = categorical.col(column);
Multiset<Object> counts = HashMultiset.create(data);
List<Object> cleaned = data.stream()
.map(o -> counts.count(o) >= 50 ? o : "OTHER")
.collect(Collectors.toList());
dataframe.add(column, cleaned);
}
在此处理之后,我们可以将数据帧转换成矩阵,并将其放入我们的Dataset对象:
double[][] X = dataframe.toModelMatrix(0.0);
Dataset dataset = new Dataset(X, target);
现在我们准备开始训练模型。同样,我们将使用 Smile 来实现机器学习算法,其他库的代码可在章节代码库中找到。
我们已经决定使用 RMSE 作为评估指标。现在我们需要建立交叉验证方案,并为最终评估提供数据:
Split trainTestSplit = dataset.shuffleSplit(0.3);
Dataset train = trainTestSplit.getTrain();
Dataset test = trainTestSplit.getTest();
List<Split> folds = train.shuffleKFold(3);
我们可以重用我们为分类情况编写的函数,并稍微修改它以适应回归情况:
public static DescriptiveStatistics crossValidate(List<Split> folds,
Function<Dataset, Regression<double[]>> trainer) {
double[] aucs = folds.parallelStream().mapToDouble(fold -> {
Dataset train = fold.getTrain();
Dataset validation = fold.getTest();
Regression<double[]> model = trainer.apply(train);
return rmse(model, validation);
}).toArray();
return new DescriptiveStatistics(aucs);
}
在前面的代码中,我们首先训练一个回归模型,然后在验证数据集上评估它的 RMSE。
在开始建模之前,让我们先来看一个简单的基线解决方案。在回归的情况下,总是预测平均值可以是这样的基线:
private static Regression<double[]> mean(Dataset data) {
double meanTarget = Arrays.stream(data.getY()).average().getAsDouble();
return x -> meanTarget;
}
让我们将它用作基线计算的交叉验证函数:
DescriptiveStatistics baseline = crossValidate(folds, data -> mean(data));
System.out.printf("baseline: rmse=%.4f ± %.4f%n", baseline.getMean(), baseline.getStandardDeviation());
它将以下内容打印到控制台:
baseline: rmse=25.1487 ± 4.3445
我们的基线解平均误差为 25 秒,误差范围为 4.3 秒。
现在我们可以尝试训练一个简单的 OLS 回归:
DescriptiveStatistics ols = crossValidate(folds, data -> {
return new OLS(data.getX(), data.getY());
});
System.out.printf("ols: rmse=%.4f ± %.4f%n", ols.getMean(), ols.getStandardDeviation());
我们应该注意到,Smile 给了我们一个警告,即矩阵不是满秩的,它将使用奇异值分解 ( SVD )来解决 OLS 问题。我们可以忽略它,或者明确地告诉它使用 SVD:
new OLS(data.getX(), data.getY(), true);
在任一情况下,它都会将以下内容打印到控制台:
ols: rmse=15.8679 ± 3.4587
当我们使用正则化模型时,我们通常不担心相关列。让我们用不同的lambda值来尝试套索:
double[] lambdas = { 0.1, 1, 10, 100, 1000, 5000, 10000, 20000 };
for (double lambda : lambdas) {
DescriptiveStatistics summary = crossValidate(folds, data -> {
return new LASSO(data.getX(), data.getY(), lambda);
});
double mean = summary.getMean();
double std = summary.getStandardDeviation();
System.out.printf("lasso λ=%9.1f, rmse=%.4f ± %.4f%n", lambda, mean, std);
}
它产生以下输出:
lasso λ= 0.1, rmse=15.8679 ± 3.4587
lasso λ= 1.0, rmse=15.8678 ± 3.4588
lasso λ= 10.0, rmse=15.8650 ± 3.4615
lasso λ= 100.0, rmse=15.8533 ± 3.4794
lasso λ= 1000.0, rmse=15.8650 ± 3.5905
lasso λ= 5000.0, rmse=16.1321 ± 3.9813
lasso λ= 10000.0, rmse=16.6793 ± 4.3830
lasso λ= 20000.0, rmse=18.6088 ± 4.9315
请注意,Smile 版本 1.1.0 中的 LASSO 实现对该数据集有问题,因为存在线性相关的列。为了避免这种情况,您应该使用 1.2.0 版本,在编写本文时,Maven Central 还没有提供该版本,如果您想使用它,您需要自己构建它。我们已经讨论过如何做到这一点。
我们也可以尝试 RidgeRegression,但它的性能与 OLS 和拉索非常相似,所以我们在这里将省略它。
看起来 OLS 的结果与套索没有太大的不同,所以我们选择它作为最终模型并使用它,因为它是最简单的模型:
OLS ols = new OLS(train.getX(), train.getY(), true);
double testRmse = rmse(lasso, test);
System.out.printf("final rmse=%.4f%n", testRmse);
这为我们提供了以下输出:
final rmse=15.0722
因此,所选模型的性能与我们的交叉验证一致,这意味着该模型能够很好地推广到未知数据。
摘要
在这一章中,我们谈到了监督机器学习和两个常见的监督问题:分类和回归。我们还介绍了常用算法库,实现了它们,并学习了如何评估这些算法的性能。
还有另一类不需要标签信息的机器学习算法;这些方法被称为无监督学习——在下一章,我们将会谈到它们。