用ML Model Binding 简化 TensorFlowLite 使用流程

270 阅读6分钟

前言

借助 Android Studio 编译时期提供的 MLModelBinding 特性,可以快速帮我们生成推理模型的模板代码,提供更高层级的 API,屏蔽数据处理的细节,简化调用流程。

LiteRT 的痛点

无论使用 PyTorchLite 还是 LiteRT(TensorFlowLite),面对不同结构的推理模型,需要构造不同的输入数据,处理不同的输出结果。尤其是 Bitmap 类型的数据,涉及数据格式转换,归一化处理等各种繁琐的步骤,而这些步骤稍有不慎就会出错,同时由于推理接口完全是黑盒般的存在,根本无法进行 debug ,只能一行行检查传给模型的数据和返回结果的处理是否有问题。其实,我们可以借助 MLModelBinding 这个特性,绕过这些繁琐的步骤,用更简洁的方式使用推理接口。

在上一篇 Google 端侧 AI 框架 LiteRT 初探学习和了解 Google 官方提供的 LiteRT 在 Android - 掘金 手写数字识别这样一个简单的场景中,依然需要我们完成模型的加载,获取输入输出数据的结构(即Shape),然后根据这些内容将我们的原始数据转换为模型可以接受的数据。

mnist_classify.png

虽然流程不算复杂,但是为了保障输入输出的正确,我们需要提前获取一些前置信息。

而对于像是风格迁移这类更加复杂的场景,我们需要处理更多的东西

style-process.png

对于输入内容需要进行图片裁剪、归一化、数据类型转换,对于输出需要将原始的数据还原为使用与 Bitmap 的 0~255 数据。这些步骤都需要我们对神经网络模型有一些了解,否则完全不知道在干什么以及为什么要这么做。

MLModelBinding

什么是 MLModelBinding

ML Model Binding 是 Android Studio(和 Android Gradle Plugin)提供的功能,它可以让你自动将 .tflite 模型绑定成 Java/Kotlin 类,这样你就不需要手动处理 Tensor 输入输出。使用它,你可以像使用普通对象那样使用 TFLite 模型。

下面介绍如何使用 MLModelBinding

使用 MLModelBinding

  • 修改 build.gradle.kts , 添加 mlModelBinding 功能
    buildFeatures {
        ...
        mlModelBinding = true
    }
  • 在项目 src 目录下创建和 java 目录平级的目录 ml,将适用于 TensorFlow Lite 的模型放在该文件夹下

model_list.png

  • 然后进行一次 build (注意是 build 不是 sync) ,执行完成后就可以在 build 目录下看到有 ml_source_out 文件夹,其中会生成基于模型的模板文件,但是我们并不用关注这些文件

model_result.png

whitebox_cartoon_gan_dr.tflite 这个模型为例,这是一个可以实现卡通动画风格迁移的模型,可以将输入图片转换为特定的风格,借助 MLModelBinding 特性,可以非常快捷的使用这个模型。

    fun transStyle(context: Context,bitmap: Bitmap) {
        val model = WhiteboxCartoonGanDr.newInstance(context)
        val tensorImage = TensorImage.fromBitmap(bitmap)
        val out = model.process(tensorImage)
        val result = out.cartoonizedImageAsTensorImage
        val bitmap = result.bitmap
        // show bitmap
            
        model.close()
    }

可以看到,就是这么简单,无需关心其中复杂的转换,只需要 4 个步骤

  1. 创建模型的实例
  2. 将输入的 bitmap 转换为 tensorImage
  3. 调用 process 进行推理
  4. 从返回结果中直接获取 bitmap

甚至以代码都不用写,模板代码都已经生成好了,我们只需要做微调 。点击 WhiteboxCartoonGanDr 这个类,Android Studio 会自动打开这样一个文件。

code-style.png

其中包含了对于模型介绍,输入输出参数的协议,以及最后使用这个模型的模板代码,甚至还提供了 Kotlin 和 Java 两种语言的版本。

我们可以看一下这个模型的效果

综合来说是形似,但是也有点卡通风格的那个味道了。

MLModelBinding 原理

MLModelBinding 和我们之前熟悉的各类 Binding (DataBinding/ViewBinding)类似,通过编译期生成的中间代码,提供了更上层的 API,将一些模板代码封装到了底层,方便开发者聚焦于实际业务,而不是各种数据的封装和转换。

我们可以看一下 build 过程中生成的文件

public final class WhiteboxCartoonGanDr {
  @NonNull
  private final ImageProcessor sourceImageProcessor;

  private int sourceImageHeight;

  private int sourceImageWidth;

  @NonNull
  private final ImageProcessor cartoonizedImagePostProcessor;

  private int cartoonizedImageHeight;

  private int cartoonizedImageWidth;

  @NonNull
  private final Model model;

  private WhiteboxCartoonGanDr(@NonNull Context context, @NonNull Model.Options options) throws
      IOException {
    model = Model.createModel(context, "whitebox_cartoon_gan_dr.tflite", options);
    MetadataExtractor extractor = new MetadataExtractor(model.getData());
    ImageProcessor.Builder sourceImageProcessorBuilder = new ImageProcessor.Builder()
      .add(new ResizeOp(512, 512, ResizeMethod.NEAREST_NEIGHBOR))
      .add(new NormalizeOp(new float[] {127.5f}, new float[] {127.5f}))
      .add(new QuantizeOp(0f, 0.0f))
      .add(new CastOp(DataType.FLOAT32));
    sourceImageProcessor = sourceImageProcessorBuilder.build();
    ImageProcessor.Builder cartoonizedImagePostProcessorBuilder = new ImageProcessor.Builder()
      .add(new DequantizeOp((float)0, (float)0.0))
      .add(new NormalizeOp(new float[] {-1.0f}, new float[] {0.00784313f}))
      .add(new CastOp(DataType.UINT8));
    cartoonizedImagePostProcessor = cartoonizedImagePostProcessorBuilder.build();
  }

  @NonNull
  public static WhiteboxCartoonGanDr newInstance(@NonNull Context context) throws IOException {
    return new WhiteboxCartoonGanDr(context, (new Model.Options.Builder()).build());
  }


  @NonNull
  public Outputs process(@NonNull TensorImage sourceImage) {
    sourceImageHeight = sourceImage.getHeight();
    sourceImageWidth = sourceImage.getWidth();
    TensorImage processedsourceImage = sourceImageProcessor.process(sourceImage);
    Outputs outputs = new Outputs(model);
    model.run(new Object[] {processedsourceImage.getBuffer()}, outputs.getBuffer());
    return outputs;
  }


  @NonNull
  public Outputs process(@NonNull TensorBuffer sourceImage) {
    TensorBuffer processedsourceImage = sourceImage;
    Outputs outputs = new Outputs(model);
    model.run(new Object[] {processedsourceImage.getBuffer()}, outputs.getBuffer());
    return outputs;
  }

  public class Outputs {
    private TensorImage cartoonizedImage;

    private Outputs(Model model) {
      this.cartoonizedImage = new TensorImage(DataType.FLOAT32);
      cartoonizedImage.load(TensorBuffer.createFixedSize(model.getOutputTensorShape(0), DataType.FLOAT32));
    }

    @NonNull
    public TensorImage getCartoonizedImageAsTensorImage() {
      return cartoonizedImagePostProcessor.process(cartoonizedImage);
    }

  }
}

可以看到对于输入数据 Bitmap 按照模型约束进行缩放、归一化的操作一点也没有减少。只不过将细节隐藏在了这里。 TensorImage 不仅支持 Bitmap, 还提供了对其他格式图片的支持,还支持不同类型的数据,Float32,Int8 等等都只支持的。

因此,借助 MLBinding 可以大大简化使用 TensorFlow Lite 模型的成本。

任意风格转换

我们借助 MLModelBinding 实现一个任意风格类型图像转换的能力,这里我们使用 dr-predict 提供的风格检测模型和风格转换模型

    fun tansFree(context: Context, styleBitmap: Bitmap, contentBitmap: Bitmap) {

        val styleModel = ArbitraryImageStylizationV1Tflite256Int8TransferV1.newInstance(context)
        val predict = PredictInt8.newInstance(context)

        val styleImage = TensorImage.fromBitmap(styleBitmap)
        val tensorImage = TensorImage.fromBitmap(contentBitmap)

        // Runs model inference and gets result.
        val outputs = predict.process(styleImage)

        val styleBottleneck = outputs.styleBottleneckAsTensorBuffer
        Log.i(TAG, "styleBottleneck ${styleBottleneck}")


        val out = styleModel.process(tensorImage, styleBottleneck)

        val result = out.styledImageAsTensorImage
        val bitmap = result.bitmap

        predict.close()
        styleModel.close()
    }

由于我们需要将 styleBitmap 的风格迁移到 contentBitmap 上,因此这里需要两个模型,一个用于获取风格,一个用于进行转换。

  • PredictInt8 用于获取风格特征
  • ArbitraryImageStylizationV1Tflite256Int8TransferV1 基于特征实现对内容图片的风格迁移

可以看到使用 MLModelBinding 之后,不用再纠结于模型转换的各种细节了,相反如果使用传统的方式,需要按以下步骤实现。

hardcode-trans.png

这还不包括获取模型原始输入输出信息的步骤。使用 MLModelBinding 我们只需要关心最核心的实现,即 processs 的调用流程,根据模板代码提供的输出,可以方便实现具体的业务流程。

可以看看任意风格转换的效果

示例1示例2示例3
转存失败,建议直接上传图片文件转存失败,建议直接上传图片文件转存失败,建议直接上传图片文件

可以看到任意风格转换模型,输出结果的大小是固定的。这一点,我们从 MLBinding 输出的模板文件就能看到

trans-inner.png

其实这里从输入就做了限制,模型输入只支持 384x383 的三通道彩色图片。而输出也是同样的格式。

总结

使用 MLModelBinding 这个特性,通过编译其生成的中间件封装的接口,屏蔽了底层繁琐的细节,简化了模型的使用流程,让开发者可以按照自己熟悉的内容去使用,而不用过分关心实现细节,提升了效率,也避免了无谓的 bug。

参考