在TensorFlow Lite中进行设备上的训练

677 阅读8分钟

由TensorFlow Lite团队发布

TensorFlow Lite是谷歌的机器学习框架,用于在多种设备和表面上部署机器学习模型,如移动(iOS和Android)、台式机和其他边缘设备。最近,我们也增加了对在浏览器中运行TensorFlow Lite模型的支持。为了使用TensorFlow Lite构建应用程序,你可以使用TensorFlow Hub的现成模型,或者使用转换器将现有的TensorFlow模型转换成TensorFlow Lite模型。一旦模型被部署在应用程序中,你可以根据输入数据在模型上运行推理

TensorFlow Lite现在支持在设备上训练你的模型,除了运行推理之外。设备上的训练实现了有趣的个性化用例,其中模型可以根据用户需求进行微调。例如,你可以部署一个图像分类模型,并允许一个用户微调该模型以使用迁移学习识别鸟类,同时允许另一个用户重新训练同一模型以识别水果。这个新功能在TensorFlow 2.7和更高版本中可用,目前可用于Android应用程序。(iOS的支持将在未来添加)。

设备上的训练也是联邦学习用例的必要基础,可以在分散的数据上训练全球模型。这篇博文不涉及联邦学习,而是专注于帮助你在安卓应用中整合设备上的训练。

在这篇文章的后面,我们将参考Colab安卓样本应用程序,指导你通过设备上学习的端到端实施路径来微调图像分类模型。

对早期方法的改进

在2019年的博文中,我们介绍了设备上的训练概念和TensorFlow Lite中的一个设备上训练的例子。然而,有几个限制。例如,定制模型结构和优化器并不容易。你还必须处理多个物理TensorFlow Lite(.tflite)模型,而不是单个TensorFlow Lite模型。同样,也没有简单的方法来存储和更新训练权重。我们最新的TensorFlow Lite版本通过为设备上的训练提供更方便的选项,简化了这一过程,如下文所述。

它是如何工作的?

为了部署一个内置设备上训练的TensorFlow Lite模型,以下是高水平的步骤。

  • 建立一个TensorFlow模型进行训练和推理
  • 将TensorFlow模型转换为TensorFlow Lite格式
  • 在你的安卓应用中集成该模型
  • 在应用程序中调用模型训练,类似于你调用模型推理的方式。

这些步骤解释如下。

建立一个TensorFlow模型进行训练和推理

TensorFlow Lite模型不仅应该支持模型推理,还应该支持模型训练,这通常涉及到将模型的权重保存到文件系统中,并从文件系统中恢复权重。这样做是为了在每个训练纪元后保存训练权重,以便下一个训练纪元可以使用前一个纪元的权重,而不是从头开始训练。

我们建议的方法是实现这些tf.函数来表示训练、推理、保存权重和加载权重。

  • 一个*训练*函数,使用训练数据训练模型。下面的train函数进行预测,计算损失(或误差),并使用tf.GradientTape()来记录自动区分的操作,更新模型的参数。

    # The `train` function takes a batch of input images and labels.@tf.function(input_signature=[     tf.TensorSpec([None, IMG_SIZE, IMG_SIZE], tf.float32),     tf.TensorSpec([None, 10], tf.float32), ])def train(self, x, y):   with tf.GradientTape() as tape:     prediction = self.model(x)     loss = self._LOSS_FN(prediction, y)   gradients = tape.gradient(loss, self.model.trainable_variables)   self._OPTIM.apply_gradients(       zip(gradients, self.model.trainable_variables))   result = {"loss": loss}   for grad in gradients:     result[grad.name] = grad   return result
    
  • 一个调用模型推理的*inferpredict*函数。这类似于你目前使用TensorFlow Lite进行推理的方式。

    @tf.function(input_signature=[tf.TensorSpec([None, IMG_SIZE, IMG_SIZE], tf.float32)]) def predict(self, x):   return {       "output": self.model(x)   }
    
  • 一个*保存/恢复*函数,将训练权重(即模型使用的参数)以Checkpoints格式保存到文件系统。保存函数的代码如下所示。

    @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)]) def save(self, checkpoint_path):   tensor_names = [weight.name for weight in self.model.weights]   tensors_to_save = [weight.read_value() for weight in self.model.weights]   tf.raw_ops.Save(       filename=checkpoint_path, tensor_names=tensor_names,       data=tensors_to_save, name='save')   return {       "checkpoint_path": checkpoint_path   }
    

转换为TensorFlow Lite格式

你可能已经熟悉了你的TensorFlow模型转换为TensorFlow Lite格式的工作流程。一些用于设备上训练的低级功能(例如,存储模型参数的变量)仍然是实验性的,其他的(例如,权重序列化)目前依赖于TF选择操作符,所以你需要在转换过程中设置这些标志。你可以在Colab中找到一个你需要设置的所有标志的例子。

# Convert the modelconverter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_DIR)converter.target_spec.supported_ops = [   tf.lite.OpsSet.TFLITE_BUILTINS,  # enable TensorFlow Lite ops.   tf.lite.OpsSet.SELECT_TF_OPS  # enable TensorFlow ops.]converter.experimental_enable_resource_variables = Truetflite_model = converter.convert()

在你的安卓应用中整合模型

一旦你将你的模型转换为TensorFlow Lite格式,你就可以将模型集成到你的应用程序中了!请参考Android应用程序样本。更多细节请参考安卓应用样本。

在应用程序中调用模型训练和推理

在Android上,TensorFlow Lite的设备上的训练可以使用Java或C++ APIs进行。你可以创建一个TensorFlow Lite解释器的实例来加载一个模型并驱动模型训练任务。我们之前定义了多个tf.函数:这些函数可以使用TensorFlow Lite对签名的支持来调用,它允许一个TensorFlow Lite模型支持多个 "入口 "点。例如,我们为设备上的训练定义了一个train函数,它是模型的签名之一。通过指定签名的名称('train'),可以使用TensorFlow Lite的runSignature方法调用train函数。

// Run training for a few steps. float[] losses = new float[NUM_EPOCHS]; for (int epoch = 0; epoch < NUM_EPOCHS; ++epoch) { for (int batchIdx = 0; batchIdx < NUM_BATCHES; ++batchIdx) { Map<String, Object> inputs = new HashMap<>>(); inputs.put("x", trainImageBatches.get(batchIdx)); inputs.put("y", trainLabelBatches.get(batchIdx)); Map<String, Object> outputs = new HashMap<>(); FloatBuffer loss = FloatBuffer.allocate(1); outputs.put("loss", loss); interpreter.runSignature(inputs, outputs, "train"); // Record the last loss. if (batchIdx == NUM_BATCHES - 1) losses[epoch] = loss.get(0); } }

同样地,下面的例子显示了如何使用模型的'infer'签名来调用推理。

try (Interpreter anotherInterpreter = new Interpreter(modelBuffer)) {    // Restore the weights from the checkpoint file.    int NUM_TESTS = 10;    FloatBuffer testImages = FloatBuffer.allocateDirect(NUM_TESTS * 28 * 28).order(ByteOrder.nativeOrder());    FloatBuffer output = FloatBuffer.allocateDirect(NUM_TESTS * 10).order(ByteOrder.nativeOrder());    // Fill the test data.    // Run the inference.    Map<String, Object> inputs = new HashMap<>>();    inputs.put("x", testImages.rewind());    Map<String, Object> outputs = new HashMap<>();    outputs.put("output", output);    anotherInterpreter.runSignature(inputs, outputs, "infer");    output.rewind();    // Process the result to get the final category values.    int[] testLabels = new int[NUM_TESTS];    for (int i = 0; i < NUM_TESTS; ++i) {        int index = 0;        for (int j = 1; j < 10; ++j) {            if (output.get(i * 10 + index) < output.get(i * 10 + j))                index = testLabels[j];        }        testLabels[i] = index;    }}

就这样!你现在有一个TensorFlow Lite模型,能够使用设备上的训练。我们希望这个代码演练能让你对如何在TensorFlow Lite中运行设备上的训练有一个很好的想法,我们很期待看到你的进展。

实际考虑

理论上,你应该能够将TensorFlow Lite中的设备上训练应用于TensorFlow所支持的任何用例。然而,在现实中,有一些实际的考虑,你需要在你的应用程序中部署设备上的训练之前记住。

  • 使用案例。Colab的例子显示了一个视觉用例的设备上训练的例子。如果你遇到特定模型或用例的问题,请在GitHub上告诉我们。
  • 性能。根据不同的使用情况,设备上的训练可能需要几秒钟到更长的时间。如果你运行设备上的训练作为面向用户的功能的一部分(例如,你的最终用户正在与该功能进行交互),你应该测量你的应用程序中各种可能的训练输入的时间,以限制训练时间。如果你的用例需要很长的设备上的训练时间,可以考虑先用桌面或云端训练一个模型,然后在设备上进行微调。
  • 电池的使用。就像模型推理一样,在设备上调用模型训练可能会导致电池消耗。如果模型训练是一个不面向用户的功能的一部分,我们建议遵循Android的指导方针来实现后台任务。
  • 从头开始训练与重新训练。理论上,应该可以使用上述功能在设备上从头开始训练一个模型。然而,在现实中,从头开始训练涉及大量的训练数据,即使在拥有强大处理器的服务器上也可能需要几天时间。因此,对于设备上的应用,我们建议在已经训练好的模型上进行再训练(即转移学习),如Colab的例子中所示。

路线图

未来的工作包括(但不限于)在iOS上支持设备上的训练,性能改进以利用设备上的加速器(如GPU)进行设备上的训练,通过在TensorFlow Lite中实现更多的训练操作来减少二进制大小,更高级别的API支持(如通过TensorFlow Lite任务库)来抽象出实现细节和涵盖其他设备上训练用例(如NLP)。我们的长期路线图包括可能提供设备上的端到端联合学习解决方案。

接下来的步骤

谢谢您的阅读!我们很高兴看到你使用设备上的学习来建立什么。再一次,这里有样本应用程序和Colab的链接。如果你有任何反馈,请在TensorFlow论坛GitHub上告诉我们

鸣谢

这篇文章反映了谷歌TensorFlow Lite团队中许多人的重大贡献,包括Michelle Carney, Lawrence Chan, Jaesung Chung, Jared Duke, Terry Heo, Jared Lim, Yu-Cheng Ling, Thai Nguyen, Karim Nosseir, Arun Venkatesan, Haoliang Zhang, 其他TensorFlow Lite团队成员,以及我们在谷歌研究院的合作者。