Spark-下一代机器学习教程-二-

187 阅读35分钟

Spark 下一代机器学习教程(二)

原文:Next-Generation Machine Learning with Spark

协议:CC BY-NC-SA 4.0

三、监督学习

最可靠的知识是你自己构建的。

—朱迪亚珍珠 i

监督学习是一种使用训练数据集进行预测的机器学习任务。监督学习可以分为分类或回归。回归用于预测价格、温度或距离等连续值,而分类用于预测是或否、垃圾邮件或非垃圾邮件、恶性或良性等类别。

分类

分类可能是最常见的监督机器学习任务。您很可能已经遇到过在没有意识到的情况下利用分类的应用程序。常见的使用案例包括医疗诊断、定向营销、垃圾邮件检测、信用风险预测和情感分析等。有三种类型的分类任务。

二元分类

如果只有两个类别,则任务是二元或二项式分类。例如,当使用二进制分类算法进行垃圾邮件检测时,输出变量可以有两个类别:垃圾邮件或非垃圾邮件。为了检测癌症,分类可以是恶性的或良性的。对于有针对性的营销,预测某人购买诸如牛奶等物品的可能性,分类可以简单地是是或否。

多类分类

多类或多项分类任务有三个或更多类别。例如,要预测天气状况,您可能有五个类别:下雨、多云、晴天、下雪和刮风。为了扩展我们的目标营销示例,多类别分类可用于预测客户是否更有可能购买全脂牛奶、低脂牛奶、低脂牛奶或脱脂牛奶。

多标签分类

在多标签分类中,可以为每个观察值分配多个类别。相比之下,在多类别分类中,只能将一个类别分配给一个观察。使用我们的目标营销示例,多标签分类不仅用于预测客户是否更有可能购买牛奶,还用于预测其他商品,如饼干、黄油、热狗或面包。

Spark MLlib 分类算法

Spark MLlib 包括几个分类算法。我将讨论最流行的算法,并提供基于我们在第二章中所介绍的易于理解的代码示例。在本章的后面,我将讨论更高级的下一代算法,比如 XGBoost 和 LightGBM。

逻辑回归

逻辑回归是预测概率的线性分类器。它使用逻辑(sigmoid)函数将其输出转换为可以映射到两个(二元)类的概率值。通过多项式逻辑回归(softmax)支持多类分类。 ii 我们将在本章后面的一个例子中使用逻辑回归。

支持向量机

支持向量机是一种流行的算法,它通过寻找使两个类之间的间隔最大化的最佳超平面来工作,通过尽可能宽的间隙将数据点分成单独的类。最接近分类边界的数据点称为支持向量(图 3-1 )。

img/488426_1_En_3_Fig1_HTML.png

图 3-1

寻找最大化两个类之间的间隔的最优超平面【iii】

奈伊夫拜厄斯

朴素贝叶斯是一种基于贝叶斯定理的简单多类线性分类算法。朴素贝叶斯之所以得名,是因为它天真地假设数据集中的要素是独立的,忽略了要素之间任何可能的相关性。我们将在本章后面的情感分析示例中使用朴素贝叶斯。

多层感知器

多层感知器是一个前馈人工网络,由几个完全连接的节点层组成。输入图层中的节点对应于输入数据集。中间层中的节点使用逻辑(sigmoid)函数,而最终输出层中的节点使用 softmax 函数来支持多类分类。输出层中的节点数量必须与类的数量相匹配。 iv 我在第七章讨论多人感知机。

决策树

决策树通过学习从输入变量推断出的决策规则来预测输出变量的值。

从视觉上看,决策树看起来就像一棵倒置的树,根节点在顶部。每个内部节点都代表对一个属性的测试。叶节点代表一个类标签,而单个分支代表一个测试的结果。图 3-2 显示了一个预测信用风险的决策树。

img/488426_1_En_3_Fig2_HTML.png

图 3-2

预测信用风险的决策树

决策树执行特征空间的递归二元分裂。为了最大化信息增益,从一组可能的分裂中选择产生最大杂质减少的分裂。通过从父节点的杂质中减去子节点杂质的加权和来计算信息增益。子节点的杂质越低,信息增益越大。分裂继续进行,直到达到最大树深度(由 maxDepth 参数设置),不再获得大于 minInfoGain 的信息增益,或者 minInstancesPerNode 等于每个子节点产生的训练实例。

有两个用于分类的杂质度量(基尼杂质和熵)和一个用于回归的杂质度量(方差)。对于分类,Spark MLlib 中杂质的默认度量是基尼杂质。基尼系数是一个量化节点纯度的指标。如果基尼系数等于零(节点是纯的),则在一个节点内存在单个数据类别。如果基尼系数大于零,则意味着该节点包含属于不同类别的数据。

决策树很容易解释。与像逻辑回归这样的线性模型相比,决策树不需要特征缩放。它能够处理缺失的特征,并处理连续和分类特征。 v 独热编码分类特征 vi 在使用决策树和基于树的集成时不是必需的,事实上是不鼓励的。独热编码创建了不平衡的树,并且要求树生长得非常深以实现良好的预测性能。对于高基数分类特征来说尤其如此。

不利的一面是,决策树对数据中的噪声很敏感,有过度拟合的倾向。由于这种限制,决策树本身很少在现实生产环境中使用。如今,决策树是更强大的集成算法的基础模型,如随机森林和梯度提升树。

随机森林

随机森林是一种集成算法,它使用一组决策树进行分类和回归。它使用一种叫做 bagging (或 bootstrap aggregation)的方法来减少方差,同时保持低偏差。Bagging 从训练数据的子集训练单独的树。除了装袋,兰登森林还采用了另一种叫做的方法装袋。与 bagging(使用观测值的子集)相反,特征 bagging 使用特征(列)的子集。特征装袋旨在减少决策树之间的相关性。如果没有特征打包,单个树将会非常相似,尤其是在只有几个主要特征的情况下。

对于分类,单个树的输出或模式的多数投票成为模型的最终预测。对于回归,单棵树输出的平均值成为最终输出(图 3-3 )。Spark 并行训练几棵树,因为每棵树都是在随机森林中独立训练的。我将在本章后面更详细地讨论随机森林。

img/488426_1_En_3_Fig3_HTML.png

图 3-3

用于分类的随机森林

梯度提升树

梯度推进树(GBT)是另一种类似于随机森林的基于树的集成算法。gbt 使用一种称为 boosting to 的技术从弱学习者(浅树)中创建强学习者。GBTs 按顺序训练一组决策树 vii ,每一棵后继的树减少前一棵树的误差。这是通过使用前一个模型的残差来拟合下一个模型来完成的。 viii 该残差校正过程 ix 被执行设定的迭代次数,迭代次数由交叉验证确定,直到残差被完全最小化。

img/488426_1_En_3_Fig4_HTML.png

图 3-4

GBTs 中的决策树集成

图 3-4 显示了决策树集成在 GBTs 中是如何工作的。以我们的信用风险为例,个人根据其信用度被分为不同的类别。决策树中的每一片叶子都被分配了一个分数。将多个树的得分相加,得到最终的预测得分。例如,图 3-4 显示了第一个决策树给了这个女人 3 分。第二棵树给了她 2 分。把两个分数加在一起,这个女人的最终分数是 5 分。请注意,决策树是相互补充的。这是 GBTs 的主要原则之一。将分数与每个叶子相关联为 GBTs 提供了一种集成的优化方法。?? x

随机森林与梯度提升树

由于梯度提升树是按顺序训练的,因此通常认为它比随机森林慢,扩展性差,随机森林能够并行训练多棵树。然而,与随机森林相比,gbt 通常使用更浅的树,这意味着 gbt 可以更快地训练。

增加 GBTs 中的树的数量会增加过度拟合的机会(GBTs 通过利用更多的树来减少偏差),而增加随机森林中的树的数量会减少过度拟合的机会(随机森林通过利用更多的树来减少方差)。一般来说,在随机森林中添加更多的树可以提高性能,而当树的数量开始变得太大时,GBTs 的性能就会开始下降。正因为如此,GBTs 可能比随机森林更难调。

如果参数调整正确,梯度提升树通常被认为比随机森林更强大。GBTs 添加了新的决策树,补充了以前构建的决策树,与随机森林相比,使用更少的树可以获得更好的预测准确性。XII

近年来开发的大多数用于分类和回归的新算法,如 XGBoost 和 LightGBM,都是 GBTs 的改进版本。它们没有传统 gbt 的局限性。

第三方分类和回归算法

无数开源贡献者为 Spark 开发第三方机器学习算法投入了时间和精力。虽然它们不是核心 Spark MLlib 库的一部分,但 Databricks (XGBoost)和微软(LightGBM)等公司已经支持这些项目,并在世界各地广泛使用。XGBoost 和 LightGBM 目前被认为是用于分类和回归的下一代机器学习算法。在精度和速度至关重要的情况下,它们是首选算法。我将在本章后面讨论这两个问题。现在,让我们动手做一些事情,深入一些例子。

用逻辑回归进行多类分类

逻辑回归是预测概率的线性分类器。它因易于使用和训练速度快而受欢迎,经常用于二分类和多类分类。如图 3-5 的第一个图表所示,当您的数据具有清晰的决策边界时,逻辑回归等线性分类器是合适的。在类不是线性可分的情况下(如第二个图表所示),应该考虑非线性分类器,如基于树的集成。

img/488426_1_En_3_Fig5_HTML.png

图 3-5

线性与非线性分类问题

例子

我们将使用流行的 Iris 数据集(参见清单 3-1 )解决第一个例子中的多类分类问题。该数据集包含三个类,每个类 50 个实例,其中每个类涉及一种鸢尾植物(鸢尾、杂色鸢尾和海滨鸢尾)。从图 3-6 中可以看出,刚毛鸢尾与杂色鸢尾和海滨鸢尾是线性分离的,但是杂色鸢尾和海滨鸢尾彼此不是线性分离的。逻辑回归在数据集分类方面仍然做得不错。

img/488426_1_En_3_Fig6_HTML.png

图 3-6

虹膜数据集的主成分分析投影

我们的目标是在给定一组特征的情况下预测鸢尾植物的类型。数据集包含四个数字特征:萼片长度、萼片宽度、花瓣长度和花瓣宽度(均以厘米为单位)。

// Create a schema for our data.
import org.apache.spark.sql.types._

var irisSchema = StructType(Array (
    StructField("sepal_length",   DoubleType, true),
    StructField("sepal_width",   DoubleType, true),
    StructField("petal_length",   DoubleType, true),
    StructField("petal_width",   DoubleType, true),
    StructField("class",  StringType, true)

    ))

// Read the CSV file. Use the schema that we just defined.

val dataDF = spark.read.format("csv")
             .option("header","false")
             .schema(irisSchema)
             .load("/files/iris.data")

// Check the schema.

dataDF.printSchema

root
 |-- sepal_length: double (nullable = true)
 |-- sepal_width: double (nullable = true)
 |-- petal_length: double (nullable = true)
 |-- petal_width: double (nullable = true)
 |-- class: string (nullable = true)

// Inspect the data to make sure they’re in the correct format.

dataDF.show

+------------+-----------+------------+-----------+-----------+
|sepal_length|sepal_width|petal_length|petal_width|      class|
+------------+-----------+------------+-----------+-----------+
|         5.1|        3.5|         1.4|        0.2|Iris-setosa|
|         4.9|        3.0|         1.4|        0.2|Iris-setosa|
|         4.7|        3.2|         1.3|        0.2|Iris-setosa|
|         4.6|        3.1|         1.5|        0.2|Iris-setosa|
|         5.0|        3.6|         1.4|        0.2|Iris-setosa|
|         5.4|        3.9|         1.7|        0.4|Iris-setosa|
|         4.6|        3.4|         1.4|        0.3|Iris-setosa|
|         5.0|        3.4|         1.5|        0.2|Iris-setosa|
|         4.4|        2.9|         1.4|        0.2|Iris-setosa|
|         4.9|        3.1|         1.5|        0.1|Iris-setosa|
|         5.4|        3.7|         1.5|        0.2|Iris-setosa|
|         4.8|        3.4|         1.6|        0.2|Iris-setosa|
|         4.8|        3.0|         1.4|        0.1|Iris-setosa|
|         4.3|        3.0|         1.1|        0.1|Iris-setosa|
|         5.8|        4.0|         1.2|        0.2|Iris-setosa|
|         5.7|        4.4|         1.5|        0.4|Iris-setosa|
|         5.4|        3.9|         1.3|        0.4|Iris-setosa|
|         5.1|        3.5|         1.4|        0.3|Iris-setosa|
|         5.7|        3.8|         1.7|        0.3|Iris-setosa|
|         5.1|        3.8|         1.5|        0.3|Iris-setosa|
+------------+-----------+------------+-----------+-----------+
only showing top 20 rows

// Calculate summary statistics for our data. This can
// be helpful in understanding the distribution of your data.

dataDF.describe().show(5,15)

+-------+---------------+---------------+---------------+---------------+
|summary|   sepal_length|    sepal_width|   petal_length|    petal_width|
+-------+---------------+---------------+---------------+---------------+
|  count|            150|            150|            150|            150|
|   mean|5.8433333333...|3.0540000000...|3.7586666666...|1.1986666666...|
| stddev|0.8280661279...|0.4335943113...|1.7644204199...|0.7631607417...|
|    min|            4.3|            2.0|            1.0|            0.1|
|    max|            7.9|            4.4|            6.9|            2.5|
+-------+---------------+---------------+---------------+---------------+

+--------------+
|         class|
+--------------+
|           150|
|          null|
|          null|
|   Iris-setosa|
|Iris-virginica|
+--------------+

// The input column class is currently a string. We'll use
// StringIndexer to encode it into a double. The new value
// will be stored in the new output column called label.

import org.apache.spark.ml.feature.StringIndexer

val labelIndexer = new StringIndexer()
                  .setInputCol("class")
                  .setOutputCol("label")

val dataDF2 = labelIndexer
             .fit(dataDF)
             .transform(dataDF)

// Check the schema of the new DataFrame.

dataDF2.printSchema

root
 |-- sepal_length: double (nullable = true)
 |-- sepal_width: double (nullable = true)
 |-- petal_length: double (nullable = true)
 |-- petal_width: double (nullable = true)
 |-- class: string (nullable = true)
 |-- label: double (nullable = false)

// Inspect the new column added to the DataFrame.

dataDF2.show

+------------+-----------+------------+-----------+-----------+-----+
|sepal_length|sepal_width|petal_length|petal_width|      class|label|
+------------+-----------+------------+-----------+-----------+-----+
|         5.1|        3.5|         1.4|        0.2|Iris-setosa|  0.0|
|         4.9|        3.0|         1.4|        0.2|Iris-setosa|  0.0|
|         4.7|        3.2|         1.3|        0.2|Iris-setosa|  0.0|
|         4.6|        3.1|         1.5|        0.2|Iris-setosa|  0.0|
|         5.0|        3.6|         1.4|        0.2|Iris-setosa|  0.0|
|         5.4|        3.9|         1.7|        0.4|Iris-setosa|  0.0|
|         4.6|        3.4|         1.4|        0.3|Iris-setosa|  0.0|
|         5.0|        3.4|         1.5|        0.2|Iris-setosa|  0.0|
|         4.4|        2.9|         1.4|        0.2|Iris-setosa|  0.0|
|         4.9|        3.1|         1.5|        0.1|Iris-setosa|  0.0|
|         5.4|        3.7|         1.5|        0.2|Iris-setosa|  0.0|
|         4.8|        3.4|         1.6|        0.2|Iris-setosa|  0.0|
|         4.8|        3.0|         1.4|        0.1|Iris-setosa|  0.0|
|         4.3|        3.0|         1.1|        0.1|Iris-setosa|  0.0|
|         5.8|        4.0|         1.2|        0.2|Iris-setosa|  0.0|
|         5.7|        4.4|         1.5|        0.4|Iris-setosa|  0.0|
|         5.4|        3.9|         1.3|        0.4|Iris-setosa|  0.0|
|         5.1|        3.5|         1.4|        0.3|Iris-setosa|  0.0|
|         5.7|        3.8|         1.7|        0.3|Iris-setosa|  0.0|
|         5.1|        3.8|         1.5|        0.3|Iris-setosa|  0.0|
+------------+-----------+------------+-----------+-----------+-----+
only showing top 20 rows

// Combine the features into a single vector
// column using the VectorAssembler transformer.

import org.apache.spark.ml.feature.VectorAssembler

val features = Array("sepal_length","sepal_width","petal_length","petal_width")

val assembler = new VectorAssembler()
                .setInputCols(features)
                .setOutputCol("features")

val dataDF3 = assembler.transform(dataDF2)

// Inspect the new column added to the DataFrame.

dataDF3.printSchema

root
 |-- sepal_length: double (nullable = true)
 |-- sepal_width: double (nullable = true)
 |-- petal_length: double (nullable = true)
 |-- petal_width: double (nullable = true)
 |-- class: string (nullable = true)
 |-- label: double (nullable = false)
 |-- features: vector (nullable = true)

// Inspect the new column added to the DataFrame.

dataDF3.show

+------------+-----------+------------+-----------+-----------+-----+
|sepal_length|sepal_width|petal_length|petal_width|      class|label|
+------------+-----------+------------+-----------+-----------+-----+
|         5.1|        3.5|         1.4|        0.2|Iris-setosa|  0.0|
|         4.9|        3.0|         1.4|        0.2|Iris-setosa|  0.0|
|         4.7|        3.2|         1.3|        0.2|Iris-setosa|  0.0|
|         4.6|        3.1|         1.5|        0.2|Iris-setosa|  0.0|
|         5.0|        3.6|         1.4|        0.2|Iris-setosa|  0.0|
|         5.4|        3.9|         1.7|        0.4|Iris-setosa|  0.0|
|         4.6|        3.4|         1.4|        0.3|Iris-setosa|  0.0|
|         5.0|        3.4|         1.5|        0.2|Iris-setosa|  0.0|
|         4.4|        2.9|         1.4|        0.2|Iris-setosa|  0.0|
|         4.9|        3.1|         1.5|        0.1|Iris-setosa|  0.0|
|         5.4|        3.7|         1.5|        0.2|Iris-setosa|  0.0|
|         4.8|        3.4|         1.6|        0.2|Iris-setosa|  0.0|
|         4.8|        3.0|         1.4|        0.1|Iris-setosa|  0.0|
|         4.3|        3.0|         1.1|        0.1|Iris-setosa|  0.0|
|         5.8|        4.0|         1.2|        0.2|Iris-setosa|  0.0|
|         5.7|        4.4|         1.5|        0.4|Iris-setosa|  0.0|
|         5.4|        3.9|         1.3|        0.4|Iris-setosa|  0.0|
|         5.1|        3.5|         1.4|        0.3|Iris-setosa|  0.0|
|         5.7|        3.8|         1.7|        0.3|Iris-setosa|  0.0|
|         5.1|        3.8|         1.5|        0.3|Iris-setosa|  0.0|
+------------+-----------+------------+-----------+-----------+-----+

+-----------------+
|         features|
+-----------------+
|[5.1,3.5,1.4,0.2]|
|[4.9,3.0,1.4,0.2]|
|[4.7,3.2,1.3,0.2]|
|[4.6,3.1,1.5,0.2]|
|[5.0,3.6,1.4,0.2]|
|[5.4,3.9,1.7,0.4]|
|[4.6,3.4,1.4,0.3]|
|[5.0,3.4,1.5,0.2]|
|[4.4,2.9,1.4,0.2]|
|[4.9,3.1,1.5,0.1]|
|[5.4,3.7,1.5,0.2]|
|[4.8,3.4,1.6,0.2]|
|[4.8,3.0,1.4,0.1]|
|[4.3,3.0,1.1,0.1]|
|[5.8,4.0,1.2,0.2]|
|[5.7,4.4,1.5,0.4]|
|[5.4,3.9,1.3,0.4]|
|[5.1,3.5,1.4,0.3]|
|[5.7,3.8,1.7,0.3]|
|[5.1,3.8,1.5,0.3]|
+-----------------+

only showing top 20 rows

// Let's measure the statistical dependence between
// the features and the class using Pearson correlation.

dataDF3.stat.corr("petal_length","label")
res48: Double = 0.9490425448523336

dataDF3.stat.corr("petal_width","label")
res49: Double = 0.9564638238016178

dataDF3.stat.corr("sepal_length","label")
res50: Double = 0.7825612318100821

dataDF3.stat.corr("sepal_width","label")
res51: Double = -0.41944620026002677

// The petal_length and petal_width have extremely high class correlation, // while sepal_length and sepal_width have low class correlation.
// As discussed in Chapter 2, correlation evaluates how strong the linear // relationship between two variables. You can use correlation to select
// relevant features (feature-class correlation) and identify redundant
// features (intra-feature correlation).

// Divide our dataset into training and test datasets.
val seed = 1234

val Array(trainingData, testData) = dataDF3.randomSplit(Array(0.8, 0.2), seed)

// We can now fit a model on the training dataset
// using logistic regression.

import org.apache.spark.ml.classification.LogisticRegression

val lr = new LogisticRegression()

// Train a model using our training dataset.

val model = lr.fit(trainingData)

// Predict on our test dataset.

val predictions = model.transform(testData)

// Note the new columns added to the DataFrame:
// rawPrediction, probability, prediction.

predictions.printSchema

root
 |-- sepal_length: double (nullable = true)
 |-- sepal_width: double (nullable = true)
 |-- petal_length: double (nullable = true)
 |-- petal_width: double (nullable = true)
 |-- class: string (nullable = true)
 |-- label: double (nullable = false)
 |-- features: vector (nullable = true)
 |-- rawPrediction: vector (nullable = true)
 |-- probability: vector (nullable = true)
 |-- prediction: double (nullable = false)

// Inspect the predictions.

predictions.select("sepal_length","sepal_width",
"petal_length","petal_width","label","prediction").show

+------------+-----------+------------+-----------+-----+----------+
|sepal_length|sepal_width|petal_length|petal_width|label|prediction|
+------------+-----------+------------+-----------+-----+----------+
|         4.3|        3.0|         1.1|        0.1|  0.0|       0.0|
|         4.4|        2.9|         1.4|        0.2|  0.0|       0.0|
|         4.4|        3.0|         1.3|        0.2|  0.0|       0.0|
|         4.8|        3.1|         1.6|        0.2|  0.0|       0.0|
|         5.0|        3.3|         1.4|        0.2|  0.0|       0.0|
|         5.0|        3.4|         1.5|        0.2|  0.0|       0.0|
|         5.0|        3.6|         1.4|        0.2|  0.0|       0.0|
|         5.1|        3.4|         1.5|        0.2|  0.0|       0.0|
|         5.2|        2.7|         3.9|        1.4|  1.0|       1.0|
|         5.2|        4.1|         1.5|        0.1|  0.0|       0.0|
|         5.3|        3.7|         1.5|        0.2|  0.0|       0.0|
|         5.6|        2.9|         3.6|        1.3|  1.0|       1.0|
|         5.8|        2.8|         5.1|        2.4|  2.0|       2.0|
|         6.0|        2.2|         4.0|        1.0|  1.0|       1.0|
|         6.0|        2.9|         4.5|        1.5|  1.0|       1.0|
|         6.0|        3.4|         4.5|        1.6|  1.0|       1.0|
|         6.2|        2.8|         4.8|        1.8|  2.0|       2.0|
|         6.2|        2.9|         4.3|        1.3|  1.0|       1.0|
|         6.3|        2.8|         5.1|        1.5|  2.0|       1.0|
|         6.7|        3.1|         5.6|        2.4|  2.0|       2.0|
+------------+-----------+------------+-----------+-----+----------+
only showing top 20 rows

// Inspect the rawPrediction and probability columns.

predictions.select("rawPrediction","probability","prediction")
           .show(false)

+------------------------------------------------------------+
|rawPrediction                                               |
+------------------------------------------------------------+
|[-27765.164694901094,17727.78535517628,10037.379339724806]  |
|[-24491.649758932126,13931.526474094646,10560.123284837473] |
|[20141.806983153703,1877.784589255676,-22019.591572409383]  |
|[-46255.06332259462,20994.503038678085,25260.560283916537]  |
|[25095.115980666546,110.99834659454791,-25206.114327261093] |
|[-41011.14350152455,17036.32945903473,23974.814042489823]   |
|[20524.55747106708,1750.139974552606,-22274.697445619684]   |
|[29601.783587714817,-1697.1845083924927,-27904.599079322325]|
|[38919.06696252647,-5453.963471106039,-33465.10349142042]   |
|[-39965.27448934488,17725.41646382807,22239.85802551682]    |
|[-18994.667253235268,12074.709651218403,6919.957602016859]  |
|[-43236.84898013162,18023.80837865029,25213.040601481334]   |
|[-31543.179893646557,16452.928101990834,15090.251791655724] |
|[-21666.087284218,13802.846783092147,7863.24050112584]      |
|[-24107.97243292983,14585.93668397567,9522.035748954155]    |
|[25629.52586174148,-192.40731255107312,-25437.11854919041]  |
|[-14271.522512385294,11041.861803401871,3229.660708983418]  |
|[-16548.06114507441,10139.917257827732,6408.143887246673]   |
|[22598.60355651257,938.4220993796007,-23537.025655892172]   |
|[-40984.78286289556,18297.704445848023,22687.078417047538]  |
+------------------------------------------------------------+

+-------------+----------+
|probability  |prediction|
+-------------+----------+
|[0.0,1.0,0.0]|1.0       |
|[0.0,1.0,0.0]|1.0       |
|[1.0,0.0,0.0]|0.0       |
|[0.0,0.0,1.0]|2.0       |
|[1.0,0.0,0.0]|0.0       |
|[0.0,0.0,1.0]|2.0       |
|[1.0,0.0,0.0]|0.0       |
|[1.0,0.0,0.0]|0.0       |
|[1.0,0.0,0.0]|0.0       |
|[0.0,0.0,1.0]|2.0       |
|[0.0,1.0,0.0]|1.0       |
|[0.0,1.0,0.0]|1.0       |
|[0.0,1.0,0.0]|1.0       |
|[1.0,0.0,0.0]|0.0       |
|[0.0,1.0,0.0]|1.0       |
|[0.0,1.0,0.0]|1.0       |
|[1.0,0.0,0.0]|0.0       |
|[0.0,0.0,1.0]|2.0       |
+-------------+----------+
only showing top 20 rows

// Evaluate the model. Several evaluation metrics are available
// for multiclass classification: f1 (default), accuracy,
// weightedPrecision, and weightedRecall.
// I discuss evaluation metrics in more detail in Chapter 2.

import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator

val evaluator = new MulticlassClassificationEvaluator().setMetricName("f1")

val f1 = evaluator.evaluate(predictions)

f1: Double = 0.958119658119658

val wp = evaluator.setMetricName("weightedPrecision").evaluate(predictions)

wp: Double = 0.9635416666666667

val wr = evaluator.setMetricName("weightedRecall").evaluate(predictions)

wr: Double = 0.9583333333333335

val accuracy = evaluator.setMetricName("accuracy").evaluate(predictions)

accuracy: Double = 0.9583333333333334

Listing 3-1Classification Using Logistic Regression

逻辑回归是一种流行的分类算法,由于其速度和简单性,经常被用作第一基线算法。对于生产使用,更高级的基于树的集成通常是首选,因为它们具有更高的准确性和捕捉数据集中复杂非线性关系的能力。

基于随机森林的流失预测

随机森林是一种强大的集成学习算法,建立在多个决策树作为基础模型上,每个决策树都在不同的数据自举子集上并行训练。如前所述,决策树容易过度拟合。随机森林通过使用一种称为bagging(bootstrap aggregation)的技术用随机选择的数据子集训练每个决策树来解决过度拟合问题。装袋减少了模型的方差,有助于避免过度拟合。随机森林在不增加偏差的情况下减少了模型的方差。它还执行特征打包,为每个决策树随机选择特征。特征打包的目标是减少单个树之间的相关性。

对于分类,最终类别通过多数投票决定。由各个决策树产生的类的模式(最频繁出现的)成为最终的类。对于回归,单个决策树输出的平均值成为模型的最终输出。

因为随机森林使用决策树作为它的基础模型,它继承了它的大部分特性。它能够处理连续和分类特征,并且不需要特征缩放和一次性编码。随机森林在不平衡数据上表现也很好,因为它的分层性质迫使它们同时处理这两类数据。最后,随机森林可以捕捉因变量和自变量之间的非线性关系。

由于其可解释性、准确性和灵活性,随机森林是最流行的基于树的分类和回归集成算法之一。然而,训练随机森林模型可能是计算密集型的(这使得它非常适合在 Hadoop 或 Spark 等多核和分布式环境中进行并行化)。与逻辑回归或朴素贝叶斯等线性模型相比,它需要更多的内存和计算资源。此外,随机森林往往在文本或基因组数据等高维数据上表现不佳。

Note

CSIRO 生物信息学团队开发了一个高度可扩展的随机森林实现,最初是为高维基因组数据设计的,称为 VariantSpark RF。XIIIVariantSpark RF 可以处理数百万个功能,并在基准测试中显示出 xiv 比 MLlib 的随机森林实现具有更高的可扩展性。关于 VariantSpark RF 的更多信息可以在 CSIRO 的生物信息学网站上找到。ReForeSt 是另一个高度可扩展的随机森林实现,由意大利热那亚大学 DIBRIS 的 SmartLab 研究实验室开发。 xv ReForeSt 可以处理数百万个特征,并支持随机森林旋转,这是一种新的集成方法,扩展了经典的随机森林算法。XVI

因素

随机森林相对容易调优。适当设置几个重要参数 xvii 往往就足以成功使用随机森林。

  • max_depth: 指定树的最大深度。为 max_depth 设置一个较高的值可以使模型更具表现力,但将其设置得太高可能会增加过度拟合的可能性,并使模型更复杂。

  • num_trees: 指定适合的树的数量。增加树的数量会减少方差,通常会提高准确性。增加树的数量会减慢训练的速度。在某个点之外添加更多的树可能不会提高准确性。

  • FeatureSubsetStrategy: 指定用于在每个节点进行分割的要素部分。设置该参数可以提高训练速度。

  • subsamplingRate: 指定将被选择用于训练每棵树的数据部分。设置此参数可以提高训练速度,并有助于防止过度拟合。将其设置得太低可能会导致拟合不足。

我提供了一些通用的指南,但是像往常一样,强烈建议执行参数网格搜索来确定这些参数的最佳值。有关随机森林参数的完整列表,请参考 Spark MLlib 的在线文档。

例子

客户流失预测是银行、保险公司、电信公司、有线电视运营商以及网飞、Hulu、Spotify 和 Apple Music 等流媒体服务的一个重要分类用例。能够预测更有可能取消订阅其服务的客户的公司可以实现更有效的客户保留策略。留住客户是有价值的。根据一家领先的客户参与分析公司的研究,客户流失每年给美国企业造成约 1360 亿美元的损失。贝恩公司所做的研究显示,客户保持率仅提高 5%,利润就会增加 25%到 95%。Lee Resource Inc .提供的另一项统计表明,吸引新客户的成本是留住现有客户的五倍。 xx

对于我们的例子,我们将使用来自加州大学欧文分校机器学习知识库的一个流行的电信客户流失数据集(参见清单 3-2 )。这是一个流行的 Kaggle 数据集 xxi ,在网上被广泛使用。 xxii

对于本书中的大多数示例,我将分别执行转换器和估算器(而不是在管道中指定它们),这样您就可以看到新列被添加到结果数据帧中。这将有助于您在研究示例时了解“幕后”发生了什么。

// Load the CSV file into a DataFrame.
val dataDF = spark.read.format("csv")
             .option("header", "true")
             .option("inferSchema", "true")
             .load("churn_data.txt")

// Check the schema.

dataDF.printSchema
root
 |-- state: string (nullable = true)
 |-- account_length: double (nullable = true)
 |-- area_code: double (nullable = true)
 |-- phone_number: string (nullable = true)
 |-- international_plan: string (nullable = true)
 |-- voice_mail_plan: string (nullable = true)
 |-- number_vmail_messages: double (nullable = true)
 |-- total_day_minutes: double (nullable = true)
 |-- total_day_calls: double (nullable = true)
 |-- total_day_charge: double (nullable = true)
 |-- total_eve_minutes: double (nullable = true)
 |-- total_eve_calls: double (nullable = true)
 |-- total_eve_charge: double (nullable = true)
 |-- total_night_minutes: double (nullable = true)
 |-- total_night_calls: double (nullable = true)
 |-- total_night_charge: double (nullable = true)
 |-- total_intl_minutes: double (nullable = true)
 |-- total_intl_calls: double (nullable = true)
 |-- total_intl_charge: double (nullable = true)
 |-- number_customer_service_calls: double (nullable = true)
 |-- churned: string (nullable = true)

// Select a few columns

.

dataDF
.select("state","phone_number","international_plan","total_day_minutes","churned").show
+-----+------------+------------------+-----------------+-------+
|state|phone_number|international_plan|total_day_minutes|churned|
+-----+------------+------------------+-----------------+-------+
|   KS|    382-4657|                no|            265.1|  False|
|   OH|    371-7191|                no|            161.6|  False|
|   NJ|    358-1921|                no|            243.4|  False|
|   OH|    375-9999|               yes|            299.4|  False|
|   OK|    330-6626|               yes|            166.7|  False|
|   AL|    391-8027|               yes|            223.4|  False|
|   MA|    355-9993|                no|            218.2|  False|
|   MO|    329-9001|               yes|            157.0|  False|
|   LA|    335-4719|                no|            184.5|  False|
|   WV|    330-8173|               yes|            258.6|  False|
|   IN|    329-6603|                no|            129.1|   True|
|   RI|    344-9403|                no|            187.7|  False|
|   IA|    363-1107|                no|            128.8|  False|
|   MT|    394-8006|                no|            156.6|  False|
|   IA|    366-9238|                no|            120.7|  False|
|   NY|    351-7269|                no|            332.9|   True|
|   ID|    350-8884|                no|            196.4|  False|
|   VT|    386-2923|                no|            190.7|  False|
|   VA|    356-2992|                no|            189.7|  False|
|   TX|    373-2782|                no|            224.4|  False|
+-----+------------+------------------+-----------------+-------+
only showing top 20 rows

import org.apache.spark.ml.feature.StringIndexer

// Convert the string column "churned" ("True", "False") to double (1,0).

val labelIndexer = new StringIndexer()
                   .setInputCol("churned")
                   .setOutputCol("label")

// Convert the string column "international_plan" ("yes", "no")
// to double 1,0.

val intPlanIndexer = new StringIndexer()
                     .setInputCol("international_plan")
                     .setOutputCol("int_plan")

// Let's select our features. Domain knowledge is essential in feature
// selection. I would think total_day_minutes and total_day_calls have
// some influence on customer churn. A significant drop in these two
// metrics might indicate that the customer does not need the service
// any longer and may be on the verge of cancelling their phone plan.
// However, I don't think phone_number, area_code, and state have any
// predictive qualities at all. We discuss feature selection later in
// this chapter.

val features = Array("number_customer_service_calls","total_day_minutes","total_eve_minutes","account_length","number_vmail_messages","total_day_calls","total_day_charge","total_eve_calls","total_eve_charge","total_night_calls","total_intl_calls","total_intl_charge","int_plan")

// Combine a given list of columns into a single vector column
// including all the features needed to train ML models.

import org.apache.spark.ml.feature.VectorAssembler

val assembler = new VectorAssembler()
                .setInputCols(features)
                .setOutputCol("features")

// Add the label column to the DataFrame.

val dataDF2 = labelIndexer
              .fit(dataDF)
              .transform(dataDF)

dataDF2.printSchema

root
 |-- state: string (nullable = true)
 |-- account_length: double (nullable = true)
 |-- area_code: double (nullable = true)
 |-- phone_number: string (nullable = true)
 |-- international_plan: string (nullable = true)
 |-- voice_mail_plan: string (nullable = true)
 |-- number_vmail_messages: double (nullable = true)
 |-- total_day_minutes: double (nullable = true)
 |-- total_day_calls: double (nullable = true)
 |-- total_day_charge: double (nullable = true)
 |-- total_eve_minutes: double (nullable = true)
 |-- total_eve_calls: double (nullable = true)
 |-- total_eve_charge: double (nullable = true)
 |-- total_night_minutes: double (nullable = true)
 |-- total_night_calls: double (nullable = true)
 |-- total_night_charge: double (nullable = true)
 |-- total_intl_minutes: double (nullable = true)
 |-- total_intl_calls: double (nullable = true)
 |-- total_intl_charge: double (nullable = true)
 |-- number_customer_service_calls: double (nullable = true)
 |-- churned: string (nullable = true)
 |-- label: double (nullable = false)

// "True" was converted to 1 and "False" was converted to 0.

dataDF2.select("churned","label").show

+-------+-----+
|churned|label|
+-------+-----+
|  False|  0.0|
|  False|  0.0|
|  False|  0.0|
|  False|  0.0|
|  False|  0.0|
|  False|  0.0|
|  False|  0.0|
|  False|  0.0|
|  False|  0.0|
|  False|  0.0|
|   True|  1.0|
|  False|  0.0|
|  False|  0.0|
|  False|  0.0|
|  False|  0.0|
|   True|  1.0|
|  False|  0.0|
|  False|  0.0|
|  False|  0.0|
|  False|  0.0|
+-------+-----+
only showing top 20 rows

// Add the int_plan column to the DataFrame.

val dataDF3 = intPlanIndexer.fit(dataDF2).transform(dataDF2)

dataDF3.printSchema

root
 |-- state: string (nullable = true)
 |-- account_length: double (nullable = true)
 |-- area_code: double (nullable = true)
 |-- phone_number: string (nullable = true)
 |-- international_plan: string (nullable = true)
 |-- voice_mail_plan: string (nullable = true)
 |-- number_vmail_messages: double (nullable = true)
 |-- total_day_minutes: double (nullable = true)
 |-- total_day_calls: double (nullable = true)
 |-- total_day_charge: double (nullable = true)
 |-- total_eve_minutes: double (nullable = true)
 |-- total_eve_calls: double (nullable = true)
 |-- total_eve_charge: double (nullable = true)
 |-- total_night_minutes: double (nullable = true)
 |-- total_night_calls: double (nullable = true)
 |-- total_night_charge: double (nullable = true)
 |-- total_intl_minutes: double (nullable = true)
 |-- total_intl_calls: double (nullable = true)
 |-- total_intl_charge: double (nullable = true)
 |-- number_customer_service_calls: double (nullable = true)
 |-- churned: string (nullable = true)
 |-- label: double (nullable = false)
 |-- int_plan: double (nullable = false)

dataDF3.select("international_plan","int_plan").show

+------------------+--------+
|international_plan|int_plan|
+------------------+--------+
|                no|     0.0|
|                no|     0.0|
|                no|     0.0|
|               yes|     1.0|
|               yes|     1.0|
|               yes|     1.0|
|                no|     0.0|
|               yes|     1.0|
|                no|     0.0|
|               yes|     1.0|
|                no|     0.0|
|                no|     0.0|
|                no|     0.0|
|                no|     0.0|
|                no|     0.0|
|                no|     0.0|
|                no|     0.0|
|                no|     0.0|
|                no|     0.0|
|                no|     0.0|
+------------------+--------+
only showing top 20 rows

// Add the features vector column to the DataFrame.

val dataDF4 = assembler.transform(dataDF3)

dataDF4.printSchema

root
 |-- state: string (nullable = true)
 |-- account_length: double (nullable = true)
 |-- area_code: double (nullable = true)
 |-- phone_number: string (nullable = true)
 |-- international_plan: string (nullable = true)
 |-- voice_mail_plan: string (nullable = true)
 |-- number_vmail_messages: double (nullable = true)
 |-- total_day_minutes: double (nullable = true)
 |-- total_day_calls: double (nullable = true)
 |-- total_day_charge: double (nullable = true)
 |-- total_eve_minutes: double (nullable = true)
 |-- total_eve_calls: double (nullable = true)
 |-- total_eve_charge: double (nullable = true)
 |-- total_night_minutes: double (nullable = true)
 |-- total_night_calls: double (nullable = true)
 |-- total_night_charge: double (nullable = true)
 |-- total_intl_minutes: double (nullable = true)
 |-- total_intl_calls: double (nullable = true)
 |-- total_intl_charge: double (nullable = true)
 |-- number_customer_service_calls: double (nullable = true)
 |-- churned: string (nullable = true)
 |-- label: double (nullable = false)
 |-- int_plan: double (nullable = false)
 |-- features: vector (nullable = true)

// The features have been vectorized.

dataDF4.select("features").show(false)

+----------------------------------------------------------------------+
|features                                                              |
+----------------------------------------------------------------------+
|[1.0,265.1,197.4,128.0,25.0,110.0,45.07,99.0,16.78,91.0,3.0,2.7,0.0]  |
|[1.0,161.6,195.5,107.0,26.0,123.0,27.47,103.0,16.62,103.0,3.0,3.7,0.0]|
|[0.0,243.4,121.2,137.0,0.0,114.0,41.38,110.0,10.3,104.0,5.0,3.29,0.0] |
|[2.0,299.4,61.9,84.0,0.0,71.0,50.9,88.0,5.26,89.0,7.0,1.78,1.0]       |
|[3.0,166.7,148.3,75.0,0.0,113.0,28.34,122.0,12.61,121.0,3.0,2.73,1.0] |
|[0.0,223.4,220.6,118.0,0.0,98.0,37.98,101.0,18.75,118.0,6.0,1.7,1.0]  |
|[3.0,218.2,348.5,121.0,24.0,88.0,37.09,108.0,29.62,118.0,7.0,2.03,0.0]|
|[0.0,157.0,103.1,147.0,0.0,79.0,26.69,94.0,8.76,96.0,6.0,1.92,1.0]    |
|[1.0,184.5,351.6,117.0,0.0,97.0,31.37,80.0,29.89,90.0,4.0,2.35,0.0]   |
|[0.0,258.6,222.0,141.0,37.0,84.0,43.96,111.0,18.87,97.0,5.0,3.02,1.0] |
|[4.0,129.1,228.5,65.0,0.0,137.0,21.95,83.0,19.42,111.0,6.0,3.43,0.0]  |
|[0.0,187.7,163.4,74.0,0.0,127.0,31.91,148.0,13.89,94.0,5.0,2.46,0.0]  |
|[1.0,128.8,104.9,168.0,0.0,96.0,21.9,71.0,8.92,128.0,2.0,3.02,0.0]    |
|[3.0,156.6,247.6,95.0,0.0,88.0,26.62,75.0,21.05,115.0,5.0,3.32,0.0]   |
|[4.0,120.7,307.2,62.0,0.0,70.0,20.52,76.0,26.11,99.0,6.0,3.54,0.0]    |
|[4.0,332.9,317.8,161.0,0.0,67.0,56.59,97.0,27.01,128.0,9.0,1.46,0.0]  |
|[1.0,196.4,280.9,85.0,27.0,139.0,33.39,90.0,23.88,75.0,4.0,3.73,0.0]  |
|[3.0,190.7,218.2,93.0,0.0,114.0,32.42,111.0,18.55,121.0,3.0,2.19,0.0] |
|[1.0,189.7,212.8,76.0,33.0,66.0,32.25,65.0,18.09,108.0,5.0,2.7,0.0]   |
|[1.0,224.4,159.5,73.0,0.0,90.0,38.15,88.0,13.56,74.0,2.0,3.51,0.0]    |
+----------------------------------------------------------------------+
only showing top 20 rows

// Split the data into training and test data.

val seed = 1234

val Array(trainingData, testData) = dataDF4.randomSplit(Array(0.8, 0.2), seed)

trainingData.count
res13: Long = 4009

testData.count
res14: Long = 991

// Create a Random Forest classifier.

import org.apache.spark.ml.classification.RandomForestClassifier

val rf = new RandomForestClassifier()
        .setFeatureSubsetStrategy("auto")
        .setSeed(seed)

// Create a binary classification evaluator, and set label column to
// be used for evaluation.

import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator

val evaluator = new BinaryClassificationEvaluator().setLabelCol("label")

// Create a parameter grid.

import org.apache.spark.ml.tuning.ParamGridBuilder

val paramGrid = new ParamGridBuilder()
                .addGrid(rf.maxBins, Array(10, 20,30))
                .addGrid(rf.maxDepth, Array(5, 10, 15))
                .addGrid(rf.numTrees, Array(3, 5, 100))
                .addGrid(rf.impurity, Array("gini", "entropy"))
                .build()

// Create a pipeline.
import org.apache.spark.ml.Pipeline

val pipeline = new Pipeline().setStages(Array(rf))

// Create a cross-validator.

import org.apache.spark.ml.tuning.CrossValidator

val cv = new CrossValidator()
         .setEstimator(pipeline)
         .setEvaluator(evaluator)
         .setEstimatorParamMaps(paramGrid)
         .setNumFolds(3)

// We can now fit the model using the training dataset, choosing the
// best set of parameters for the model.

val model = cv.fit(trainingData)

// You can now make some predictions on our test data.

val predictions = model.transform(testData)

// Evaluate the model.

import org.apache.spark.ml.param.ParamMap

val pmap = ParamMap(evaluator.metricName -> "areaUnderROC")

val auc = evaluator.evaluate(predictions, pmap)

auc: Double = 0.9270599683335483

// Our Random Forest classifier has a high AUC score. The test
// data consists of 991 observations. 92 customers are predicted
// to leave the service.

predictions.count
res25: Long = 991

predictions.filter("prediction=1").count
res26: Long = 92

println(s"True Negative: ${predictions.select("*").where("prediction = 0 AND label = 0").count()}  True Positive: ${predictions.select("*").where("prediction = 1 AND label = 1").count()}")

True Negative: 837 True Positive: 81

// Our test predicted 81 customers leaving who actually did leave and also
// predicted 837 customers not leaving who actually did not leave.

println(s"False Negative: ${predictions.select("*").where("prediction = 0 AND label = 1").count()} False Positive: ${predictions.select("*").where("prediction = 1 AND label = 0").count()}")

False Negative: 62 False Positive: 11

// Our test predicted 11 customers leaving who actually did not leave and
// also predicted 62 customers not leaving who actually did leave.

// You can sort the output by RawPrediction or Probability to target
// highest-probability customers. RawPrediction and Probability
// provide a measure of confidence for each prediction. The larger
// the value, the more confident the model is in its prediction.

predictions.select("phone_number","RawPrediction","prediction")
           .orderBy($"RawPrediction".asc)
           .show(false)

+------------+--------------------------------------+----------+
|phone_number|RawPrediction                         |prediction|
+------------+--------------------------------------+----------+
| 366-1084   |[15.038138063913935,84.96186193608602]|1.0       |
| 334-6519   |[15.072688486480072,84.9273115135199] |1.0       |
| 359-5574   |[15.276260309388752,84.72373969061123]|1.0       |
| 399-7865   |[15.429722388653014,84.57027761134698]|1.0       |
| 335-2967   |[16.465107279664032,83.53489272033593]|1.0       |
| 345-9140   |[16.53288465159445,83.46711534840551] |1.0       |
| 342-6864   |[16.694165016887318,83.30583498311265]|1.0       |
| 419-1863   |[17.594670105674677,82.4053298943253] |1.0       |
| 384-7176   |[17.92764148018115,82.07235851981882] |1.0       |
| 357-1938   |[18.8550074623437,81.1449925376563]   |1.0       |
| 355-6837   |[19.556608109022648,80.44339189097732]|1.0       |
| 417-1488   |[20.13305147603522,79.86694852396475] |1.0       |
| 394-5489   |[21.05074084178182,78.94925915821818] |1.0       |
| 394-7447   |[21.376663858426735,78.62333614157326]|1.0       |
| 339-6477   |[21.549262081786424,78.45073791821355]|1.0       |
| 406-7844   |[21.92209788389343,78.07790211610656] |1.0       |
| 372-4073   |[22.098599119168263,77.90140088083176]|1.0       |
| 404-4809   |[22.515513847987147,77.48448615201283]|1.0       |
| 347-8659   |[22.66840460762997,77.33159539237005] |1.0       |
| 335-1874   |[23.336632598761128,76.66336740123884]|1.0       |
+------------+--------------------------------------+----------+
only showing top 20 rows

Listing 3-2Churn Prediction Using Random Forest

特征重要性

随机森林(和其他基于树的集合)具有内置的特征选择功能,可用于测量数据集中每个特征的重要性(参见清单 3-3 )。

随机森林将要素重要性计算为每次选择要素分割节点时聚合的每棵树上每个节点的节点杂质减少量之和除以森林中的树数。Spark MLlib 提供了一种方法,可以返回每个特性的重要性估计值。

import org.apache.spark.ml.classification.RandomForestClassificationModel
import org.apache.spark.ml.PipelineModel

val bestModel = model.bestModel

val model = bestModel
            .asInstanceOf[PipelineModel]
            .stages
            .last
            .asInstanceOf[RandomForestClassificationModel]

model.featureImportances

feature_importances: org.apache.spark.ml.linalg.Vector =
(13,[0,1,2,3,4,5,6,7,8,9,10,11,12],
[
0.20827010117447803,
0.1667170878866465,
0.06099491253318444,
0.008184141410796346,
0.06664053647245761,
0.0072108752126555,
0.21097011684691344,
0.006902059667276019,
0.06831916361401609,
0.00644772968425685,
0.04105403721675372,
0.056954219262186724,
0.09133501901837866])

Listing 3-3Showing Feature Importance with Random Forest

我们得到一个向量,包含特征的数量(在我们的例子中是 13)、特征的数组索引和相应的权重。表 3-1 以可读性更强的格式显示了输出,实际特征以相应的权重显示。如您所见,total_day_charge、total_day_minutes 和 number_customer_service_calls 是最重要的功能。有道理。大量客户服务呼叫可能表明服务多次中断或大量客户投诉。低 total_day_minutes 和 total_day_charge 可能表明客户不经常使用他的电话计划,这可能意味着他准备很快取消他的计划。

表 3-1

电信客户流失预测示例的功能重要性

|

索引

|

特征

|

特征重要性

| | --- | --- | --- | | Zero | 数量 _ 客户 _ 服务 _ 呼叫 | 0.20827010117447803 | | one | 总计 _ 天 _ 分钟 | 0.1667170878866465 | | Two | 总计 _ eve _ 分钟 | 0.06099491253318444 | | three | 帐户 _ 长度 | 0.008184141410796346 | | four | 数字邮件消息 | 0.06664053647245761 | | five | 总计 _ 天 _ 次呼叫 | 0.0072108752126555 | | six | 总计 _ 天 _ 费用 | 0.21097011684691344 | | seven | 总通话次数 | 0.006902059667276019 | | eight | 总费用 | 0.06831916361401609 | | nine | 夜间通话总数 | 0.00644772968425685 | | Ten | 总呼叫次数 | 0.04105403721675372 | | Eleven | total_intl_charge | 0.056954219262186724 | | Twelve | int_plan | 0.09133501901837866 |

Note

Spark MLlib 在随机森林中实现的特征重要性也称为基于基尼的重要性或杂质平均减少(MDI)。随机森林的一些实现利用不同的方法来计算特征重要性,这种方法被称为基于精度的重要性或平均精度下降(MDA)。 xxiii 基于准确度的重要性是基于特征被随机置换时预测准确度的降低来计算的。虽然 Spark MLlib 的随机森林实现不直接支持这种方法,但通过评估模型同时一次置换每个特性的一列值,手动实现这种方法相当简单。

有时检查最佳模型使用的参数是有用的(参见清单 3-4 )。

import org.apache.spark.ml.classification.RandomForestClassificationModel
import org.apache.spark.ml.PipelineModel

val bestModel = model
                .bestModel
                .asInstanceOf[PipelineModel]
                .stages
                .last
                .asInstanceOf[RandomForestClassificationModel]

 print(bestModel.extractParamMap)
{
        rfc_81c4d3786152-cacheNodeIds: false,
        rfc_81c4d3786152-checkpointInterval: 10,
        rfc_81c4d3786152-featureSubsetStrategy: auto,
        rfc_81c4d3786152-featuresCol: features,
        rfc_81c4d3786152-impurity: gini,
        rfc_81c4d3786152-labelCol: label,
        rfc_81c4d3786152-maxBins: 10,
        rfc_81c4d3786152-maxDepth: 15,
        rfc_81c4d3786152-maxMemoryInMB: 256,
        rfc_81c4d3786152-minInfoGain: 0.0,
        rfc_81c4d3786152-minInstancesPerNode: 1,
        rfc_81c4d3786152-numTrees: 100,
        rfc_81c4d3786152-predictionCol: prediction,
        rfc_81c4d3786152-probabilityCol: probability,
        rfc_81c4d3786152-rawPredictionCol: rawPrediction,
        rfc_81c4d3786152-seed: 1234,
        rfc_81c4d3786152-subsamplingRate: 1.0
}

Listing 3-4Extracting the Parameters of the Random Forest Model

XGBoost4J-Spark 的极限梯度升压

梯度推进算法是用于分类和回归的一些最强大的机器学习算法。目前有各种梯度提升算法的实现。流行的实现包括 AdaBoost 和 CatBoost(Yandex 最近开源的梯度提升库)。Spark MLlib 还包括自己的梯度提升树(GBT)实现。

XGBoost(极限梯度提升)是目前可用的最好的梯度提升树实现之一。XGBoost 于 2014 年 3 月 27 日由陈天琦发布,作为一个研究项目,它已经成为分类和回归的主流机器学习算法。为提高效率和可伸缩性而设计,其并行树提升能力使其比其他基于树的集成算法快得多。由于准确率高,XGBoost 通过赢得多个机器学习比赛而获得了知名度。2015 年,Kaggle 上的 29 个获奖解决方案中有 17 个使用了 XGBoost。2015 年 KDD 杯前 10 名的解决方案全部使用了 XGBoost。

XGBoost 是使用梯度推进的一般原则设计的,将弱学习者组合成强学习者。但是,虽然梯度提升树是按顺序构建的——慢慢地从数据中学习,以在后续迭代中改进其预测,但 XGBoost 是并行构建树的。XGBoost 通过其内置的正则化来控制模型复杂性和减少过拟合,从而产生更好的预测性能。当查找连续特征的最佳分割点时,它使用近似算法来查找分割点。XXV

近似分割方法使用离散箱来存储连续要素,从而显著加快模型训练速度。XGBoost 包括另一种使用基于直方图的算法的树生长方法,该方法提供了一种将连续要素分入离散箱的更有效的方法。但是,虽然近似方法每次迭代都创建一组新的面元,但是基于直方图的方法在多次迭代中重复使用面元。

这种方法允许使用近似方法无法实现的额外优化,例如缓存二进制文件以及父直方图和兄弟直方图相减的能力。 xxvi 为了优化排序操作,XGBoost 将排序后的数据存储在内存块单元中。排序块可以由并行 CPU 核心高效地分配和执行。XGBoost 可以通过其加权分位数草图算法有效地处理加权数据,可以有效地处理稀疏数据,支持缓存,并通过为大型数据集利用磁盘空间来支持核外计算,因此数据不必放在内存中。

XGBoost4J-Spark 项目于 2016 年末启动,将 XGBoost 移植到 Spark。XGBoost4J-Spark 利用了 Spark 高度可扩展的分布式处理引擎,并与 Spark MLlib 的数据帧/数据集抽象完全兼容。XGBoost4J-Spark 可以无缝嵌入到 Spark MLlib 管道中,并与 Spark MLlib 的变压器和估算器集成。

Note

XGBoost4J-Spark 需要 Apache Spark 2.4+。建议直接从 http://spark.apache.org 安装 Spark。XGBoost4J-Spark 不能保证与其他供应商的第三方 Spark 发行版(如 Cloudera、Hortonworks 或 MapR)一起使用。有关更多信息,请参考供应商的文档。 二十七

因素

XGBoost 比 Random Forest 有更多的参数,通常需要更多的调整。最初关注最重要的参数可以帮助您开始使用 XGBoost。随着你对算法越来越熟悉,你可以学习剩下的部分。

  • max_depth: 指定树的最大深度。为 max_depth 设置较高的值可能会增加过度拟合的可能性,并使模型更加复杂。

  • *n _ estimates:*指定要拟合的树的数量。一般来说,价值越大越好。将此参数设置得太高可能会影响训练速度。在某一点之外添加更多的树可能不会提高精度。默认值设置为 100。 二十八

  • sub_sample: 指定将为每棵树选择的数据部分。设置此参数可以提高训练速度,并有助于防止过度拟合。将其设置得太低可能会导致拟合不足。

  • colsample_bytree: 指定将为每棵树随机选择的列的分数。设置此参数可以提高训练速度,并有助于防止过度拟合。相关参数包括 colsample_bylevel 和 colsample_bynode。

  • 目标:指定学习任务和学习目标。为该参数设置正确的值很重要,以避免不可预测的结果或不良的准确性。XGBClassifier 默认为二进制:逻辑进行二进制分类,而 XGBRegressor 默认为 reg:squarederror 。其他值包括用于多类分类的 multi:softmaxmulti:soft probrank:pairwise,rank:ndcg,rank:map 进行排名;和生存:cox 使用 cox 比例风险模型进行生存回归,仅举几例。

  • *learning _ rate(eta):*learning _ rate 作为收缩因子,在每一个 boosting 步骤后减少特征权重,目的是减缓学习速率。该参数用于控制过度拟合。较低的值需要更多的树。

  • n_jobs: 指定 XGBoost 使用的并行线程的数量(如果 n_thread 被弃用,则使用该参数)。

这些只是关于如何使用参数的一般准则。强烈建议执行参数网格搜索来确定这些参数的最佳值。有关 XGBoost 参数的完整列表,请参考 XGBoost 的在线文档。

Note

为了与 Scala 的变量命名约定保持一致,XGBoost4J-Spark 既支持默认的参数集,也支持这些参数的 camel case 变体(例如,max_depth 和 maxDepth)。

例子

我们将重用相同的电信客户流失数据集和前面随机森林示例中的大部分代码(参见清单 3-5 )。这一次,我们将使用管道将转换器和估算器连接在一起。

// XGBoost4J-Spark is available as an external package.
// Start spark-shell. Specify the XGBoost4J-Spark package.

spark-shell --packages ml.dmlc:xgboost4j-spark:0.81

// Load the CSV file into a DataFrame.

val dataDF = spark.read.format("csv")
             .option("header", "true")
             .option("inferSchema", "true")
             .load("churn_data.txt")

// Check the schema.

dataDF.printSchema

root
 |-- state: string (nullable = true)
 |-- account_length: double (nullable = true)
 |-- area_code: double (nullable = true)
 |-- phone_number: string (nullable = true)
 |-- international_plan: string (nullable = true)
 |-- voice_mail_plan: string (nullable = true)
 |-- number_vmail_messages: double (nullable = true)
 |-- total_day_minutes: double (nullable = true)
 |-- total_day_calls: double (nullable = true)
 |-- total_day_charge: double (nullable = true)
 |-- total_eve_minutes: double (nullable = true)
 |-- total_eve_calls: double (nullable = true)
 |-- total_eve_charge: double (nullable = true)
 |-- total_night_minutes: double (nullable = true)
 |-- total_night_calls: double (nullable = true)
 |-- total_night_charge: double (nullable = true)
 |-- total_intl_minutes: double (nullable = true)
 |-- total_intl_calls: double (nullable = true)
 |-- total_intl_charge: double (nullable = true)
 |-- number_customer_service_calls: double (nullable = true)
 |-- churned: string (nullable = true)

// Select a few columns.

dataDF.select("state","phone_number","international_plan","churned").show

+-----+------------+------------------+-------+
|state|phone_number|international_plan|churned|
+-----+------------+------------------+-------+
|   KS|    382-4657|                no|  False|
|   OH|    371-7191|                no|  False|
|   NJ|    358-1921|                no|  False|
|   OH|    375-9999|               yes|  False|
|   OK|    330-6626|               yes|  False|
|   AL|    391-8027|               yes|  False|
|   MA|    355-9993|                no|  False|
|   MO|    329-9001|               yes|  False|
|   LA|    335-4719|                no|  False|
|   WV|    330-8173|               yes|  False|
|   IN|    329-6603|                no|   True|
|   RI|    344-9403|                no|  False|
|   IA|    363-1107|                no|  False|
|   MT|    394-8006|                no|  False|
|   IA|    366-9238|                no|  False|
|   NY|    351-7269|                no|   True|
|   ID|    350-8884|                no|  False|
|   VT|    386-2923|                no|  False|
|   VA|    356-2992|                no|  False|
|   TX|    373-2782|                no|  False|
+-----+------------+------------------+-------+
only showing top 20 rows

import org.apache.spark.ml.feature.StringIndexer

// Convert the String "churned" column ("True", "False") to double(1,0).

val labelIndexer = new StringIndexer()
                   .setInputCol("churned")
                   .setOutputCol("label")

// Convert the String "international_plan" ("no", "yes") column to double(1,0).

val intPlanIndexer = new StringIndexer()
                     .setInputCol("international_plan")
                     .setOutputCol("int_plan")

// Specify features to be selected for model fitting.

val features = Array("number_customer_service_calls","total_day_minutes","total_eve_minutes","account_length","number_vmail_messages","total_day_calls","total_day_charge","total_eve_calls","total_eve_charge","total_night_calls","total_intl_calls","total_intl_charge","int_plan")

// Combines the features into a single vector column.

import org.apache.spark.ml.feature.VectorAssembler

val assembler = new VectorAssembler()
                .setInputCols(features)
                .setOutputCol("features")

// Split the data into training and test data.

val seed = 1234

val Array(trainingData, testData) = dataDF.randomSplit(Array(0.8, 0.2), seed)

// Create an XGBoost classifier.

import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel

val xgb = new XGBoostClassifier()
          .setFeaturesCol("features")
          .setLabelCol("label")

// XGBClassifier's objective parameter defaults to binary:logistic which
// is the learning task and objective that we want for this example
// (binary classification). Depending on your task, remember to set the
// correct learning task and objective.

import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator

val evaluator = new MulticlassClassificationEvaluator().setLabelCol("label")

import org.apache.spark.ml.tuning.ParamGridBuilder

val paramGrid = new ParamGridBuilder()
                .addGrid(xgb.maxDepth, Array(3, 8))
                .addGrid(xgb.eta, Array(0.2, 0.6))
                .build()

// This time we'll specify all the steps in the pipeline.

import org.apache.spark.ml.{ Pipeline, PipelineStage }

val pipeline = new Pipeline()
               .setStages(Array(labelIndexer, intPlanIndexer, assembler, xgb))

// Create a cross-validator.

import org.apache.spark.ml.tuning.CrossValidator

val cv = new CrossValidator()
         .setEstimator(pipeline)
         .setEvaluator(evaluator)
         .setEstimatorParamMaps(paramGrid)
         .setNumFolds(3)

// We can now fit the model using the training data. This will run
// cross-validation, choosing the best set of parameters.

val model = cv.fit(trainingData)

// You can now make some predictions on our test data.

val predictions = model.transform(testData)

predictions.printSchema

root
 |-- state: string (nullable = true)
 |-- account_length: double (nullable = true)
 |-- area_code: double (nullable = true)
 |-- phone_number: string (nullable = true)
 |-- international_plan: string (nullable = true)
 |-- voice_mail_plan: string (nullable = true)
 |-- number_vmail_messages: double (nullable = true)
 |-- total_day_minutes: double (nullable = true)
 |-- total_day_calls: double (nullable = true)
 |-- total_day_charge: double (nullable = true)
 |-- total_eve_minutes: double (nullable = true)
 |-- total_eve_calls: double (nullable = true)
 |-- total_eve_charge: double (nullable = true)
 |-- total_night_minutes: double (nullable = true)
 |-- total_night_calls: double (nullable = true)
 |-- total_night_charge: double (nullable = true)
 |-- total_intl_minutes: double (nullable = true)
 |-- total_intl_calls: double (nullable = true)
 |-- total_intl_charge: double (nullable = true)
 |-- number_customer_service_calls: double (nullable = true)
 |-- churned: string (nullable = true)
 |-- label: double (nullable = false)
 |-- int_plan: double (nullable = false)
 |-- features: vector (nullable = true)
 |-- rawPrediction: vector (nullable = true)
 |-- probability: vector (nullable = true)
 |-- prediction: double (nullable = false)

// Let's evaluate the model.

val auc = evaluator.evaluate(predictions)
auc: Double = 0.9328044307445879

// The AUC score produced by XGBoost4J-Spark is slightly better compared
// to our previous Random Forest example. XGBoost4J-Spark was also
// faster than Random Forest in training this dataset.

// Like Random Forest, XGBoost lets you extract the feature importance.

import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel
import org.apache.spark.ml.PipelineModel

val bestModel = model.bestModel

val model = bestModel
            .asInstanceOf[PipelineModel]
            .stages
            .last
            .asInstanceOf[XGBoostClassificationModel]

// Execute the getFeatureScore method to extract the feature importance.

model.nativeBooster.getFeatureScore()

res9: scala.collection.mutable.Map[String,Integer] = Map(f7 -> 4, f9 -> 7, f10 -> 2, f12 -> 4, f11 -> 8, f0 -> 5, f1 -> 19, f2 -> 17, f3 -> 10, f4 -> 2, f5 -> 3)

// The method returns a map with the key mapping to the feature
// array index and the value corresponding to the feature importance score.

Listing 3-5Churn Prediction Using XGBoost4J-Spark

表 3-2

使用 XGBoost4J-Spark 的特性重要性

|

索引

|

特征

|

特征重要性

| | --- | --- | --- | | Zero | 数量 _ 客户 _ 服务 _ 呼叫 | Two | | one | 总计 _ 天 _ 分钟 | Fifteen | | Two | 总计 _ eve _ 分钟 | Ten | | three | 帐户 _ 长度 | three | | four | 数字邮件消息 | Two | | five | 总计 _ 天 _ 次呼叫 | three | | six | 总计 _ 天 _ 费用 | 省略 | | seven | 总通话次数 | Two | | eight | 总费用 | 省略 | | nine | 夜间通话总数 | Two | | Ten | 总呼叫次数 | Two | | Eleven | total_intl_charge | one | | Twelve | int_plan | five |

请注意,输出缺少了几列,特别是 total_day_charge (f6)和 total_eve_charge (f8)。这些是 XGBoost 被认为在提高模型预测精度方面无效的特征(见表 3-2 )。只有在至少一次分割中使用的要素才能进入 XGBoost 要素重要性输出。有几种可能的解释。这可能意味着丢弃的要素具有非常低的方差或零方差。这也可能意味着这两个特征与其他特征高度相关。

将 XGBoost 的特性重要性输出与我们之前的随机森林示例进行比较时,有一些有趣的事情需要注意。请注意,虽然我们之前的随机森林模型认为 number_customer_service_calls 是最重要的特性之一,但 XGBoost 将其列为最不重要的特性之一。类似地,前面的随机森林模型认为 total_day_charge 是最重要的特性,但是 XGBoost 由于其缺乏重要性而完全将其从输出中忽略(参见清单 3-6 )。

val bestModel = model
                .bestModel
                .asInstanceOf[PipelineModel]
                .stages
                .last
                .asInstanceOf[XGBoostClassificationModel]

print(bestModel.extractParamMap)
{
        xgbc_9b95e70ab140-alpha: 0.0,
        xgbc_9b95e70ab140-baseScore: 0.5,
        xgbc_9b95e70ab140-checkpointInterval: -1,
        xgbc_9b95e70ab140-checkpointPath: ,
        xgbc_9b95e70ab140-colsampleBylevel: 1.0,
        xgbc_9b95e70ab140-colsampleBytree: 1.0,
        xgbc_9b95e70ab140-customEval: null,
        xgbc_9b95e70ab140-customObj: null,
        xgbc_9b95e70ab140-eta: 0.2,
        xgbc_9b95e70ab140-evalMetric: error,
        xgbc_9b95e70ab140-featuresCol: features,
        xgbc_9b95e70ab140-gamma: 0.0,
        xgbc_9b95e70ab140-growPolicy: depthwise,
        xgbc_9b95e70ab140-labelCol: label,
        xgbc_9b95e70ab140-lambda: 1.0,
        xgbc_9b95e70ab140-lambdaBias: 0.0,
        xgbc_9b95e70ab140-maxBin: 16,
        xgbc_9b95e70ab140-maxDeltaStep: 0.0,
        xgbc_9b95e70ab140-maxDepth: 8,
        xgbc_9b95e70ab140-minChildWeight: 1.0,
        xgbc_9b95e70ab140-missing: NaN,
        xgbc_9b95e70ab140-normalizeType: tree,
        xgbc_9b95e70ab140-nthread: 1,
        xgbc_9b95e70ab140-numEarlyStoppingRounds: 0,
        xgbc_9b95e70ab140-numRound: 1,
        xgbc_9b95e70ab140-numWorkers: 1,
        xgbc_9b95e70ab140-objective: reg:linear,
        xgbc_9b95e70ab140-predictionCol: prediction,
        xgbc_9b95e70ab140-probabilityCol: probability,
        xgbc_9b95e70ab140-rateDrop: 0.0,
        xgbc_9b95e70ab140-rawPredictionCol: rawPrediction,
        xgbc_9b95e70ab140-sampleType: uniform,
        xgbc_9b95e70ab140-scalePosWeight: 1.0,
        xgbc_9b95e70ab140-seed: 0,
        xgbc_9b95e70ab140-silent: 0,
        xgbc_9b95e70ab140-sketchEps: 0.03,
        xgbc_9b95e70ab140-skipDrop: 0.0,
        xgbc_9b95e70ab140-subsample: 1.0,
        xgbc_9b95e70ab140-timeoutRequestWorkers: 1800000,
        xgbc_9b95e70ab140-trackerConf: TrackerConf(0,python),
        xgbc_9b95e70ab140-trainTestRatio: 1.0,
        xgbc_9b95e70ab140-treeLimit: 0,
        xgbc_9b95e70ab140-treeMethod: auto,
        xgbc_9b95e70ab140-useExternalMemory: false
}

Listing 3-6Extracting the Parameters of the XGBoost4J-Spark Model

LightGBM:微软的快速渐变提升

多年来,XGBoost 一直是每个人最喜欢的分类和回归算法。最近,LightGBM 成为了王位的新挑战者。它是一个相对较新的基于树的梯度提升变体,类似于 XGBoost。LightGBM 于 2016 年 10 月 17 日发布,是微软分布式机器学习工具包(DMTK)项目的一部分。它被设计成快速和分布式的,导致更快的训练速度和更低的内存使用。它支持 GPU 和并行学习以及处理大型数据集的能力。在几个基准测试和公共数据集上的实验中,LightGBM 显示出比 XGBoost 更快的速度和更好的准确性。

Note

LightGBM 作为微软 Apache Spark 机器学习(MMLSpark)生态系统的一部分被移植到 Spark。微软一直在积极开发与 Apache Spark 生态系统无缝集成的数据科学和深度学习工具,如微软认知工具包、OpenCV 和 LightGBM。MMLSpark 需要 Python 2.7 或 3.5+,Scala 2.11,Spark 2.3+。

与 XGBoost 相比,LightGBM 有几个优点。它利用直方图将连续特征分入离散的箱中。这为 LightGBM 提供了优于 XGBoost(默认情况下,XGBoost 使用基于预排序的算法进行树学习)的几个性能优势,例如减少了内存使用、减少了计算每次分割的增益的成本,以及减少了并行学习的通信成本。LightGBM 通过对其兄弟节点和父节点执行直方图减法来计算节点的直方图,从而实现了额外的性能提升。在线基准测试显示,在某些任务中,LightGBM 比 XGBoost(不含宁滨)快 11 到 15 倍。XXIX

LightGBM 通过逐叶生长树(最佳优先),在准确性方面通常优于 XGBoost。训练决策树有两种主要策略,层次式和叶式(如图 3-7 所示)。对于大多数基于树的集成(包括 XGBoost),逐层树生长是生长决策树的传统方式。LightGBM 引入了逐叶增长策略。与水平方向生长相比,叶方向生长通常收敛更快 xxx 并且损失更低。 xxxi

Note

逐叶增长往往会过度适应小数据集。建议在 LightGBM 中设置 max_depth 参数来限制树深度。请注意,即使设置了 max_depth,树仍然是逐叶生长的。 xxxii 我将在本章后面讨论 LightGBM 参数调优。

img/488426_1_En_3_Fig7_HTML.png

图 3-7

水平方向生长与叶方向生长

Note

此后,XGBoost 实现了许多由 LightGBM 首创的优化,包括逐叶树生长策略和使用直方图将连续特征存储到离散容器中。最新的性能指标评测显示,XGBoost 达到了与 LightGBM 相当的性能。XXXIII

因素

与其他算法(如随机森林)相比,调优 LightGBM 稍微复杂一些。LightGBM 使用逐叶(最佳优先)树生长算法,如果参数配置不正确,该算法很容易过度拟合。而且,LightGBM 有 100 多个参数。关注最重要的参数足以帮助您开始使用 LightGBM。随着你对算法越来越熟悉,你可以学习剩下的部分。

  • max_depth :设置该参数,防止树长得太深。浅树不太可能过度生长。如果数据集很小,设置此参数尤其重要。

  • num_leaves :控制树模型的复杂度。该值应小于 2^(max_depth)以防止过度拟合。将 num_leaves 设置为一个较大的值可以提高精度,但有较高的过度拟合风险。将 num_leaves 设置为较小的值有助于防止过度拟合。

  • min_data_in_leaf: 将该参数设置为较大的值可以防止树长得太深。这是另一个可以设置的参数,有助于控制过度拟合。将该值设置得太大会导致拟合不足。

  • max_bin: LightGBM 使用直方图将连续特征的值分组到离散桶中。设置 max_bin 以指定值将分组到的箱数。较小的值有助于控制过拟合并提高训练速度,而较大的值可提高精度。

  • feature_fraction: 该参数启用特征子采样。此参数指定在每次迭代中随机选择的要素比例。例如,将 feature_fraction 设置为 0.75 将在每次迭代中随机选择 75%的要素。设置此参数可以提高训练速度,并有助于防止过度拟合。

  • bagging_fraction: 指定在每次迭代中选择的数据的分数。例如,将 bagging_fraction 设置为 0.75 将在每次迭代中随机选择 75%的数据。设置此参数可以提高训练速度,并有助于防止过度拟合。

  • num_iteration: 设置增强迭代的次数。默认值为 100。对于多类分类,LightGBM 构建 num_class * num_iterations 树。设置该参数会影响训练速度。

  • *目标:和 XGBoost 一样,LightGBM 支持多个目标。默认目标设置为回归。*将该参数设置为 s 指定你的模型试图执行的任务类型。对于回归任务,选项有 regression_l2、regression_l1、poisson、quantile、mape、gamma、huber、fair 或 tweedie。对于分类任务,选项有二进制、多类或多类集。正确设置目标很重要,以避免不可预测的结果或不准确。

与往常一样,强烈建议执行参数网格搜索来确定这些参数的最佳值。有关 LightGBM 参数的详细列表,请参考 LightGBM 在线文档。

Note

在撰写本文时,LightGBM for Spark 还没有达到与 LightGBM for Python 同等的特性。虽然 LightGBM for Spark 包含了最重要的参数,但仍然缺少一些。您可以通过访问 https://bit.ly/2OqHl2M 获得 LightGBM for Spark 中所有可用参数的列表。可以和 https://bit.ly/30YGyaO 的 LightGBM 参数完整列表进行对比。

例子

我们将重用相同的电信客户流失数据集和前面的 Random Forest 和 XGBoost 示例中的大部分代码,如清单 3-7 所示。

spark-shell --packages Azure:mmlspark:0.15

// Load the CSV file into a DataFrame.

val dataDF = spark.read.format("csv")
             .option("header", "true")
             .option("inferSchema", "true")
             .load("churn_data.txt")

// Check the schema.

dataDF.printSchema

root
 |-- state: string (nullable = true)
 |-- account_length: double (nullable = true)
 |-- area_code: double (nullable = true)
 |-- phone_number: string (nullable = true)
 |-- international_plan: string (nullable = true)
 |-- voice_mail_plan: string (nullable = true)
 |-- number_vmail_messages: double (nullable = true)
 |-- total_day_minutes: double (nullable = true)
 |-- total_day_calls: double (nullable = true)
 |-- total_day_charge: double (nullable = true)
 |-- total_eve_minutes: double (nullable = true)
 |-- total_eve_calls: double (nullable = true)
 |-- total_eve_charge: double (nullable = true)
 |-- total_night_minutes: double (nullable = true)
 |-- total_night_calls: double (nullable = true)
 |-- total_night_charge: double (nullable = true)
 |-- total_intl_minutes: double (nullable = true)
 |-- total_intl_calls: double (nullable = true)
 |-- total_intl_charge: double (nullable = true)
 |-- number_customer_service_calls: double (nullable = true)
 |-- churned: string (nullable = true)

// Select a few columns.

dataDF.select("state","phone_number","international_plan","churned").show

+-----+------------+------------------+-------+
|state|phone_number|international_plan|churned|
+-----+------------+------------------+-------+
|   KS|    382-4657|                no|  False|
|   OH|    371-7191|                no|  False|
|   NJ|    358-1921|                no|  False|
|   OH|    375-9999|               yes|  False|
|   OK|    330-6626|               yes|  False|
|   AL|    391-8027|               yes|  False|
|   MA|    355-9993|                no|  False|
|   MO|    329-9001|               yes|  False|
|   LA|    335-4719|                no|  False|
|   WV|    330-8173|               yes|  False|
|   IN|    329-6603|                no|   True|
|   RI|    344-9403|                no|  False|
|   IA|    363-1107|                no|  False|
|   MT|    394-8006|                no|  False|
|   IA|    366-9238|                no|  False|
|   NY|    351-7269|                no|   True|
|   ID|    350-8884|                no|  False|
|   VT|    386-2923|                no|  False|
|   VA|    356-2992|                no|  False|
|   TX|    373-2782|                no|  False|
+-----+------------+------------------+-------+
only showing top 20 rows

import org.apache.spark.ml.feature.StringIndexer

val labelIndexer = new StringIndexer().setInputCol("churned").setOutputCol("label")

val intPlanIndexer = new StringIndexer().setInputCol("international_plan").setOutputCol("int_plan")

val features = Array("number_customer_service_calls","total_day_minutes","total_eve_minutes","account_length","number_vmail_messages","total_day_calls","total_day_charge","total_eve_calls","total_eve_charge","total_night_calls","total_intl_calls","total_intl_charge","int_plan")

import org.apache.spark.ml.feature.VectorAssembler

val assembler = new VectorAssembler()
                .setInputCols(features)
                .setOutputCol("features")

val seed = 1234

val Array(trainingData, testData) = dataDF.randomSplit(Array(0.9, 0.1), seed)

// Create a LightGBM classifier.

import com.microsoft.ml.spark.LightGBMClassifier

val lightgbm = new LightGBMClassifier()
               .setFeaturesCol("features")
               .setLabelCol("label")
               .setRawPredictionCol("rawPrediction")
               .setObjective("binary")

// Remember to set the correct objective using the setObjective method.
// Specifying the incorrect objective can affect accuracy or produce
// unpredictable results. In LightGBM the default objective is set to
// regression. In this example, we are performing binary classification, so
// we set the objective to binary.

Listing 3-7Churn Prediction with LightGBM

Note

Spark 从 2.4 版本开始支持屏障执行模式。从 0.18 版开始,LightGBM 通过 setUseBarrierExecutionMode 方法支持屏障执行模式。

import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator

val evaluator = new BinaryClassificationEvaluator()
                .setLabelCol("label")
                .setMetricName("areaUnderROC")

import org.apache.spark.ml.tuning.ParamGridBuilder

val paramGrid = new ParamGridBuilder()
                .addGrid(lightgbm.maxDepth, Array(2, 3, 4))
                .addGrid(lightgbm.numLeaves, Array(4, 6, 8))
                .addGrid(lightgbm.numIterations, Array(600))
                .build()

import org.apache.spark.ml.{ Pipeline, PipelineStage }

val pipeline = new Pipeline()
               .setStages(Array(labelIndexer, intPlanIndexer, assembler, lightgbm))

import org.apache.spark.ml.tuning.CrossValidator

val cv = new CrossValidator()
         .setEstimator(pipeline)
         .setEvaluator(evaluator)
         .setEstimatorParamMaps(paramGrid)
         .setNumFolds(3)

val model = cv.fit(trainingData)

// You can now make some predictions on our test data.

val predictions = model.transform(testData)

predictions.printSchema

root
 |-- state: string (nullable = true)
 |-- account_length: double (nullable = true)
 |-- area_code: double (nullable = true)
 |-- phone_number: string (nullable = true)
 |-- international_plan: string (nullable = true)
 |-- voice_mail_plan: string (nullable = true)
 |-- number_vmail_messages: double (nullable = true)
 |-- total_day_minutes: double (nullable = true)
 |-- total_day_calls: double (nullable = true)
 |-- total_day_charge: double (nullable = true)
 |-- total_eve_minutes: double (nullable = true)
 |-- total_eve_calls: double (nullable = true)
 |-- total_eve_charge: double (nullable = true)
 |-- total_night_minutes: double (nullable = true)
 |-- total_night_calls: double (nullable = true)
 |-- total_night_charge: double (nullable = true)
 |-- total_intl_minutes: double (nullable = true)
 |-- total_intl_calls: double (nullable = true)
 |-- total_intl_charge: double (nullable = true)
 |-- number_customer_service_calls: double (nullable = true)
 |-- churned: string (nullable = true)
 |-- label: double (nullable = false)
 |-- int_plan: double (nullable = false)
 |-- features: vector (nullable = true)
 |-- rawPrediction: vector (nullable = true)
 |-- probability: vector (nullable = true)
 |-- prediction: double (nullable = false)

// Evaluate the model. The AUC score is higher than Random Forest
// and XGBoost from our previous examples.

val auc = evaluator.evaluate(predictions)

auc: Double = 0.940366124260358

//LightGBM also lets you extract the feature importance.

import com.microsoft.ml.spark.LightGBMClassificationModel
import org.apache.spark.ml.PipelineModel

val bestModel = model.bestModel

val model = bestModel.asInstanceOf[PipelineModel]
            .stages
            .last
            .asInstanceOf[LightGBMClassificationModel]

LightGBM 中有两种类型的特征重要性,“分裂”(分裂总数)和“增益”(总信息增益)。通常建议使用“增益”,这与 Random Forest 计算特征重要性的方法大致相似,但在我们的二进制分类示例中,LightGBM 使用交叉熵(对数损失)而不是基尼杂质(见表 3-3 和 3-4 )。要最小化的损失取决于指定的目标。XXXIV

表 3-3

使用信息增益的 LightGBM 的特征重要性

|

索引

|

特征

|

特征重要性

| | --- | --- | --- | | Zero | 数量 _ 客户 _ 服务 _ 呼叫 | 2648.0893859118223 | | one | 总计 _ 天 _ 分钟 | 5339.0795262902975 | | Two | 总计 _ eve _ 分钟 | 2191.832309693098 | | three | 帐户 _ 长度 | 564.6461282968521 | | four | 数字邮件消息 | 1180.4672759771347 | | five | 总计 _ 天 _ 次呼叫 | 656.8244850635529 | | six | 总计 _ 天 _ 费用 | Zero | | seven | 总通话次数 | 533.6638155579567 | | eight | 总费用 | 579.7435692846775 | | nine | 夜间通话总数 | 651.5408382415771 | | Ten | 总呼叫次数 | 1179.492751300335 | | Eleven | total_intl_charge | 2186.5995585918427 | | Twelve | int_plan | 1773.7864662855864 |

val gainFeatureImportances = model.getFeatureImportances("gain")

gainFeatureImportances: Array[Double] =
Array(2648.0893859118223, 5339.0795262902975, 2191.832309693098,564.6461282968521, 1180.4672759771347, 656.8244850635529, 0.0, 533.6638155579567, 579.7435692846775, 651.5408382415771, 1179.492751300335, 2186.5995585918427, 1773.7864662855864)

比较使用“split”时的输出。

val gainFeatureImportances = model.getFeatureImportances("split")
gainFeatureImportances: Array[Double] = Array(159.0, 583.0, 421.0, 259.0, 133.0, 264.0, 0.0, 214.0, 92.0, 279.0, 279.0, 366.0, 58.0)

表 3-4

使用拆分数量的 LightGBM 的特征重要性

|

索引

|

特征

|

特征重要性

| | --- | --- | --- | | Zero | 数量 _ 客户 _ 服务 _ 呼叫 | One hundred and fifty-nine | | one | 总计 _ 天 _ 分钟 | Five hundred and eighty-three | | Two | 总计 _ eve _ 分钟 | Four hundred and twenty-one | | three | 帐户 _ 长度 | Two hundred and fifty-nine | | four | 数字邮件消息 | One hundred and thirty-three | | five | 总计 _ 天 _ 次呼叫 | Two hundred and sixty-four | | six | 总计 _ 天 _ 费用 | Zero | | seven | 总通话次数 | Two hundred and fourteen | | eight | 总费用 | Ninety-two | | nine | 夜间通话总数 | Two hundred and seventy-nine | | Ten | 总呼叫次数 | Two hundred and seventy-nine | | Eleven | total_intl_charge | Three hundred and sixty-six | | Twelve | int_plan | Fifty-eight |

println(s"True Negative: ${predictions.select("*").where("prediction = 0 AND label = 0").count()}  True Positive: ${predictions.select("*").where("prediction = 1 AND label = 1").count()}")

True Negative: 407  True Positive: 58

println(s"False Negative: ${predictions.select("*").where("prediction = 0 AND label = 1").count()} False Positive: ${predictions.select("*").where("prediction = 1 AND label = 0").count()}")

False Negative: 20 False Positive: 9

基于朴素贝叶斯的情感分析

朴素贝叶斯是一种基于贝叶斯定理的简单多类线性分类算法。朴素贝叶斯之所以得名,是因为它天真地假设数据集中的要素是独立的,忽略了要素之间任何可能的相关性。现实世界的情况并非如此,朴素贝叶斯仍然表现良好,尤其是在小数据集或高维数据集上。像线性分类器一样,它在非线性分类问题上表现不佳。朴素贝叶斯是一种计算效率高且高度可伸缩的算法,只需要对数据集进行一次传递。对于使用大型数据集的分类任务,这是一个很好的基线模型。它的工作原理是在给定一组特征的情况下,找出一个点属于某个类的概率。贝叶斯定理方程可以表述为:

\boldsymbol{P}\left(\boldsymbol{A}|\boldsymbol{B}\right)=\frac{\boldsymbol{P}\left(\boldsymbol{B}|\boldsymbol{A}\right)\boldsymbol{P}\left(\boldsymbol{A}\right)}{\boldsymbol{P}\left(\boldsymbol{B}\right)}

P(A|B)是后验概率可以解释为:“给定事件 B,事件 A 发生的概率是多少?”B 代表特征向量。分子代表条件概率乘以先验概率。分母代表证据。这个等式可以更精确地写成:

\boldsymbol{P}\left(\boldsymbol{y}|{\boldsymbol{x}}_{\mathbf{1}},\dots, {\boldsymbol{x}}_{\boldsymbol{n}}\right)=\frac{\boldsymbol{P}\left({\boldsymbol{x}}_{\mathbf{1}},\dots, {\boldsymbol{x}}_{\boldsymbol{n}}|\boldsymbol{y}\right)}{\boldsymbol{P}\left({\boldsymbol{x}}_{\mathbf{1}},\dots, {\boldsymbol{x}}_{\boldsymbol{n}}\right)}

朴素贝叶斯常用于文本分类。文本分类的流行应用包括垃圾邮件检测和文档分类。另一个文本分类用例是情感分析。公司定期检查来自社交媒体的评论,以确定公众对产品或服务的意见是积极的还是消极的。对冲基金利用情绪分析来预测股市走势。

Spark MLlib 支持伯努利朴素贝叶斯和多项式朴素贝叶斯。伯努利朴素贝叶斯仅适用于布尔或二进制特征(例如,文档中存在或不存在单词),而多项式朴素贝叶斯是为离散特征(例如,单词计数)设计的。MLlib 的朴素贝叶斯实现的默认模型类型设置为多项式。可以为平滑设置另一个参数 lambda(默认值为 1.0)。

例子

让我们用一个例子来演示如何使用朴素贝叶斯进行情感分析。我们将使用来自加州大学欧文分校机器学习知识库的流行数据集。该数据集是为 Kotzias 等人的论文“使用深度特征从群体到个体标签”创建的。艾尔。,KDD 2015。数据集来自三家不同的公司:IMDB、亚马逊和 Yelp。每个公司有 500 个正面和 500 个负面评论。我们将使用来自亚马逊的数据集,根据亚马逊产品评论来确定特定产品的情绪是正面(1)还是负面(0)的概率。

我们需要将数据集中的每个句子转换成一个特征向量。Spark MLlib 为此提供了一个转换器。术语频率逆文档频率(TF IDF)通常用于从文本生成特征向量。TF IDF 用于通过计算单词在文档中出现的次数(TF)和单词在整个语料库中出现的频率(IDF)来确定单词与语料库中文档的相关性。在 Spark MLlib 中,TF 和 IDF 是分开实现的(HashingTF 和 IDF)。

在我们可以使用 TF IDF 将单词转换成特征向量之前,我们需要使用另一个转换器 tokenizer 将句子分割成单独的单词。这些步骤应该如图 3-8 所示,代码如清单 3-8 所示。

img/488426_1_En_3_Fig8_HTML.png

图 3-8

我们的情感分析示例的特征转换

// Start by creating a schema for our dataset.
import org.apache.spark.sql.types._

var reviewsSchema = StructType(Array (
    StructField("text",   StringType, true),
    StructField("label",  IntegerType, true)
    ))

// Create a DataFrame from the tab-delimited text file.
// Use the "csv" format regardless if its tab or comma delimited.
// The file does not have a header, so we’ll set the header
// option to false. We’ll set delimiter to tab and use the schema
// that we just built.

val reviewsDF = spark.read.format("csv")
                .option("header", "false")
                .option("delimiter","\t")
                .schema(reviewsSchema)
                .load("/files/amazon_cells_labelled.txt")

// Review the schema.

reviewsDF.printSchema

root
 |-- text: string (nullable = true)
 |-- label: integer (nullable = true)

// Check the data.

reviewsDF.show

+--------------------+-----+
|                text|label|
+--------------------+-----+
|So there is no wa...|    0|
|Good case, Excell...|    1|
|Great for the jaw...|    1|
|Tied to charger f...|    0|
|   The mic is great.|    1|
|I have to jiggle ...|    0|
|If you have sever...|    0|
|If you are Razr o...|    1|
|Needless to say, ...|    0|
|What a waste of m...|    0|
|And the sound qua...|    1|
|He was very impre...|    1|
|If the two were s...|    0|
|Very good quality...|    1|
|The design is ver...|    0|
|Highly recommend ...|    1|
|I advise EVERYONE...|    0|
|    So Far So Good!.|    1|
|       Works great!.|    1|
|It clicks into pl...|    0|
+--------------------+-----+
only showing top 20 rows

// Let's do some row counts.

reviewsDF.createOrReplaceTempView("reviews")

spark.sql("select label,count(*) from reviews group by label").show

+-----+--------+
|label|count(1)|
+-----+--------+
|    1|     500|
|    0|     500|
+-----+--------+

// Randomly divide the dataset into training and test datasets.

val seed = 1234

val Array(trainingData, testData) = reviewsDF.randomSplit(Array(0.8, 0.2), seed)

trainingData.count
res5: Long = 827

testData.count
res6: Long = 173

// Split the sentences into words.

import org.apache.spark.ml.feature.Tokenizer

val tokenizer = new Tokenizer().setInputCol("text")
                .setOutputCol("words")

// Check the tokenized data.

val tokenizedDF = tokenizer.transform(trainingData)

tokenizedDF.show

+--------------------+-----+--------------------+
|                text|label|               words|
+--------------------+-----+--------------------+
|         (It works!)|    1|      [(it, works!)]|
|)Setup couldn't h...|    1|[)setup, couldn't...|
|* Comes with a st...|    1|[*, comes, with, ...|
|.... Item arrived...|    1|[...., item, arri...|
|1\. long lasting b...|    0|[1., long, lastin...|
|2 thumbs up to th...|    1|[2, thumbs, up, t...|
|:-)Oh, the charge...|    1|[:-)oh,, the, cha...|
|   A Disappointment.|    0|[a, disappointment.]|
|A PIECE OF JUNK T...|    0|[a, piece, of, ju...|
|A good quality ba...|    1|[a, good, quality...|
|A must study for ...|    0|[a, must, study, ...|
|A pretty good pro...|    1|[a, pretty, good,...|
|A usable keyboard...|    1|[a, usable, keybo...|
|A week later afte...|    0|[a, week, later, ...|
|AFTER ARGUING WIT...|    0|[after, arguing, ...|
|AFter the first c...|    0|[after, the, firs...|
|       AMAZON SUCKS.|    0|    [amazon, sucks.]|
|     Absolutel junk.|    0|  [absolutel, junk.]|
|   Absolutely great.|    1|[absolutely, great.]|
|Adapter does not ...|    0|[adapter, does, n...|
+--------------------+-----+--------------------+
only showing top 20 rows

// Next, we'll use HashingTF to convert the tokenized words
// into fixed-length feature vector.

import org.apache.spark.ml.feature.HashingTF

val htf = new HashingTF().setNumFeatures(1000)
          .setInputCol("words")
.setOutputCol("features")

// Check the vectorized features.

val hashedDF = htf.transform(tokenizedDF)

hashedDF.show

+--------------------+-----+--------------------+--------------------+
|                text|label|               words|            features|
+--------------------+-----+--------------------+--------------------+
|         (It works!)|    1|      [(it, works!)]|(1000,[369,504],[...|
|)Setup couldn't h...|    1|[)setup, couldn't...|(1000,[299,520,53...|
|* Comes with a st...|    1|[*, comes, with, ...|(1000,[34,51,67,1...|
|.... Item arrived...|    1|[...., item, arri...|(1000,[98,133,245...|
|1\. long lasting b...|    0|[1., long, lastin...|(1000,[138,258,29...|
|2 thumbs up to th...|    1|[2, thumbs, up, t...|(1000,[92,128,373...|
|:-)Oh, the charge...|    1|[:-)oh,, the, cha...|(1000,[388,497,52...|
|   A Disappointment.|    0|[a, disappointment.]|(1000,[170,386],[...|
|A PIECE OF JUNK T...|    0|[a, piece, of, ju...|(1000,[34,36,47,7...|
|A good quality ba...|    1|[a, good, quality...|(1000,[77,82,168,...|
|A must study for ...|    0|[a, must, study, ...|(1000,[23,36,104,...|
|A pretty good pro...|    1|[a, pretty, good,...|(1000,[168,170,27...|
|A usable keyboard...|    1|[a, usable, keybo...|(1000,[2,116,170,...|
|A week later afte...|    0|[a, week, later, ...|(1000,[77,122,156...|
|AFTER ARGUING WIT...|    0|[after, arguing, ...|(1000,[77,166,202...|
|AFter the first c...|    0|[after, the, firs...|(1000,[63,77,183,...|
|       AMAZON SUCKS.|    0|    [amazon, sucks.]|(1000,[828,966],[...|
|     Absolutel junk.|    0|  [absolutel, junk.]|(1000,[607,888],[...|
|   Absolutely great.|    1|[absolutely, great.]|(1000,[589,903],...|
|Adapter does not ...|    0|[adapter, does, n...|(1000,[0,18,51,28...|
+--------------------+-----+--------------------+--------------------+
only showing top 20 rows

// We will use the naïve Bayes classifier provided by MLlib.

import org.apache.spark.ml.classification.NaiveBayes

val nb = new NaiveBayes()

// We now have all the parts that we need to assemble
// a machine learning pipeline.

import org.apache.spark.ml.Pipeline

val pipeline = new Pipeline().setStages(Array(tokenizer, htf, nb))

// Train our model using the training dataset.

val model = pipeline.fit(trainingData)

// Predict using the test dataset.

val predictions = model.transform(testData)

// Display the predictions for each review.

predictions.select("text","prediction").show

+--------------------+----------+
|                text|prediction|
+--------------------+----------+
|!I definitely reco...|       1.0|
|#1 It Works - #2 ...|       1.0|
| $50 Down the drain.|       0.0|
|A lot of websites...|       1.0|
|After charging ov...|       0.0|
|After my phone go...|       0.0|
|All in all I thin...|       1.0|
|All it took was o...|       0.0|
|Also, if your pho...|       0.0|
|And I just love t...|       1.0|
|And none of the t...|       1.0|
|         Bad Choice.|       0.0|
|Best headset ever...|       1.0|
|Big Disappointmen...|       0.0|
|Bluetooth range i...|       0.0|
|But despite these...|       0.0|
|Buyer--Be Very Ca...|       1.0|
|Can't store anyth...|       0.0|
|Chinese Forgeries...|       0.0|
|Do NOT buy if you...|       0.0|
+--------------------+----------+

only showing top 20 rows

// Evaluate our model using a binary classifier evaluator.

import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator

val evaluator = new BinaryClassificationEvaluator()

import org.apache.spark.ml.param.ParamMap

val paramMap = ParamMap(evaluator.metricName -> "areaUnderROC")

val auc = evaluator.evaluate(predictions, paramMap)

auc: Double = 0.5407085561497325

// Test on a positive example.

val predictions = model
.transform(sc.parallelize(Seq("This product is good")).toDF("text"))

predictions.select("text","prediction").show

+--------------------+----------+
|                text|prediction|
+--------------------+----------+
|This product is good|       1.0|
+--------------------+----------+

// Test on a negative example.

val predictions = model
.transform(sc.parallelize(Seq("This product is bad")).toDF("text"))

predictions.select("text","prediction").show

+-------------------+----------+
|               text|prediction|
+-------------------+----------+
|This product is bad|       0.0|
+-------------------+----------+

Listing 3-8Sentiment Analysis Using Naïve Bayes

可以做几件事来改进我们的模型。在大多数自然语言处理(NLP)任务中,执行额外的文本预处理(如 n 元语法、词汇化和去除停用词)是很常见的。我在第四章中介绍了斯坦福 CoreNLP 和 Spark NLP。

回归

回归是一种用于预测连续数值的监督机器学习任务。举几个例子来说,流行的用例包括销售和需求预测、预测股票、房屋或商品价格以及天气预报。我在第 [1 章更详细地讨论了回归。

简单线性回归

线性回归用于检查一个或多个自变量和因变量之间的线性关系。对单个自变量和单个连续因变量之间关系的分析称为简单线性回归。

正如您在图 3-9 中所看到的,该图显示了线性攻击试图绘制一条直线,以最好地减少观察到的响应和预测值之间的残差平方和。XXXV

img/488426_1_En_3_Fig9_HTML.png

图 3-9

简单的线性回归图

例子

对于我们的例子,我们将使用简单的线性回归来显示房价(因变量)如何根据该地区的平均家庭收入(自变量)而变化。清单 3-9 详细列出了代码。

import org.apache.spark.ml.regression.LinearRegression
import spark.implicits._

val dataDF = Seq(
 (50000, 302200),
 (75200, 550000),
 (90000, 680000),
 (32800, 225000),
 (41000, 275000),
 (54000, 300500),
 (72000, 525000),
 (105000, 700000),
 (88500, 673100),
 (92000, 695000),
 (53000, 320900),
 (85200, 652800),
 (157000, 890000),
 (128000, 735000),
 (71500, 523000),
 (114000, 720300),
 (33400, 265900),
 (143000, 846000),
 (68700, 492000),
 (46100, 285000)
).toDF("avg_area_income","price")

dataDF.show
+---------------+------+
|avg_area_income| price|
+---------------+------+
|          50000|302200|
|          75200|550000|
|          90000|680000|
|          32800|225000|
|          41000|275000|
|          54000|300500|
|          72000|525000|
|         105000|700000|
|          88500|673100|
|          92000|695000|
|          53000|320900|
|          85200|652800|
|         157000|890000|
|         128000|735000|
|          71500|523000|
|         114000|720300|
|          33400|265900|
|         143000|846000|
|          68700|492000|
|          46100|285000|
+---------------+------+

import org.apache.spark.ml.feature.VectorAssembler

val assembler = new VectorAssembler()
                .setInputCols(Array("avg_area_income"))
                .setOutputCol("feature")

val dataDF2 = assembler.transform(dataDF)

dataDF2.show
+---------------+------+----------+
|avg_area_income| price|   feature|
+---------------+------+----------+
|          50000|302200| [50000.0]|
|          75200|550000| [75200.0]|
|          90000|680000| [90000.0]|
|          32800|225000| [32800.0]|
|          41000|275000| [41000.0]|
|          54000|300500| [54000.0]|
|          72000|525000| [72000.0]|
|         105000|700000|[105000.0]|
|          88500|673100| [88500.0]|
|          92000|695000| [92000.0]|
|          53000|320900| [53000.0]|
|          85200|652800| [85200.0]|
|         157000|890000|[157000.0]|
|         128000|735000|[128000.0]|
|          71500|523000| [71500.0]|
|         114000|720300|[114000.0]|
|          33400|265900| [33400.0]|
|         143000|846000|[143000.0]|
|          68700|492000| [68700.0]|
|          46100|285000| [46100.0]|
+---------------+------+----------+

val lr = new LinearRegression()
         .setMaxIter(10)
         .setFeaturesCol("feature")
         .setLabelCol("price")

val model = lr.fit(dataDF2)

import org.apache.spark.ml.linalg.Vectors

val testData = spark
         .createDataFrame(Seq(Vectors.dense(75000))
         .map(Tuple1.apply))
         .toDF("feature")

val predictions = model.transform(testData)

predictions.show
+---------+------------------+
|  feature|        prediction|
+---------+------------------+
|[75000.0]|504090.35842779215|
+---------+------------------+

Listing 3-9Linear Regression Example

使用 XGBoost4J-Spark 进行多元回归

多元回归用于有两个或更多自变量和一个连续因变量的更现实的情况。在现实世界的用例中,同时具有线性和非线性特性是很常见的。XGBoost 等基于树的集成算法能够处理线性和非线性特征,这使其成为大多数生产环境的理想选择。在大多数情况下,使用基于树的集成(如 XGBoost)进行多元回归应该会显著提高预测精度。XXXVI

我们在本章前面使用 XGBoost 解决了一个分类问题。因为 XGBoost 同时支持分类和回归,所以使用 XGBoost 进行回归与分类非常相似。

例子

对于我们的多元回归示例,我们将使用稍微复杂一点的数据集,如清单 3-10 所示。数据集可以从 Kaggle 下载。 xxxvii 我们的目标是根据数据集中提供的属性预测房价。数据集包含七列:Avg。地区收入,平均。面积房屋年龄,平均。房间面积,平均。卧室面积数量,面积人口,价格和地址。为了简单起见,我们将不使用地址字段(有用的信息可以从家庭地址中获得,例如附近学校的位置)。价格是我们的因变量。

spark-shell --packages ml.dmlc:xgboost4j-spark:0.81

import org.apache.spark.sql.types._

// Define a schema for our dataset.

var pricesSchema = StructType(Array (
    StructField("avg_area_income",   DoubleType, true),
    StructField("avg_area_house_age",   DoubleType, true),
    StructField("avg_area_num_rooms",   DoubleType, true),
    StructField("avg_area_num_bedrooms",   DoubleType, true),
    StructField("area_population",   DoubleType, true),
    StructField("price",   DoubleType, true)
    ))

val dataDF = spark.read.format("csv")
             .option("header","true")
             .schema(pricesSchema)
             .load("USA_Housing.csv").na.drop()

// Inspect the dataset.

dataDF.printSchema
root
 |-- avg_area_income: double (nullable = true)
 |-- avg_area_house_age: double (nullable = true)
 |-- avg_area_num_rooms: double (nullable = true)
 |-- avg_area_num_bedrooms: double (nullable = true)
 |-- area_population: double (nullable = true)
 |-- price: double (nullable = true)

dataDF.select("avg_area_income","avg_area_house_age","avg_area_num_rooms").show

+------------------+------------------+------------------+
|   avg_area_income|avg_area_house_age|avg_area_num_rooms|
+------------------+------------------+------------------+
| 79545.45857431678| 5.682861321615587| 7.009188142792237|
| 79248.64245482568|6.0028998082752425| 6.730821019094919|
|61287.067178656784| 5.865889840310001| 8.512727430375099|
| 63345.24004622798|7.1882360945186425| 5.586728664827653|
|59982.197225708034| 5.040554523106283| 7.839387785120487|
|  80175.7541594853|4.9884077575337145| 6.104512439428879|
| 64698.46342788773| 6.025335906887153| 8.147759585023431|
| 78394.33927753085|6.9897797477182815| 6.620477995185026|
| 59927.66081334963|  5.36212556960358|6.3931209805509015|
| 81885.92718409566| 4.423671789897876| 8.167688003472351|
| 80527.47208292288|  8.09351268063935| 5.042746799645982|
| 50593.69549704281| 4.496512793097035| 7.467627404008019|
|39033.809236982364| 7.671755372854428| 7.250029317273495|
|  73163.6634410467| 6.919534825456555|5.9931879009455695|
|  69391.3801843616| 5.344776176735725| 8.406417714534253|
| 73091.86674582321| 5.443156466535474| 8.517512711137975|
| 79706.96305765743| 5.067889591058972| 8.219771123286257|
| 61929.07701808926| 4.788550241805888|5.0970095543775615|
| 63508.19429942997| 5.947165139552473| 7.187773835329727|
| 62085.27640340488| 5.739410843630574|  7.09180810424997|
+------------------+------------------+------------------+
only showing top 20 rows

dataDF.select("avg_area_num_bedrooms","area_population","price").show

+---------------------+------------------+------------------+
|avg_area_num_bedrooms|   area_population|             price|
+---------------------+------------------+------------------+
|                 4.09|23086.800502686456|1059033.5578701235|
|                 3.09| 40173.07217364482|  1505890.91484695|
|                 5.13| 36882.15939970458|1058987.9878760849|
|                 3.26| 34310.24283090706|1260616.8066294468|
|                 4.23|26354.109472103148| 630943.4893385402|
|                 4.04|26748.428424689715|1068138.0743935304|
|                 3.41| 60828.24908540716|1502055.8173744078|
|                 2.42|36516.358972493836|1573936.5644777215|
|                  2.3| 29387.39600281585| 798869.5328331633|
|                  6.1| 40149.96574921337|1545154.8126419624|
|                  4.1| 47224.35984022191| 1707045.722158058|
|                 4.49|34343.991885578806| 663732.3968963273|
|                  3.1| 39220.36146737246|1042814.0978200927|
|                 2.27|32326.123139488096|1291331.5184858206|
|                 4.37|35521.294033173246|1402818.2101658515|
|                 4.01|23929.524053267953|1306674.6599511993|
|                 3.12| 39717.81357630952|1556786.6001947748|
|                  4.3| 24595.90149782299| 528485.2467305964|
|                 5.12|35719.653052030866|1019425.9367578316|
|                 5.49|44922.106702293066|1030591.4292116085|
+---------------------+------------------+------------------+
only showing top 20 rows

val features = Array("avg_area_income","avg_area_house_age",
"avg_area_num_rooms","avg_area_num_bedrooms","area_population")

// Combine our features into a single feature vector.

import org.apache.spark.ml.feature.VectorAssembler

val assembler = new VectorAssembler()
                .setInputCols(features)
                .setOutputCol("features")

val dataDF2 = assembler.transform(dataDF)

dataDF2.select("price","features").show(20,50)

+------------------+--------------------------------------------------+
|             price|                                          features|
+------------------+--------------------------------------------------+
|1059033.5578701235|[79545.45857431678,5.682861321615587,7.00918814...|
|  1505890.91484695|[79248.64245482568,6.0028998082752425,6.7308210...|
|1058987.9878760849|[61287.067178656784,5.865889840310001,8.5127274...|
|1260616.8066294468|[63345.24004622798,7.1882360945186425,5.5867286...|
| 630943.4893385402|[59982.197225708034,5.040554523106283,7.8393877...|
|1068138.0743935304|[80175.7541594853,4.9884077575337145,6.10451243...|
|1502055.8173744078|[64698.46342788773,6.025335906887153,8.14775958...|
|1573936.5644777215|[78394.33927753085,6.9897797477182815,6.6204779...|
| 798869.5328331633|[59927.66081334963,5.36212556960358,6.393120980...|
|1545154.8126419624|[81885.92718409566,4.423671789897876,8.16768800...|
| 1707045.722158058|[80527.47208292288,8.09351268063935,5.042746799...|
| 663732.3968963273|[50593.69549704281,4.496512793097035,7.46762740...|
|1042814.0978200927|[39033.809236982364,7.671755372854428,7.2500293...|
|1291331.5184858206|[73163.6634410467,6.919534825456555,5.993187900...|
|1402818.2101658515|[69391.3801843616,5.344776176735725,8.406417714...|
|1306674.6599511993|[73091.86674582321,5.443156466535474,8.51751271...|
|1556786.6001947748|[79706.96305765743,5.067889591058972,8.21977112...|
| 528485.2467305964|[61929.07701808926,4.788550241805888,5.09700955...|
|1019425.9367578316|[63508.19429942997,5.947165139552473,7.18777383...|
|1030591.4292116085|[62085.27640340488,5.739410843630574,7.09180810...|
+------------------+--------------------------------------------------+
only showing top 20 rows

// Divide our dataset into training and test data.

val seed = 1234

val Array(trainingData, testData) = dataDF2.randomSplit(Array(0.8, 0.2), seed)

// Use XGBoost for regression.

import ml.dmlc.xgboost4j.scala.spark.{XGBoostRegressionModel,XGBoostRegressor}

val xgb = new XGBoostRegressor()
          .setFeaturesCol("features")
          .setLabelCol("price")

// Create a parameter grid.

import org.apache.spark.ml.tuning.ParamGridBuilder

val paramGrid = new ParamGridBuilder()
                .addGrid(xgb.maxDepth, Array(6, 9))
                .addGrid(xgb.eta, Array(0.3, 0.7)).build()

paramGrid: Array[org.apache.spark.ml.param.ParamMap] =
Array({
      xgbr_bacf108db722-eta: 0.3,
      xgbr_bacf108db722-maxDepth: 6
}, {
      xgbr_bacf108db722-eta: 0.3,
      xgbr_bacf108db722-maxDepth: 9
}, {
      xgbr_bacf108db722-eta: 0.7,
      xgbr_bacf108db722-maxDepth: 6
}, {
      xgbr_bacf108db722-eta: 0.7,
      xgbr_bacf108db722-maxDepth: 9
})

// Create our evaluator.

import org.apache.spark.ml.evaluation.RegressionEvaluator

val evaluator = new RegressionEvaluator()
               .setLabelCol("price")
               .setPredictionCol("prediction")
               .setMetricName("rmse")

// Create our cross-validator.

import org.apache.spark.ml.tuning.CrossValidator

val cv = new CrossValidator()
         .setEstimator(xgb)
         .setEvaluator(evaluator)
         .setEstimatorParamMaps(paramGrid)
         .setNumFolds(3)

val model = cv.fit(trainingData)

val predictions = model.transform(testData)

predictions.select("features","price","prediction").show

+--------------------+------------------+------------+
|            features|             price|  prediction|
+--------------------+------------------+------------+
|17796.6311895433...|302355.83597895555| 591896.9375|
|[35454.7146594754...| 1077805.577726322|   440094.75|
|[35608.9862370775...| 449331.5835333807|   672114.75|
|[38868.2503114142...| 759044.6879907805|   672114.75|
|[40752.7142433209...| 560598.5384309639| 591896.9375|
|[41007.4586732745...| 494742.5435776913|421605.28125|
|[41533.0129597444...| 682200.3005599922|505685.96875|
|[42258.7745410484...| 852703.2636757497| 591896.9375|
|[42940.1389392421...| 680418.7240122693| 591896.9375|
|[43192.1144092488...|1054606.9845532854|505685.96875|
|[43241.9824225005...| 629657.6132544072|505685.96875|
|[44328.2562966742...| 601007.3511604669|141361.53125|
|[45347.1506816944...| 541953.9056802422|441908.40625|
|[45546.6434075757...|   923830.33486809| 591896.9375|
|[45610.9384142094...|  961354.287727855|   849175.75|
|[45685.2499205068...| 867714.3838490517|441908.40625|
|[45990.1237417814...|1043968.3994445396|   849175.75|
|[46062.7542664558...| 675919.6815570832|505685.96875|
|[46367.2058588838...|268050.81474351394|  379889.625|
|[47467.4239151893...| 762144.9261238109| 591896.9375|
+--------------------+------------------+------------+
only showing top 20 rows

Listing 3-10Multiple Regression Using XGBoost4J-Spark

让我们用均方根误差(RMSE)来评估这个模型。残差是数据点与回归线之间距离的度量。RMSE 是残差的标准差,用于测量预测误差。[xxxviii

val rmse = evaluator.evaluate(predictions)

rmse: Double = 438499.82356536255

// Extract the parameters.

model.bestModel.extractParamMap

res11: org.apache.spark.ml.param.ParamMap =
{
      xgbr_8da6032c61a9-alpha: 0.0,
      xgbr_8da6032c61a9-baseScore: 0.5,
      xgbr_8da6032c61a9-checkpointInterval: -1,
      xgbr_8da6032c61a9-checkpointPath: ,
      xgbr_8da6032c61a9-colsampleBylevel: 1.0,
      xgbr_8da6032c61a9-colsampleBytree: 1.0,
      xgbr_8da6032c61a9-customEval: null,
      xgbr_8da6032c61a9-customObj: null,
      xgbr_8da6032c61a9-eta: 0.7,
      xgbr_8da6032c61a9-evalMetric: rmse,
      xgbr_8da6032c61a9-featuresCol: features,
      xgbr_8da6032c61a9-gamma: 0.0,
      xgbr_8da6032c61a9-growPolicy: depthwise,

      xgbr_8da6032c61a9-labelCol: price,
      xgbr_8da6032c61a9-lambda: 1.0,
      xgbr_8da6032c61a9-lambdaBias: 0.0,
      xgbr_8da6032c61a9-maxBin: 16,
      xgbr_8da6032c61a9-maxDeltaStep: 0.0,
      xgbr_8da6032c61a9-maxDepth: 9,
      xgbr_8da6032c61a9-minChildWeight: 1.0,
      xgbr_8da6032c61a9-missing: NaN,
      xgbr_8da6032c61a9-normalizeType: tree,
      xgbr_8da6032c61a9-nthread: 1,
      xgbr_8da6032c61a9-numEarlyStoppingRounds: 0,
      xgbr_8da6032c61a9-numRound: 1,
      xgbr_8da6032c61a9-numWorkers: 1,
      xgbr_8da6032c61a9-objective: reg:linear,
      xgbr_8da6032c61a9-predictionCol: prediction,
      xgbr_8da6032c61a9-rateDrop: 0.0,
      xgbr_8da6032c61a9-sampleType: uniform,
      xgbr_8da6032c61a9-scalePosWeight: 1.0,
      xgbr_8da6032c61a9-seed: 0,
      xgbr_8da6032c61a9-silent: 0,
      xgbr_8da6032c61a9-sketchEps: 0.03,
      xgbr_8da6032c61a9-skipDrop: 0.0,
      xgbr_8da6032c61a9-subsample: 1.0,
      xgbr_8da6032c61a9-timeoutRequestWorkers: 1800000,
      xgbr_8da6032c61a9-trackerConf: TrackerConf(0,python),
      xgbr_8da6032c61a9-trainTestRatio: 1.0,
      xgbr_8da6032c61a9-treeLimit: 0,
      xgbr_8da6032c61a9-treeMethod: auto,
      xgbr_8da6032c61a9-useExternalMemory: false

}

LightGBM 多元回归

在清单 3-11 中,我们将使用 LightGBM。LightGBM 附带了专门用于回归任务的 LightGBMRegressor 类。我们将重用住房数据集和上一个 XGBoost 示例中的大部分代码。

spark-shell --packages Azure:mmlspark:0.15

var pricesSchema = StructType(Array (
    StructField("avg_area_income",   DoubleType, true),
    StructField("avg_area_house_age",   DoubleType, true),
    StructField("avg_area_num_rooms",   DoubleType, true),
    StructField("avg_area_num_bedrooms",   DoubleType, true),
    StructField("area_population",   DoubleType, true),
    StructField("price",   DoubleType, true)
    ))

val dataDF = spark.read.format("csv")
             .option("header","true")
             .schema(pricesSchema)
             .load("USA_Housing.csv")
             .na.drop()

dataDF.printSchema

root
 |-- avg_area_income: double (nullable = true)
 |-- avg_area_house_age: double (nullable = true)
 |-- avg_area_num_rooms: double (nullable = true)
 |-- avg_area_num_bedrooms: double (nullable = true)
 |-- area_population: double (nullable = true)
 |-- price: double (nullable = true)

dataDF.select("avg_area_income","avg_area_house_age",
"avg_area_num_rooms")
.show

+------------------+------------------+------------------+
|   avg_area_income|avg_area_house_age|avg_area_num_rooms|
+------------------+------------------+------------------+
| 79545.45857431678| 5.682861321615587| 7.009188142792237|
| 79248.64245482568|6.0028998082752425| 6.730821019094919|
|61287.067178656784| 5.865889840310001| 8.512727430375099|
| 63345.24004622798|7.1882360945186425| 5.586728664827653|
|59982.197225708034| 5.040554523106283| 7.839387785120487|
|  80175.7541594853|4.9884077575337145| 6.104512439428879|
| 64698.46342788773| 6.025335906887153| 8.147759585023431|
| 78394.33927753085|6.9897797477182815| 6.620477995185026|
| 59927.66081334963|  5.36212556960358|6.3931209805509015|
| 81885.92718409566| 4.423671789897876| 8.167688003472351|
| 80527.47208292288|  8.09351268063935| 5.042746799645982|
| 50593.69549704281| 4.496512793097035| 7.467627404008019|
|39033.809236982364| 7.671755372854428| 7.250029317273495|
|  73163.6634410467| 6.919534825456555|5.9931879009455695|
|  69391.3801843616| 5.344776176735725| 8.406417714534253|
| 73091.86674582321| 5.443156466535474| 8.517512711137975|
| 79706.96305765743| 5.067889591058972| 8.219771123286257|
| 61929.07701808926| 4.788550241805888|5.0970095543775615|
| 63508.19429942997| 5.947165139552473| 7.187773835329727|
| 62085.27640340488| 5.739410843630574|  7.09180810424997|
+------------------+------------------+------------------+

dataDF.select("avg_area_num_bedrooms","area_population","price").show

+---------------------+------------------+------------------+
|avg_area_num_bedrooms|   area_population|             price|
+---------------------+------------------+------------------+
|                 4.09|23086.800502686456|1059033.5578701235|
|                 3.09| 40173.07217364482|  1505890.91484695|
|                 5.13| 36882.15939970458|1058987.9878760849|
|                 3.26| 34310.24283090706|1260616.8066294468|
|                 4.23|26354.109472103148| 630943.4893385402|
|                 4.04|26748.428424689715|1068138.0743935304|
|                 3.41| 60828.24908540716|1502055.8173744078|
|                 2.42|36516.358972493836|1573936.5644777215|
|                  2.3| 29387.39600281585| 798869.5328331633|
|                  6.1| 40149.96574921337|1545154.8126419624|
|                  4.1| 47224.35984022191| 1707045.722158058|
|                 4.49|34343.991885578806| 663732.3968963273|
|                  3.1| 39220.36146737246|1042814.0978200927|
|                 2.27|32326.123139488096|1291331.5184858206|
|                 4.37|35521.294033173246|1402818.2101658515|
|                 4.01|23929.524053267953|1306674.6599511993|
|                 3.12| 39717.81357630952|1556786.6001947748|
|                  4.3| 24595.90149782299| 528485.2467305964|
|                 5.12|35719.653052030866|1019425.9367578316|
|                 5.49|44922.106702293066|1030591.4292116085|
+---------------------+------------------+------------------+
only showing top 20 rows

val features = Array("avg_area_income","avg_area_house_age",
"avg_area_num_rooms","avg_area_num_bedrooms","area_population")

import org.apache.spark.ml.feature.VectorAssembler

val assembler = new VectorAssembler()
                .setInputCols(features)
                .setOutputCol("features")

val dataDF2 = assembler.transform(dataDF)

dataDF2.select("price","features").show(20,50)

+------------------+--------------------------------------------------+
|             price|                                          features|
+------------------+--------------------------------------------------+
|1059033.5578701235|[79545.45857431678,5.682861321615587,7.00918814...|
|  1505890.91484695|[79248.64245482568,6.0028998082752425,6.7308210...|
|1058987.9878760849|[61287.067178656784,5.865889840310001,8.5127274...|
|1260616.8066294468|[63345.24004622798,7.1882360945186425,5.5867286...|
| 630943.4893385402|[59982.197225708034,5.040554523106283,7.8393877...|
|1068138.0743935304|[80175.7541594853,4.9884077575337145,6.10451243...|
|1502055.8173744078|[64698.46342788773,6.025335906887153,8.14775958...|
|1573936.5644777215|[78394.33927753085,6.9897797477182815,6.6204779...|
| 798869.5328331633|[59927.66081334963,5.36212556960358,6.393120980...|
|1545154.8126419624|[81885.92718409566,4.423671789897876,8.16768800...|
| 1707045.722158058|[80527.47208292288,8.09351268063935,5.042746799...|
| 663732.3968963273|[50593.69549704281,4.496512793097035,7.46762740...|
|1042814.0978200927|[39033.809236982364,7.671755372854428,7.2500293...|
|1291331.5184858206|[73163.6634410467,6.919534825456555,5.993187900...|
|1402818.2101658515|[69391.3801843616,5.344776176735725,8.406417714...|
|1306674.6599511993|[73091.86674582321,5.443156466535474,8.51751271...|
|1556786.6001947748|[79706.96305765743,5.067889591058972,8.21977112...|
| 528485.2467305964|[61929.07701808926,4.788550241805888,5.09700955...|
|1019425.9367578316|[63508.19429942997,5.947165139552473,7.18777383...|
|1030591.4292116085|[62085.27640340488,5.739410843630574,7.09180810...|
+------------------+--------------------------------------------------+
only showing top 20 rows

val seed = 1234

val Array(trainingData, testData) = dataDF2.randomSplit(Array(0.8, 0.2), seed)

import com.microsoft.ml.spark.{LightGBMRegressionModel,LightGBMRegressor}

val lightgbm = new LightGBMRegressor()
               .setFeaturesCol("features")
               .setLabelCol("price")
               .setObjective("regression")

import org.apache.spark.ml.tuning.ParamGridBuilder

val paramGrid = new ParamGridBuilder()
                .addGrid(lightgbm.numLeaves, Array(6, 9))
                .addGrid(lightgbm.numIterations, Array(10, 15))
                .addGrid(lightgbm.maxDepth, Array(2, 3, 4))
                .build()

paramGrid: Array[org.apache.spark.ml.param.ParamMap] =
Array({
        LightGBMRegressor_f969f7c475b5-maxDepth: 2,
        LightGBMRegressor_f969f7c475b5-numIterations: 10,
        LightGBMRegressor_f969f7c475b5-numLeaves: 6
}, {
        LightGBMRegressor_f969f7c475b5-maxDepth: 3,
        LightGBMRegressor_f969f7c475b5-numIterations: 10,
        LightGBMRegressor_f969f7c475b5-numLeaves: 6
}, {
        LightGBMRegressor_f969f7c475b5-maxDepth: 4,
        LightGBMRegressor_f969f7c475b5-numIterations: 10,
        LightGBMRegressor_f969f7c475b5-numLeaves: 6
}, {
        LightGBMRegressor_f969f7c475b5-maxDepth: 2,
        LightGBMRegressor_f969f7c475b5-numIterations: 10,
        LightGBMRegressor_f969f7c475b5-numLeaves: 9
}, {
        LightGBMRegressor_f969f7c475b5-maxDepth: 3,
        LightGBMRegressor_f969f7c475b5-numIterations: 10,
        LightGBMRegressor_f969f7c475b5-numLeaves: 9
}, {
        Lig...

import org.apache.spark.ml.evaluation.RegressionEvaluator

val evaluator = new RegressionEvaluator()
                .setLabelCol("price")
                .setPredictionCol("prediction")
                .setMetricName("rmse")

import org.apache.spark.ml.tuning.CrossValidator

val cv = new CrossValidator()
         .setEstimator(lightgbm)
         .setEvaluator(evaluator)
         .setEstimatorParamMaps(paramGrid)
         .setNumFolds(3)

val model = cv.fit(trainingData)

val predictions = model.transform(testData)

predictions.select("features","price","prediction").show

+--------------------+------------------+------------------+
|            features|             price|        prediction|
+--------------------+------------------+------------------+
|[17796.6311895433...|302355.83597895555| 965317.3181705693|
|[35454.7146594754...| 1077805.577726322|1093159.8506664087|
|[35608.9862370775...| 449331.5835333807|1061505.7131801855|
|[38868.2503114142...| 759044.6879907805|1061505.7131801855|
|[40752.7142433209...| 560598.5384309639| 974582.8481703462|
|[41007.4586732745...| 494742.5435776913| 881891.5646432829|
|[41533.0129597444...| 682200.3005599922| 966417.0064436384|
|[42258.7745410484...| 852703.2636757497|1070641.7611960804|
|[42940.1389392421...| 680418.7240122693|1028986.6314725328|
|[43192.1144092488...|1054606.9845532854|1087808.2361520242|
|[43241.9824225005...| 629657.6132544072| 889012.3734817103|
|[44328.2562966742...| 601007.3511604669| 828175.3829271109|
|[45347.1506816944...| 541953.9056802422| 860754.7467075661|
|[45546.6434075757...|   923830.33486809| 950407.7970842035|
|[45610.9384142094...|  961354.287727855|1175429.1179985087|
|[45685.2499205068...| 867714.3838490517|  828812.007346283|
|[45990.1237417814...|1043968.3994445396|1204501.1530193759|
|[46062.7542664558...| 675919.6815570832| 973273.6042265462|
|[46367.2058588838...|268050.81474351394| 761576.9192149616|
|[47467.4239151893...| 762144.9261238109| 951908.0117790927|
+--------------------+------------------+------------------+
only showing top 20 rows

val rmse = evaluator.evaluate(predictions)

rmse: Double = 198601.74726198777

Listing 3-11Multiple Regression Using LightGBM

让我们提取每个特性的特性重要性分数。

val model = lightgbm.fit(trainingData)

model.getFeatureImportances("gain")
res7: Array[Double] = Array(1.110789482705408E15, 5.69355224816896E14, 3.25231517467648E14, 1.16104381056E13, 4.84685311277056E14)

通过匹配列表中输出的顺序和特征向量中特征的顺序(avg_area_income,avg_area_house_age,avg_area_num_rooms,avg _ area _ num _ hydro ses,area_population),看起来 avg_area_income 是我们最重要的特征,其次是 avg_area_house_age,area_population 和 avg_area_num_rooms。最不重要的特征是 avg _ area _ num _ bedrooms。

摘要

我讨论了 Spark MLlib 中包含的一些最流行的监督学习算法,以及外部可用的新算法,如 XGBoost 和 LightGBM。虽然网上有大量关于 XGBoost 和 LightGBM for Python 的文档,但是关于 Spark 的信息和示例却很有限。本章旨在帮助弥合这一差距。

我建议您参考 https://xgboost.readthedocs.io/en/latest 来了解更多关于 XGBoost 的信息。对于 LightGBM, https://lightgbm.readthedocs.io/en/latest 有最新信息。有关 Spark MLlib 中包含的分类和回归算法背后的理论和数学的更深入的报道,我建议您参考 Gareth James、Daniela Witten、Trevor Hastie 和 Robert Tibshirani (Springer,2017 年)的统计学习介绍以及 Trevor Hastie、Robert Tibshirani 和 Jerome Friedman (Springer,2016 年)的统计学习要素。关于 Spark MLlib 的更多信息,请在线咨询 Apache Spark 的机器学习库(MLlib)指南*:https://spark.apache.org/docs/latest/ml-guide.html。*

参考

  1. 朱迪亚珍珠;“E PUR SI MUOVE(但它会移动),”2018,原因之书:因果的新科学

  2. 阿帕奇 Spark《多项逻辑回归》,spark.apache.org,2019, https://spark.apache.org/docs/latest/ml-classification-regression.html#multinomial-logistic-regression

  3. 圣乔治·德拉克斯;《支持向量机 vs 逻辑回归》,towardsdatascience.com,2018, https://towardsdatascience.com/support-vector-machine-vs-logistic-regression-94cc2975433f

  4. 阿帕奇 Spark《多层感知器分类器》,spark.apache.org,2019, https://spark.apache.org/docs/latest/ml-classification-regression.html#multilayer-perceptron-classifier

  5. 分析 Vidhya 内容团队;《从零开始的基于树的建模完整教程(用 R & Python 编写)》,AnalyticsVidhya.com,2016, www.analyticsvidhya.com/blog/2016/04/complete-tutorial-tree-based-modeling-scratch-in-python/#one

  6. LightGBM《分类特征的最优分割》,lightgbm.readthedocs.io,2019, https://lightgbm.readthedocs.io/en/latest/Features.html

  7. 约瑟夫·布拉德利和马尼什·阿姆德;《MLlib 中的随机森林与助推》,Databricks,2015, https://databricks.com/blog/2015/01/21/random-forests-and-boosting-in-mllib.html

  8. 分析 Vidhya 内容团队;《理解 XGBoost 背后数学的端到端指南》,analyticsvidhya.com,2018, www.analyticsvidhya.com/blog/2018/09/an-end-to-end-guide-to-understand-the-math-behind-xgboost/

  9. 本·戈尔曼;“一位 Kaggle 大师解释渐变增强,”Kaggle.com,2017, http://blog.kaggle.com/2017/01/23/a-kaggle-master-explains-gradient-boosting/

  10. XGBoost《助推树木概论》xgboost.readthedocs.io,2019, https://xgboost.readthedocs.io/en/latest/tutorials/model.html

  11. 阿帕奇 Spark“集合——基于 RDD 的 API”,spark.apache.org,2019, https://spark.apache.org/docs/latest/mllib-ensembles.html#gradient-boosted-trees-gbts

  12. 陈天琦;“什么时候可以在梯度提升机器(GBM)上使用随机森林?,“quora.com,2015, www.quora.com/When-would-one-use-Random-Forests-over-Gradient-Boosted-Machines-GBMs

  13. 艾登·奥布莱恩等人。艾尔。;“基因组变体的 VariantSpark 机器学习”,CSIRO,2018, https://bioinformatics.csiro.au/variantspark

  14. 丹尼斯·c·鲍尔等人。艾尔。;“利用广泛的随机森林打破基因组学中的维数灾难”,Databricks,2017, https://databricks.com/blog/2017/07/26/breaking-the-curse-of-dimensionality-in-genomics-using-wide-random-forests.html

  15. 亚历山大·露露利。艾尔!艾尔!;“重构”,github.com,2017 年,

  16. 重新造林;“如何用 ReForeSt 学习随机森林分类模型”,sites.google.com,2019, https://sites.google.com/view/reforest/example?authuser=0

  17. 阿帕奇 Sparkspark.apache.org,2019 年, https://spark.apache.org/docs/latest/mllib-ensembles.html#random-forests

  18. CallMiner“新的研究发现不重视客户导致 1360 亿美元的转换流行病,”CallMiner,2018, www.globenewswire.com/news-release/2018/09/27/1577343/0/en/New-research-finds-not-valuing-customers-leads-to-136-billion-switching-epidemic.html

  19. 红色 Reichheld《削减成本的药方》,贝恩公司,2016 年, www2.bain.cimg/BB_Prescription_cutting_costs.pdf

  20. 亚历克斯·劳伦斯;《企业家留住客户的五个秘诀》,福布斯,2012, www.forbes.com/sites/alexlawrence/2012/11/01/five-customer-retention-tips-for-entrepreneurs/

  21. 大卫·贝克汉姆;《电信数据集中的流失》,Kaggle,2017, www.kaggle.com/becksddf/churn-in-telecoms-dataset

  22. 杰弗里·什曼;“如何使用 Apache Spark MLlib 预测电信客户流失”,DZone,2016, https://dzone.com/articles/how-to-predict-telco-churn-with-apache-spark-mllib

  23. 杰克·霍尔;《随机森林的可变重要性是如何计算的》,DisplayR,2018, www.displayr.com/how-is-variable-importance-calculated-for-a-random-forest/

  24. 迪德里克·尼尔森;《用 XGBoost 助推树》,挪威科技大学,2016, https://brage.bibsys.no/xmlui/bitstream/handle/11250/2433761/16128_FULLTEXT.pdf

  25. 莉娜·肖;《XGBoost:简明技术概述》,KDNuggets,2017, www.kdnuggets.com/2017/10/xgboost-concise-technical-overview.html

  26. Philip Hyunsu Cho“快速直方图优化生长器,8 到 10 倍加速”,DMLC,2017, https://github.com/dmlc/xgboost/issues/1950

  27. XGBoost《用 XGBoost4J-Spark 构建一个 ML 应用》,xgboost.readthedocs.io,2019, https://xgboost.readthedocs.io/en/latest/jvm/xgboost4j_spark_tutorial.html#pipeline-with-hyper-parameter-tunning

  28. 杰森·布朗利;《Python 中如何用 XGBoost 调优决策树的个数和大小》,machinelearningmastery.com,2016, https://machinelearningmastery.com/tune-number-size-decision-trees-xgboost-python/

  29. Laurae"基准测试 light GBM:light GBM 与 xgboost 相比有多快?",medium.com,2017 年

[`https://medium.com/implodinggradients/benchmarking-lightgbm-how-fast-is-lightgbm-vs-xgboost-15d224568031`](https://medium.com/implodinggradients/benchmarking-lightgbm-how-fast-is-lightgbm-vs-xgboost-15d224568031)
  1. LightGBM《在速度和内存使用上的优化》,lightgbm.readthedocs.io,2019, https://lightgbm.readthedocs.io/en/latest/Features.html

  2. 大卫·马克思;“决策树:逐叶(最佳优先)和逐级树遍历”,stackexchange.com,2018, https://datascience.stackexchange.com/questions/26699/decision-trees-leaf-wise-best-first-and-level-wise-tree-traverse

  3. LightGBM《LightGBM 特性》,lightgbm.readthedocs.io,2019, https://lightgbm.readthedocs.io/en/latest/Features.html

  4. 西拉德·帕夫卡;“各种开源 GBM 实现的性能”,github.com,2019, https://github.com/szilard/GBM-perf

  5. 胡利奥·安东尼奥·索托;"“增益”返回的特征重要性是什么?",github.com,2018, https://github.com/Microsoft/LightGBM/issues/1842

  6. sci kit-learn;《线性回归例题》,scikit-learn.org,2019, https://scikit-learn.org/stable/auto_examples/linear_model/plot_ols.html

  7. 李宏建等。艾尔。;“用随机森林代替多元线性回归提高评分函数的结合亲和力预测:Cyscore 案例研究”,nih.gov,2014, www.ncbi.nlm.nih.gov/pmc/articles/PMC4153907/

  8. 阿里扬·潘查尔;《USA Housing.csv》,Kaggle,2018, www.kaggle.com/aariyan101/usa-housingcsv

  9. 数据科学中心;《RMSE:均方根误差》,Datasciencecentral.com,2016, www.statisticshowto.datasciencecentral.com/rmse/