WWDC 2018:初探 Create ML

3,385

本文是 WWDC 2018 Session 703的读后感,其视频及配套 PDF 文稿链接如下:Introducing Create ML 本文内容从机器学习在 iOS 平台的发展说起,借此引出 Create ML。之后详细介绍了 Create ML 的使用方法,并给出相应示范。

Create ML 的提出

2017年可谓是机器学习在 iOS 平台上的元年:苹果在 WWDC 上推出了全新的 Core ML 框架,旨在为开发者提供一套完整的机器学习部署方案,从而让 App 更加智能。2018年的 WWDC 上,我们发现使用 Core ML 的 App 多达182个,其中最多的是拍照和照片编辑应用,包括大名鼎鼎的 Snapchat。我想这其中重要的原因是图像识别的机器学习模型比较成熟,苹果官方主页上也有现成的模型可供下载。

我们知道 Core ML 的工作模式是获取模型、导入模型、生成接口、进行调用,其中导入模型、生成接口、进行调用这3步在去年就有了比较简单直接的解决方案。而机器学习模型的获取方式却乏善可陈,主要有以下两种:

  • 从苹果的官方主页进行下载。去年有4个现成的模型,今年有6个,可以说没什么进步。

  • 用第三方的框架生成模型,再用 Core ML Tools 转成 Core ML 模型。2017年苹果宣布支持的框架有5个,包括 Caffee、Keras。今年宣布支持的第三方框架增加到了11个,包括了最知名的 TensorFlow、IBM Watson、MXNet,数量和质量都有大幅提升。

官网的现成模型数量有限;而通过第三方机器学习框架生成模型再进行转换这种方式,一来比较麻烦,二来模型的数量和质量都要受制于人,三来为了生成模型,第三方框架得到了 App 开发者的训练数据,相当于是苹果在机器学习上间接成全了谷歌、亚马逊、脸书、IBM 等竞争对手。

基于效率、安全、竞争方面的考虑,苹果推出了 Create ML——生成机器学习模型的原生官方手段。

Create ML 是什么

Create ML 是苹果开发的、生成机器学习模型的框架。它有三个特点:使用 Swift 进行操作;用 Playground 训练和生成模型;在 Mac OS 上完成所有工作,可以说是集原生态和便携性于一体。使用 Create ML 的流程如下:

  1. 确认场景。在使用 Create ML 前,我们必须确定当前问题可以用机器学习来解决:即问题对应的数据存在关联和规律。例如某电商网站通过分析用户特征和历史数据来判断其购物偏好,这类问题就可以机器学习解决;而诸如通过收集彩票中奖历史记录来预判下一期中奖号码则不可以通过机器学习解决,因为彩票开奖是随机的,其历史记录并无规律。

  2. 收集数据。这一步骤中,开发者不仅需要收集大量的训练数据,还要收集部分测试数据。目前 Create ML 支持三种类型的数据:图片、文字、表格。其中图片对应的模型 API 是 MLImageClassifier,而后两者则对应MLTextClassifier (对一大段文字进行分析)和 MLWordTagger(对单个单词进行分析)。

  3. 训练模型。目前 Create ML 提供两种训练模型方式:拖拽和代码形式。拖拽是这样完成的:在 Playground 中可以用 LiveView 直接打开训练的 UI 界面,然后将准备好的训练数据放入对应的训练框中,Playground 就自动开始训练模型了。代码形式则是调用相应的 API 进行操作,我们会在稍后详谈。

  1. 评估模型。具体步骤和训练模型类似。差别在于这次用的是测试数据,我们会根据测试数据返回的准确度去判断模型的可靠性。上图中我们可以看到,在测试数据评估之后,水果分类模型的准确度为92%。

  2. 保存模型。在 UI 界面中可以直接拖拽模型将其保存在桌面或其它位置。若是 API 操作,可以手动指定存储路径,再将其保存。

注意 Create ML 生成的模型是基于现有模型和专用数据而生成的定制化模型。例如上图中我们看到的水果图片分类模型,就是基于苹果的图片分类模型,配合水果图片而生成的专用模型。它只针对具体的使用场景,所以在尺寸和时间上都优化到了极致。例如上文中的水果图片分类模型就只有83KB,训练时间也在1分钟以内。

Create ML 的使用示范

在 WWDC 上,苹果工程师展示了如何使用 Create ML 生成图片分类模型、文本分类模型、表格分类模型。除了直接拖拽的方式,我们来看看用代码如何操作生成图片分类模型:

import Foundation
import CreateML

// 定义数据源
let trainDirectory = URL(fileURLWithPath: "/Users/createml/Desktop/Fruits")
let testDirectory = URL(fileURLWithPath: "/Users/createml/Desktop/TestFruits")

// 训练模型
let model = try MLImageClassifier(trainingData: .labeledDirectories(at: trainDirectory))

// 评估模型
let evaluation = model.evaluation(on: .labeledDirectories(at: testDirectory))

// 保存模型
try model.write(to: URL(fileURLWithPath: "/Users/createml/Desktop/FruitClassifier.mlmodel"))

在第一步是导入 Create ML 框架之后,我们需要定义数据源的信息,这里训练数据和测试数据皆是一系列水果的图片,被存放在桌面对应的文件夹中。文件夹内部又包含多个子目录,而 Create ML 能帮我们从中提取出有用的图片信息。

Fruits文件夹中的内容

接下来我们就要用对应的训练数据 trainDirectorytestDirectory来训练和评估模型。由于 Playground 可以实时显示执行结果,我们可以观察到训练进度(100%为模型训练完成)和评估准确度(100%为完美匹配)。其中evaluation方法返回的是 MLClassifierMetrics 结构体,其中的confusion属性对应了评估结果和实际结果不同的数据。

最后我们将生成的模型存入指定的位置。注意这里用try的原因是有可能写入的操作会抛出异常,如磁盘已满、当前目录不允许写操作等。

除了图片分类模型,苹果还展示了文本和表格信息的模型生成方式。其中文本模型与图片模型的生成过程大同小异;我们重点看下表格信息模型生成的 Create ML 示范代码:

// 定义数据源
let trainingCSV = URL(fileURLWithPath: “/Users/createml/HouseData.csv”)
let houseData = MLDataTable(contentsOf: trainingCSV)
let (trainingData,testData) = houseData.randomSplit(by: 0.8, seed: 0)

// 训练模型
let classifier = try MLRegressor(trainingData: houseData, targetColumn: "price")

// 评估模型
let metrics = try classifier.testingMetrics(on: testData)

// 保存模型
try classifier.write(to: URL(fileURLWithPath: "/Users/createml/HousePricer.mlmodel"))

在定义数据源上,我们可以用 CSV 格式的表格数据,也可以用 JSON 数据。另外,表格数据由专用的 MLDataTable 结构体来处理。最后,使用randomSplit(by:seed:)方法可以随机得将原 MLDataTable 根据比例拆分成两个 MLDataTable,来分别对应训练和测试数据。

在训练模型上,苹果提供了线性回归、决策树、随机森林等多种算法来生成模型,而 MLRegressor 则是将算法分析交给苹果,系统会自动选择较好的算法生成模型。虽然 MLRegressor 在普适性上高于其他方法,但相对的精度和效率比之其他具体算法会略显不足。

// 用线性回归算法生成模型
let classifier = try MLLinearRegressor(trainingData: houseData, targetColumn: "price")

// 用随机森林算法生成模型
let classifier = try MLRandomForestRegressor(trainingData: houseData, targetColumn: "price")

// 用 MLRegressor 生成模型
let classifier = try MLRegressor(trainingData: houseData, targetColumn: "price")

在评估模型上,不同于 MLImageClassifier 返回的 MLClassifierMetrics,这里表格模型评估返回的是 MLRegressorMetrics 结构体,其中的maximumError属性对应的是最坏情况下,评估结果和实际结果的方差;相应的均方误差则用rootMeanSquaredError属性来查看。对于如何利用 Metrics 来提高模型的准确性,苹果则给出了官方说明:Improving Your Model’s Accuracy

总结

Create ML 的出现解决了 iOS 平台上机器学习模型数量少的问题。其灵活的 API 和原生系统的支持使得 App 开发者可以更自由得定义和使用机器学习。然而相比于 TensorFlow,Create ML 不够成熟:模型生成的局限在特定的数据,对于其他数据诸如声音、图像依然无法支持,而且也无法处理图片、文字混合数据。尽管如此,Create ML 提供了简洁易用的 API,与 Core ML 一起构成了苹果的机器学习生态,展示了机器学习在 iOS 开发上的巨大潜力。