给你的应用增加一点智能——浅谈决策树及其在移动端的应用

4,297 阅读28分钟

1 应用智能化的意义

到2021年,人工智能已经成为了各大移动互联网产品核心的竞争点。内容平台通过推荐算法,向用户分发个性化的内容,提升使用时长;社交平台进行关系图谱分析,精确推荐熟人,提升用户粘性;直播平台通过图像处理,提供美颜和特效,提升直播效果。在手机屏幕的背后,有成千上万台机器在进行无休止的运算。无法计量的数据通过电磁波和网线穿梭在应用和计算中心之间,让一个个AI模型变得愈发丰满。

人工智能的潜力不止如此。除了这些“大场景”,移动端应用上还有很多“小场景”可以借助人工智能获得更好的体验,和其他产品产生区分。算法的执行也无需完全依赖云端的计算资源,端侧的计算框架也能完成很多有趣的功能。比如:

  1. 入口预测:在进入多Tab页面时,预测用户可能需要的功能,直接展示该分页。
  2. 弹窗管理:针对用户的使用习惯,控制弹窗出现的频率,展示更贴合用户需求的内容。
  3. 智能预加载:对可能进入的页面进行预加载,缩短页面加载时长,提升用户体验。

本文以决策树学习为例,通过介绍决策树的原理、训练方法和移动端部署方法,对移动端机器学习进行一点浅薄的讨论。抛砖引玉,一得之见。望各位老师多多提点。

关键词:端智能、机器学习、决策树、scikit-learn、Python

2 智能化实现方法

2.1 端智能实践一般流程

相较于普通的业务需求,端智能的业务在链路上更长一些。整条链路要从数据开始。一方面云端需要拉取离线数据,用于模型训练。另一方面客户端也需要在端上获取对等数据,作为端侧推理的输入。另外,为了让推理能力的迭代不受App迭代周期和覆盖速度的影响,端侧需要具有动态更新模型的能力。理想状态下,这应该是一个基础能力。在端侧处理特征和推理后,推理结果被用于指导业务形态。模型推理的相关数据(推理结果、真实结果、推理耗时等)应当和业务本身结果一样,通过埋点上报到云端用于评估。

在机器学习与其应用的流程中,每个环节都包含大量科学原理和工程实践,限于篇幅和笔者的能力,本文仅着重介绍决策树原理、训练和移动端部署的环节。

2.2 模型选择

人工智能有多种实现方式,包括机器学习(线性分割、决策树、SVM等)、深度学习(依赖机器学习,CNN、RNN等)和强化学习等。不同方案适用的场景不同,训练和应用的难度也相差甚远。这里提供scikit-learn提供的cheat-sheet,供读者参考。

在实际应用中,一般会针对一个场景(如分类场景)筛选出多个模型(如SVC/SGD),初步训练后选择表现最好的模型,最终再针对这种模型进行进一步的参数调整。

3 决策树

3.1 决策树的定义

本文聚焦在决策树这种简单强大的工具上。维基百科上对决策树的描述如下:

决策树学习是统计学数据挖掘机器学习中使用的一种预测建模方法。它使用决策树作为预测模型,从样本的观测数据(对应决策树的分支)推断出该样本的预测结果(对应决策树的叶结点)。按预测结果的差异,决策树学习可细分两类。(1)分类树,其预测结果仅限于一组离散数值。树的每个分支对应一组由逻辑与连接的分类特征,而该分支上的叶结点对应由上述特征可以预测出的分类标签。(2)回归树,其预测结果为连续值(例如实数)。

简而言之,决策树学习就是给定一组样本数据,包括特征取值及其对应的标签,训练出一个树形分类器来描述特征与其标签映射的逻辑。以久经考验的Iris分类数据集为例:特征(或称属性)包括花瓣长度、花瓣宽度、花萼长度和花萼宽度,标签为鸢尾花的种类(山鸢尾、变色鸢尾和维吉尼亚鸢尾)。下面是数据集的一个采样(左一列是样本序号,非数据)。

学习这个数据集获得的分类决策树可能有以下的形状(仅示意)。在决策树中,内部结点(非叶结点)代表一个用于分类的特征(如“花瓣长度”),结点下属的边代表判断条件(如“大于2.497”)。而叶结点代表一个预测分类。在推理一条未知分类的鸢尾花数据时,将其输入到决策树中:从根结点出发,根据每个结点的判断条件选择分支,最终抵达一个叶结点,即预测的分类。

3.2 构造决策树

上面一节简述了决策树的定义和功能,本节讨论决策树的构造过程。下图是周志华老师《机器学习》一书中对决策树构造过程的伪代码描述。

决策树的构造是一个递归过程:输入一组数据集,如果当前数据集满足了停止条件1,则生成一个叶结点并停止构造过程。否则,选择一个最优划分属性2来分割这个数据集,创建一个中间结点和对应的边,并且将分割后的产生的子集作为输入重复执行构造过程。

可以看到,构造方法有两个关键点:停止条件和最优划分属性。首先讨论停止条件。在不考虑过拟合(关于过拟合的定义,后文会有介绍)和计算成本的条件下,停止条件可以总结为:

  1. 数据集为空,或
  2. 数据集内样本全部属于同一分类,或
  3. 数据集里所有样本在所有特征上的取值都相同

在条件1和3下,数据集已经无法再进行分割;而在条件2下,继续分割已经没有意义了。

接下来讨论最优划分属性。分割数据集的目的是使得子集里的样本尽量属于同一类别,即抵达停止条件。我们先聚焦在“类别”这一列上。数据集所包含的类别数量可以视为它包含的信息量,类别成分越复杂则信息量越大。在信息学上,“信息增益”被用来衡量信息量之差:信息增益等于数据集划分前的信息量与划分后的信息量之差。信息增益越大,则子集包含的信息越少,即子集包含的类别越少。选择最优划分属性,实际上是在寻找信息增益最大的分割条件。具体方法如下:对数据集包含的特征进行遍历,使用该特征将数据集划分成多个子集,计算并记录当前划分获得的信息增益。最终选择信息增益最大的特征作为分割条件。

这里对数据划分做一些额外的阐述。当特征取值离散时,划分方法比较直观。例如特征为颜色={红,黄,蓝},可以将数据集划分为三个子集,分别包含颜色为红、黄、蓝的样本。如果特征取值为连续值,如iris数据集中的花瓣宽度,则需要根据该特征值在数据集上的分布,将其划分为几个区间,再按照离散值的方法进行划分。不同的树算法有不同的连续值分割方法,此处不再深入讨论。

回顾一下,划分子集时使用的是特征,计算信息增益时观察的是类别。记住这一点,避免在下一节中迷路。

3.3 信息增益

本节介绍信息增益的计算方法,包含一些数学公式。读者可以根据阅读目标决定是否跳过本节。

3.3.1 信息熵

1948年,香农老祖将热力学中“熵(Entropy)”的概念引入信息论,用来度量信息的含量。信息熵描述了在观察“类别”时,数据集D的不确定性。数据集包含的样本越杂乱,类别越多,则信息熵越大;则数据集包含的样本越统一,类别越少,信息熵越小。把数据集按照某种方式划分为几个子集后,原数据集的信息熵和划分后子集总信息熵的差值,即为信息增益。

假设一个数据集D包含y个类别,每类样本所占比例为pk,则D的信息熵定义为:

Ent(D)=k=1ypklog2pkEnt(D) = -\sum_{k=1}^{y} p_k \log_2p_k

例如y = 1, 则Ent(D) = 0。 假设按照离散属性f将D进行划分,f包括v个取值,则D会被划分为v个子集,记为DvD^v,其信息熵为Ent(Dv)Ent(D^v)。考虑到每个子集包含的样本数不同,再给予其一个数量上的权重。划分的信息增益为:

Gain(D,f)=Ent(D)v=1VDvDEnt(Dv)Gain(D, f) = Ent(D) - \sum_{v=1}^{V}\frac{|D^v|}{|D|}Ent(D^v)

使用信息熵划分属性的决策树算法为ID3。但是信息熵会对取值更多的属性有偏好。为了解决这个问题,C4.5算法不直接使用信息熵,而是参考增益率来划分属性:

Gain_ratio(D,f)=Gain(D,f)IV(f)Gain\_ratio(D, f) = \frac{Gain(D, f)}{IV(f)}, 其中 IV(f)=1vDvDlog2DvDIV(f) = - \sum_{1}^{v}\frac{|D^v|}{ |D|}log_2\frac{|D^v|}{|D|}

特征f取值越多,则IV(f)越大。所以使用IV(f)IV(f)除信息增益能够一定程度上抵消信息增益对特征数量的偏好。

3.3.2 基尼指数

和信息熵类似,基尼指数是另一种衡量信息复杂度的方法。基尼指数可以理解为:从数据集D中随机抽取两个样本,其类别不同的概率。基尼指数越小,则数据集内的样本越统一。基尼指数的算法如下:

Gini(D)=k=1ykkpkpk=1k=1ypk2Gini(D) = \sum_{k=1}^{y}\sum_{k'\neq k}^{}p_kp_{k'}=1-\sum_{k=1}^{y}p_k^2

使用特征f划分D后,子集的基尼指数之和可以用下面的公式计算。同样,由于样本数不同,子集的基尼指数需要乘上权重。

Gini_index(D,f)=v=1VDvDGini(Dv)Gini\_index(D, f) = \sum_{v=1}^{V}\frac{|D^v|}{|D|}Gini(D^v)

在划分子集时,选择划分后基尼指数最小的特征。可以看到,使用基尼指数时并没有计算信息增益,笔者(不负责任的)认为这是由于基尼指数可以直接用于比较,而直接比较信息熵会存在上面提到的特征数量偏好的问题——需要计算增益率。

3.4 性能评估

现在假定我们已经在数据集上构造出了一颗决策树,我们怎么来评价这颗决策树的表现呢?这一节讨论对模型进行性能评估的方法。

3.4.1 训练集、验证集和测试集

机器学习的任务就是从数据中学习特征的规律。如果我们将获得的所有离线数据都直接投入到学习算法中,那么我们对模型进行的评估都只能反应它在这个数据集上的表现。我们无法预知这个模型投入到生产环境中使用时的性能。为了解决这个问题,需要将原始数据集进行分割,一部分用于训练,一部分用于测试。

通常数据集需要被分成三部分:训练集、验证集和测试集。训练集用于训练模型。验证集则用于验证模型是否过拟合及调整参数。当你获取了最终想要发布的模型时,可以用测试集提前预测模型推广的效果。一个常用的划分比例是抽取原始数据的70%作为训练集,20%作为验证集和10%作为测试集。数据集的划分应该尽可能保证等比抽样,例如原始数据中正例反例之比为8:2,那么在这三个子集中正例反例之比也应该接近8:2。

3.4.2 性能指标

混淆矩阵

首先定义混淆矩阵(基于二分类,标签分别为正和负):

真实为正真实为负
预测为正True PositiveFalse Positive
预测为负False NegativeTrue Negative
混淆矩阵

错误率和准确率(精度)

错误率描述了预测结果整体的错误情况,即预测错误的数量占样本的比例。

Error=(FN+FP)÷(TP+FN+FP+TN)Error = (FN + FP ) \div (TP + FN + FP + TN)

而准确率是错误率的反面描述:

Accuracy=(TP+TN)÷(TP+FN+FP+TN)Accuracy = (TP + TN) \div (TP + FN + FP + TN)

容易看出: Error+Accuracy=1Error + Accuracy = 1

查准率/准确率

准确率描述了模型对正例预测的准确度:预测为正的数据里有多少是真实为正的。

Precision=TP÷(TP+FP)Precision = TP \div ( TP + FP)

查全率/召回率

准确率描述了模型对正例的预测能力:真实为正的数据有多少被模型成功找出。

Recall=TP÷(TP+FN)Recall = TP \div (TP + FN)

3.4.3 评估方法

如上所述,错误率和准确率是对模型整体效果的衡量,对正例和反例一视同仁。当用户关心模型对所有分类的预测表现时,可以直接观察准确率。但是在实际的业务中,我们往往更关心模型对某一类标签的预测能力。比如在预加载业务中,业务更关心的是模型预测“进入页面”这一分类的表现,因为这个预测结果会导致端上执行预加载的动作。在这种情况下可以将“进入页面”视为positive分类,观察查准率和查全率这两个参数。

需要理解的是,查准率和查全率这两个参数在一定程度上是冲突的。具有较高查准率的模型往往查全率表现比较差。可以想象以下极限情况:预测所有的样本为positive。这种情况覆盖了样本中的所有正例,查全率为100%,但可想而知查准率会非常低。下图反映了Iris数据集上某种模型查准率和查全率的关系。

在实际应用时,可以根据业务场景对这两个参数有所倾向。例如垃圾邮件过滤场景,为了不阻拦用户的正常邮件,需要格外重视模型的查准率。可以通过计算F-score,在给予一定权重的情况下综合评估:

其中β是给予recal的权重,代表recall的重要性是precision的β倍。

除了以上的参数,机器学习中还常常使用ROC-AUC作为参数衡量分类模型的表现,这里不再进行介绍,读者可以自行检索。

3.4.4 泛化能力/过拟合

过拟合的定义

机器学习训练是模型学习训练集特征的过程。由于训练集只是真实世界的一个子集,必然包含其独有的特征。过拟合(overfitting)描述了模型对训练集特有特征过度学习的现象。举一个极端的例子:假设我们的训练集包括5条数据和1个特征,每条数据具有特征上的一个取值。决策树只要生成5条路径,每条路径描述一条数据的特征取值,即可实现训练集上100%的预测准确度。显然这样的模型应用在未知数据上时性能会很差。下图展示了过拟合的情况:

混淆矩阵

识别过拟合

在3.4.1节中我们提到过,原始数据集通常会被划分为训练集、验证集和测试集。验证集就是用来辅助判断模型是否过拟合的。随着训练进程的深入或决策树深度的增长,决策树在训练集上的性能表现会越来越好。但是如果我们将此模型用于预测验证集,其性能增长会在某一点开始停滞,甚至开始下降。下图描述了在Iris数据集上,一棵决策树随着深度增长在测试集和验证集上的得分曲线。可以看到在4层之后,测试集上的得分仍然在上升,但测试集的得分开始下降。此时决策树已经开始过拟合。

通过剪枝避免过拟合

可以采用剪枝(pruning)的手段来避免过拟合:通过减少决策树中的路径,来避免决策树包含过多的判断条件,即减少决策树学习到的特征。剪枝可以分为预剪枝和后剪枝。预剪枝是说,在决策树生长过程中,如果达到限制条件则停止生长过程。后剪枝则是首先让决策树自由生长到最大深度,再通过一些算法去除决策树中的部分结点。

预剪枝常用的限制条件包括:树的最大深度、划分结点所需的最少样本数和叶结点的最小样本数。而后剪枝遍历每个叶结点,计算将叶结点剪枝后决策树的泛化能力是否上升或不变。上升的情况不谈,即使观测到的泛化能力不变,剪枝也总是好的:如无必要,勿增实体。

4 决策树训练

上一节讨论了决策树的基本原理,性能指标和过拟合防治。在实际生产中,一般不会从零手写一个决策树的训练算法,就像我们不会在Android上手写一个图片加载库或者网络库。业界有很多优秀且久经检验的训练框架供我们使用。本章介绍如何使用scikit-learn框架训练一棵决策树,对企鹅进行分类。

4.1 框架介绍与搭建

4.1.1 Scikit-learn

Scikit-learn是一个知名的Python机器学习库,基于NumPy、SciPy和matplotlib等优秀的Python数据处理库建立。Scikit-learn支持多种类型的机器学习任务:分类、回归、聚类等,并且提供数据预处理等能力。Scikit-learn是开源的,所以不用担心商业使用的问题。更妙的是,Scikit-learn有一个翻译水平很高的中文网站,包含全部Api参考和用户指南。

4.1.2 环境搭建

完全手动搭建环境,需要先后安装Python(基础环境)、Pip(Python包管理)、Scikit-learn和其他可能需要的包。这里推荐通过Anaconda进行图形化的一站式环境配置。鼠标点几下就能安装所有依赖,并且可以轻松的建立venv虚拟环境避免包依赖之间冲突。

除了包管理之外,Anaconda还能直接带来Jupyter Notebook的支持。Jupyter Notebook是一个现代化的交互式的编程环境,相较于传统的IDE如PyCharm,Jupyter能更直观的展示上下文,大幅提高工作效率。

如果读者只是快速体验机器学习,或不涉及到数据安全问题(使用公开数据),强烈建议直接点开Google Colabtory,  无需自己配置任何基础环境,即开即用,还能蹭到免费GPU资源。本文中的Demo就是在Colab上开发的。点这里查看Demo。

4.2 训练数据获取和处理

为了追求一点新鲜感,本文中的例子没有使用经典数据集Iris,而是使用了一个类似规模的帕默群岛企鹅数据集。数据集描述了南极洲帕默群岛上三种企鹅的喙长、喙宽、蹼长、体重和性别。本文中的Demo会学习企鹅的这些特征,预测企鹅的种类。

混淆矩阵 混淆矩阵

4.2.1 数据拉取

需要注意的是,企鹅数据集不是Scikit-learn内置的,需要单独进行安装。命令如下:

# 在笔记本中执行pip指令需要增加一个!
!pip install palmerpenguins

企鹅包提供的加载方法返回的是pandas.DataFrame类型的数据,通过head函数可以快速查看数据内容。

import pandas as pd
from palmerpenguins import load_penguins
penguins = load_penguins()
penguins.head()

4.2.2 数据处理

在机器学习中,原始数据集通常需要处理后才能交给模型进行学习。这是因为原始数据可能包括一些异常值(如typo)、空值或无用属性。

可以看到,企鹅数据中包括值为空的样本(NaN)。空值的处理方法对于不同的模型有所不同,这里简单的去掉包含空值的行。

另外,部分属性在映射后更方便使用。如sex属性,原数据集中是“male/female”这样的字符串,建议将字符串映射为数字(如0/1)。这样做有两个好处,第一,如果数据中存在typo的情况,如“male”不小心录入成了“mal”,映射后可以被迅速的发现。第二,部分模型(包括决策树)无法接受字符串这样的输入。

实际上,除了对特征的格式转化外,往往还需要对特征进行相关性分析,去除和标签无关的属性。这里暂时略过相关性分析,单从数据上就可以看出,year一列描述的是这条数据收集的时间,与企鹅的品种无关。所以需要去除。

Pandas提供了强大的数据处理能力,上述的处理动作仅需要短短几行:

# 去掉空数据
penguins = penguins.dropna()
# 种类、性别、岛名到数字的映射关系
spice_dict = {'Adelie': 0, 'Chinstrap': 1, 'Gentoo': 2}
gender_dict = {'male': 0, 'female': 1}
island_dict = {'Torgersen':0, 'Biscoe':1, 'Dream':2}
# 映射数据
penguins = penguins.replace(spice_dict).replace(gender_dict).replace(island_dict)
# 丢掉无效输入, axis=1表示选取的是列而非行
penguins = penguins.drop('year', axis=1)
penguins.head()

处理后的数据长这样:

4.2.3 数据集划分

前面提到过,原始数据集需要分为训练集、验证集和测试集。Scikit-learn提供了快速拆分数据集的能力。这里通过留出法展示拆分的原理,见下方代码。需要额外说明的是,由于拆分可能带来潜在的样本偏差,实践中往往会采用K-fold的方式,即将测试集分成容量相同的K组,每次取其中一组作为验证集,剩余组作为训练集合。验证时遍历K次,计算平均分数。

# 数据集划分
from sklearn.model_selection import train_test_split
# 首先留出测试集, test_size代表拆分比例,random_state则是用于保证每次分割结果都一致的种子
train_and_validate, test = train_test_split(penguins, test_size=0.1, random_state=1)
# 简单示意拆出验证集,实践中常采取K-fold的方式
train, validate = train_test_split(train_and_validate, test_size=0.3, random_state=1)

下面的方法可以将特征和标签从Dataframe中取出来:

# 拆分特征和标签的方法
def getFeatureAndLabelFromDF(dataset):
    feature_cols = ['bill_length_mm', 'bill_depth_mm', 'flipper_length_mm', 'body_mass_g', 'sex', 'island']
    X = dataset.loc[:, feature_cols]
    y = dataset.species
    return X, y

4.3 决策树训练

4.3.1 决策树拟合

现在万事俱备,我们可以开始训练决策树了。决策树拟合的代码非常简单,只需要以下几句:

from sklearn import tree
from sklearn.metrics import f1_score
# 创建决策树
clf = tree.DecisionTreeClassifier()
# 分离特征和标签,X为特征,y为标签
tX, ty = getFeatureAndLabelFromDF(train)
# 给定特征和标签,拟合决策树
clf = clf.fit(tX, ty)
# 看下表现,计算方法为f-measure,权重为1。
py = clf.predict(tX)
f1_score(ty, py, average='weighted')

=== 输出 ===
1.0

Wow,模型在训练集上拿到了1.0的满分。这么高的分数立刻会让我们警惕起来:模型是否已经过拟合了?看下同个模型在验证集上的表现:

# 验证集验证决策树表现
vX, vy = getFeatureAndLabelFromDF(test)
py = clf.predict(vX)
f1_score(vy, py, average='weighted')

=== 输出 ===
0.8808654496281271

果然,在验证集上,决策树的效果就没有那么好了。在多数情况下,决策树的默认参数无法为我们提供性能最好的模型。下一节我们讨论如何对决策树进行调参。

4.3.2 决策树的调参

前面提到过,对决策树进行预剪枝一般会调整三个参数:树的最大深度、划分结点所需的最少样本数和叶结点的最小样本数。在Scikit-learn中,这三个参数对应的属性分别为:max_depth, min_samples_split, min_samples_leaf。对这三个参数的调整一般按照以下步骤进行调整:

  1. 搜索max_depth的最佳区间。
  2. 在max_depth的最佳区间中任取一个值作为深度,在此基础上搜索min_samples_split的最佳区间。
  3. 和第二步类似,固定max_depth和min_samples_split,搜索min_samples_leaf的最佳区间。
  4. 使用网格查找,遍历max_depth,min_samples_split和min_samples_leaf的取值组合,获取最佳参数。

由于1,2,3这三步的逻辑几乎一模一样,这里就以深度为例,介绍如何进行搜索:

# 训练决策树, 取平衡权重
def trainModel(X, y, depth, split=2, leaf=1):
    clf = tree.DecisionTreeClassifier(max_depth = depth, min_samples_split = split, min_samples_leaf = leaf, class_weight = "balanced")
    # 使用K-fold方式交叉验证模型
    scores = cross_val_score(clf, X, y, cv=10, scoring='f1_weighted')
    ave_score = sum(scores) / len(scores)
    return clf, ave_score
  
 # 搜索Depth
def searchDepth(X, y):
  	# 遍历Depth区间,记录得分。一般可以首先设定较大的scope,如[10, 100]。
    # 本例数据量较小,无需设置太大的深度(过深容易引起过拟合)
    depth_options = list(range(3,20))
    scores = []
    for depth in depth_options:
        clf, score = trainModel(X, y, depth)
        scores.append(score)
    # 绘制折线图
    drawCurve(depth_options, scores, 'depth', 'train&validation f1-scores', 'Depth performance')
    
import matplotlib.pyplot as plt
from sklearn import metrics
from sklearn.model_selection import cross_val_score
X, y = getFeatureAndLabelFromDF(train_and_validate)
searchDepth(X, y)

下面是结果的曲线图:

可以看到,随着depth变深,模型表现有所波动,这代表着过深的层数是不可靠的。所以我们取波动前的[5,6]为深度搜索区间。在搜索其他参数时,可设定depth为5。下面是另外两个参数的搜索结果,区间分别为[4,5]和[2,3]。

获取区间后,可以使用网格搜索对最终参数进行获取。网格搜索的内在逻辑为:遍历所提供参数取值区间中的所有组合,依次设定的评估方式计算分数,最终提供得分最高的参数组合。

from sklearn.model_selection import GridSearchCV
params = [{'max_depth': [5, 6],
         'min_samples_split': [4, 5],
         'min_samples_leaf': [2, 3]}]
gridSearchClf = GridSearchCV(tree.DecisionTreeClassifier(),param_grid=params, scoring='f1_weighted', cv=10)
gridSearchClf.fit(X, y)
gridSearchClf.best_params_

=== 输出 ===
{'max_depth': 5, 'min_samples_leaf': 2, 'min_samples_split': 5}

至此,对决策树的调参告一段落,我们可以用测试集评估一下模型的泛化表现:

testX, testy = getFeatureAndLabelFromDF(test)
pTesty = gridSearchClf.best_estimator_.predict(testX)
f1_score(testy, pTesty, average='weighted')

=== 输出 ===
0.9081741150297716

5 移动平台

在上一章的最后,我们得到了一棵在离线环境训练好的决策树,现在我们要讨论如何把这棵决策树带到Android端用于推理。限于篇幅,本文在移动端部署的部分没有给出step by step的教程。如果读者对这一部分有需求,可以在评论区留言,后续可以单开一篇文章给出部署的详细流程。

5.1 Android执行机器学习模型

首先需要声明的是,对于深度学习而言,市面上的主流框架,如TensorFlow和Pytorch,已经提供了较为成熟的端上推理方案。而很多国产端智能框架,如字节的Pitaya还包括模型分发的能力。很遗憾,本文讨论的机器学习框架Scikit-learn并没有这么强大的能力。我们需要将模型导出为Android平台可以执行的东西。

有两种思路可供参考。第一种方案是将模型导出为程序代码,编译后分发给安卓端。例如可以使用sklearn-porter将模型导出为Java/C等语言的代码,编译成jar包或.so文件后进行分发。由于导出的是代码,可操作性会比较高,因为用户可以轻易的修改这些代码。如果导出为.so文件的话,还可以潜在的提升代码运行效率。但这种方案带来的问题是,用户需要单独开发特征处理的部分。

第二种方案是将模型转化为中间格式的描述文件,如ONNXPMML格式,端上拉取到描述文件后重新解析为模型,再进行推理。这种方案有几个优点。第一,ONNX和PMML都是开源的文件格式,可供多种终端使用。第二,这种方式可以把数据处理相关的流程一并导出,使用方无需单独开发特征处理的代码。但是由于导出的描述文件包括一些标记信息,潜在的导致文件体积可能较大。而且端上在加载模型时有一些额外的消耗。

5.2 C/CPP方案

5.2.1 导出代码

将Scikit-learn的决策树导出为C++代码非常简单。首先你需要在Colab中安装sklearn-porter库。需要注意的是,由于sklearn新版本的包结构有调整,sklearn-porter无法支持最新版本sklearn。可以安装0.19.2版本的sklearn,或者修改sklearn-porter的源码。

# 安装sklearn-porter(和sklearn,如果需要的话)
!pip install --no-cache-dir https://github.com/nok/sklearn-porter/zipball/master
!pip install scikit-learn==0.19.2

导出代码非常简单。

from sklearn_porter import Porter
# 导出模型
def transformModel(clf):
    porter = Porter(clf, language='c')
    output = porter.export(embed_data=True)
    file_name = "c_model.cpp"
    with open(file_name,'w') as model_file:
        model_file.write(output)
        print("===TRANSFORMATION===")
        print("model saved in: ", file_name)
# 把上面获取的最佳决策树导出
transformModel(gridSearchClf.best_estimator_)

如果你用colab的话,在页面左侧的文件 -> Content下就可以看到c_model.cpp文件了。

5.2.2 打包、下发和端上执行

由于NDK的知识过于繁杂,限于篇幅和笔者的水平,这里仅作简单的描述。

  1. 当你从上一步获取了模型对应的cpp文件后,可以将其加入到安卓项目的NDK代码中,编译为对应平台的.so文件(这里的平台指的是armv7/v8等)。
  2. 在实际应用中,模型常常需要进行迭代。所以需要一个.so文件管理系统,确保.so文件可以上传到后端进行管理,App可以从后端拉取到.so文件。这个系统可能还需要处理一些.so路径管理和版本管理的工作。
  3. 使用类似Tinker热修的方案,App启动时将包含最新版本.so文件的路径插入到BaseDexClassLoader中搜索nativelib的列表中。
  4. 安卓端上通过.so库的路径加载native库,调用相关代码。
  5. 安卓端接收.so库的回调,执行后续业务逻辑。

5.3 PMML方案

openscoring.io提供了一系列PMML执行工具,囊括了Scikit-learn导出PMML模型、Java平台加载PMML模型和Java平台执行模型的全链路能力。然而由于Andriod平台天生缺乏对JXAB(用于解析PMML的运行时)的支持,笔者在尝试最短路径运行PMML模型时失败了。更遗憾的是,笔者在尝试绕路的过程中遇到了大量的问题,例如Jpmml-android无法支持最新版本的PMML文件、jpmml-evaluator版本无法匹配jpmml-android等。如果读者有强需求,或者有兴趣进一步探索的话,可以考虑以下的流程:

  1. 在Scikit-learn的Python代码中引入SkLearn2PMML库,导出PMML格式的模型
  2. 通过Jpmml-android库,把PMML文件转化为Android平台支持的Java Serialization(SER)格式。
  3. Android端引入jpmml-evaluator库,加载和执行SER文件。

6 尾声:It's all about the data

虽然本文绝大部分篇幅都在讨论决策树的原理、训练和部署。但是从更宏观的角度来看,一个机器学习模型的表现最终是由训练数据决定的。数据对标签的描述能力决定了机器学习的上限,任何对模型的优化都仅仅是在逼近这个上限。

我们在考虑为应用增加机器学习能力驱动的业务时,首先要考虑的就是有什么数据可供训练,评估数据能够在多大程度上描述预期的推理结果。此外,数据收集部分也常常是工程中最耗时和最容易出错的地方,需要同时考虑离线端和移动端拉取数据的方案和时间成本。

预祝大家编码愉快。

Reference

  1. Scikit-learn中文官网: scikit-learn.org.cn/
  2. Sklearn提供的模型选择地图:scikit-learn.org/stable/tuto…
  3. Jupyter Notebook主页: jupyter.org/
  4. Android NDK: developer.android.com/ndk
  5. Python模型导出PMML SkLearn2pmml: github.com/jpmml/sklea…
  6. PMML模型转化为SER Jpmml-android:github.com/jpmml/jpmml…
  7. Java平台执行PMML模型 Jpmml-evaluator:github.com/jpmml/jpmml…
  8. Colab Demo: colab.research.google.com/drive/1jLFB…

hi, 我是快手电商的Ryver

快手电商无线技术团队正在招贤纳士🎉🎉🎉! 我们是公司的核心业务线, 这里云集了各路高手, 也充满了机会与挑战. 伴随着业务的高速发展, 团队也在快速扩张. 欢迎各位高手加入我们, 一起创造世界级的电商产品~

热招岗位: Android/iOS 高级开发, Android/iOS 专家, Java 架构师, 产品经理(电商背景), 测试开发... 大量 HC 等你来呦~

内部推荐请发简历至 >>>我们的邮箱: hr.ec@kuaishou.com <<<, 备注我的花名成功率更高哦~ 😘