由TensorFlow Lite团队发布
TensorFlow Lite是谷歌的机器学习框架,用于在多种设备和表面上部署机器学习模型,如移动(iOS和Android)、台式机和其他边缘设备。最近,我们也增加了对在浏览器中运行TensorFlow Lite模型的支持。为了使用TensorFlow Lite构建应用程序,你可以使用TensorFlow Hub的现成模型,或者使用转换器将现有的TensorFlow模型转换成TensorFlow Lite模型。一旦模型被部署在应用程序中,你可以根据输入数据在模型上运行推理。
TensorFlow Lite现在支持在设备上训练你的模型,除了运行推理之外。设备上的训练实现了有趣的个性化用例,其中模型可以根据用户需求进行微调。例如,你可以部署一个图像分类模型,并允许一个用户微调该模型以使用迁移学习识别鸟类,同时允许另一个用户重新训练同一模型以识别水果。这个新功能在TensorFlow 2.7和更高版本中可用,目前可用于Android应用程序。(iOS的支持将在未来添加)。
设备上的训练也是联邦学习用例的必要基础,可以在分散的数据上训练全球模型。这篇博文不涉及联邦学习,而是专注于帮助你在安卓应用中整合设备上的训练。
在这篇文章的后面,我们将参考Colab和安卓样本应用程序,指导你通过设备上学习的端到端实施路径来微调图像分类模型。
对早期方法的改进
在2019年的博文中,我们介绍了设备上的训练概念和TensorFlow Lite中的一个设备上训练的例子。然而,有几个限制。例如,定制模型结构和优化器并不容易。你还必须处理多个物理TensorFlow Lite(.tflite)模型,而不是单个TensorFlow Lite模型。同样,也没有简单的方法来存储和更新训练权重。我们最新的TensorFlow Lite版本通过为设备上的训练提供更方便的选项,简化了这一过程,如下文所述。
它是如何工作的?
为了部署一个内置设备上训练的TensorFlow Lite模型,以下是高水平的步骤。
- 建立一个TensorFlow模型进行训练和推理
- 将TensorFlow模型转换为TensorFlow Lite格式
- 在你的安卓应用中集成该模型
- 在应用程序中调用模型训练,类似于你调用模型推理的方式。
这些步骤解释如下。
建立一个TensorFlow模型进行训练和推理
TensorFlow Lite模型不仅应该支持模型推理,还应该支持模型训练,这通常涉及到将模型的权重保存到文件系统中,并从文件系统中恢复权重。这样做是为了在每个训练纪元后保存训练权重,以便下一个训练纪元可以使用前一个纪元的权重,而不是从头开始训练。
我们建议的方法是实现这些tf.函数来表示训练、推理、保存权重和加载权重。
-
一个*训练*函数,使用训练数据训练模型。下面的train函数进行预测,计算损失(或误差),并使用tf.GradientTape()来记录自动区分的操作,更新模型的参数。
# The `train` function takes a batch of input images and labels.@tf.function(input_signature=[ tf.TensorSpec([None, IMG_SIZE, IMG_SIZE], tf.float32), tf.TensorSpec([None, 10], tf.float32), ])def train(self, x, y): with tf.GradientTape() as tape: prediction = self.model(x) loss = self._LOSS_FN(prediction, y) gradients = tape.gradient(loss, self.model.trainable_variables) self._OPTIM.apply_gradients( zip(gradients, self.model.trainable_variables)) result = {"loss": loss} for grad in gradients: result[grad.name] = grad return result -
一个调用模型推理的*infer或predict*函数。这类似于你目前使用TensorFlow Lite进行推理的方式。
@tf.function(input_signature=[tf.TensorSpec([None, IMG_SIZE, IMG_SIZE], tf.float32)]) def predict(self, x): return { "output": self.model(x) } -
一个*保存/恢复*函数,将训练权重(即模型使用的参数)以Checkpoints格式保存到文件系统。保存函数的代码如下所示。
@tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)]) def save(self, checkpoint_path): tensor_names = [weight.name for weight in self.model.weights] tensors_to_save = [weight.read_value() for weight in self.model.weights] tf.raw_ops.Save( filename=checkpoint_path, tensor_names=tensor_names, data=tensors_to_save, name='save') return { "checkpoint_path": checkpoint_path }
转换为TensorFlow Lite格式
你可能已经熟悉了将你的TensorFlow模型转换为TensorFlow Lite格式的工作流程。一些用于设备上训练的低级功能(例如,存储模型参数的变量)仍然是实验性的,其他的(例如,权重序列化)目前依赖于TF选择操作符,所以你需要在转换过程中设置这些标志。你可以在Colab中找到一个你需要设置的所有标志的例子。
# Convert the modelconverter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_DIR)converter.target_spec.supported_ops = [ tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops. tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.]converter.experimental_enable_resource_variables = Truetflite_model = converter.convert()
在你的安卓应用中整合模型
一旦你将你的模型转换为TensorFlow Lite格式,你就可以将模型集成到你的应用程序中了!请参考Android应用程序样本。更多细节请参考安卓应用样本。
在应用程序中调用模型训练和推理
在Android上,TensorFlow Lite的设备上的训练可以使用Java或C++ APIs进行。你可以创建一个TensorFlow Lite解释器的实例来加载一个模型并驱动模型训练任务。我们之前定义了多个tf.函数:这些函数可以使用TensorFlow Lite对签名的支持来调用,它允许一个TensorFlow Lite模型支持多个 "入口 "点。例如,我们为设备上的训练定义了一个train函数,它是模型的签名之一。通过指定签名的名称('train'),可以使用TensorFlow Lite的runSignature方法调用train函数。
// Run training for a few steps. float[] losses = new float[NUM_EPOCHS]; for (int epoch = 0; epoch < NUM_EPOCHS; ++epoch) { for (int batchIdx = 0; batchIdx < NUM_BATCHES; ++batchIdx) { Map<String, Object> inputs = new HashMap<>>(); inputs.put("x", trainImageBatches.get(batchIdx)); inputs.put("y", trainLabelBatches.get(batchIdx)); Map<String, Object> outputs = new HashMap<>(); FloatBuffer loss = FloatBuffer.allocate(1); outputs.put("loss", loss); interpreter.runSignature(inputs, outputs, "train"); // Record the last loss. if (batchIdx == NUM_BATCHES - 1) losses[epoch] = loss.get(0); } }
同样地,下面的例子显示了如何使用模型的'infer'签名来调用推理。
try (Interpreter anotherInterpreter = new Interpreter(modelBuffer)) { // Restore the weights from the checkpoint file. int NUM_TESTS = 10; FloatBuffer testImages = FloatBuffer.allocateDirect(NUM_TESTS * 28 * 28).order(ByteOrder.nativeOrder()); FloatBuffer output = FloatBuffer.allocateDirect(NUM_TESTS * 10).order(ByteOrder.nativeOrder()); // Fill the test data. // Run the inference. Map<String, Object> inputs = new HashMap<>>(); inputs.put("x", testImages.rewind()); Map<String, Object> outputs = new HashMap<>(); outputs.put("output", output); anotherInterpreter.runSignature(inputs, outputs, "infer"); output.rewind(); // Process the result to get the final category values. int[] testLabels = new int[NUM_TESTS]; for (int i = 0; i < NUM_TESTS; ++i) { int index = 0; for (int j = 1; j < 10; ++j) { if (output.get(i * 10 + index) < output.get(i * 10 + j)) index = testLabels[j]; } testLabels[i] = index; }}
就这样!你现在有一个TensorFlow Lite模型,能够使用设备上的训练。我们希望这个代码演练能让你对如何在TensorFlow Lite中运行设备上的训练有一个很好的想法,我们很期待看到你的进展。
实际考虑
理论上,你应该能够将TensorFlow Lite中的设备上训练应用于TensorFlow所支持的任何用例。然而,在现实中,有一些实际的考虑,你需要在你的应用程序中部署设备上的训练之前记住。
- 使用案例。Colab的例子显示了一个视觉用例的设备上训练的例子。如果你遇到特定模型或用例的问题,请在GitHub上告诉我们。
- 性能。根据不同的使用情况,设备上的训练可能需要几秒钟到更长的时间。如果你运行设备上的训练作为面向用户的功能的一部分(例如,你的最终用户正在与该功能进行交互),你应该测量你的应用程序中各种可能的训练输入的时间,以限制训练时间。如果你的用例需要很长的设备上的训练时间,可以考虑先用桌面或云端训练一个模型,然后在设备上进行微调。
- 电池的使用。就像模型推理一样,在设备上调用模型训练可能会导致电池消耗。如果模型训练是一个不面向用户的功能的一部分,我们建议遵循Android的指导方针来实现后台任务。
- 从头开始训练与重新训练。理论上,应该可以使用上述功能在设备上从头开始训练一个模型。然而,在现实中,从头开始训练涉及大量的训练数据,即使在拥有强大处理器的服务器上也可能需要几天时间。因此,对于设备上的应用,我们建议在已经训练好的模型上进行再训练(即转移学习),如Colab的例子中所示。
路线图
未来的工作包括(但不限于)在iOS上支持设备上的训练,性能改进以利用设备上的加速器(如GPU)进行设备上的训练,通过在TensorFlow Lite中实现更多的训练操作来减少二进制大小,更高级别的API支持(如通过TensorFlow Lite任务库)来抽象出实现细节和涵盖其他设备上训练用例(如NLP)。我们的长期路线图包括可能提供设备上的端到端联合学习解决方案。
接下来的步骤
谢谢您的阅读!我们很高兴看到你使用设备上的学习来建立什么。再一次,这里有样本应用程序和Colab的链接。如果你有任何反馈,请在TensorFlow论坛 或GitHub上告诉我们。
鸣谢
这篇文章反映了谷歌TensorFlow Lite团队中许多人的重大贡献,包括Michelle Carney, Lawrence Chan, Jaesung Chung, Jared Duke, Terry Heo, Jared Lim, Yu-Cheng Ling, Thai Nguyen, Karim Nosseir, Arun Venkatesan, Haoliang Zhang, 其他TensorFlow Lite团队成员,以及我们在谷歌研究院的合作者。