前言
借助 Android Studio 编译时期提供的 MLModelBinding 特性,可以快速帮我们生成推理模型的模板代码,提供更高层级的 API,屏蔽数据处理的细节,简化调用流程。
LiteRT 的痛点
无论使用 PyTorchLite 还是 LiteRT(TensorFlowLite),面对不同结构的推理模型,需要构造不同的输入数据,处理不同的输出结果。尤其是 Bitmap 类型的数据,涉及数据格式转换,归一化处理等各种繁琐的步骤,而这些步骤稍有不慎就会出错,同时由于推理接口完全是黑盒般的存在,根本无法进行 debug ,只能一行行检查传给模型的数据和返回结果的处理是否有问题。其实,我们可以借助 MLModelBinding 这个特性,绕过这些繁琐的步骤,用更简洁的方式使用推理接口。
在上一篇 Google 端侧 AI 框架 LiteRT 初探学习和了解 Google 官方提供的 LiteRT 在 Android - 掘金 手写数字识别这样一个简单的场景中,依然需要我们完成模型的加载,获取输入输出数据的结构(即Shape),然后根据这些内容将我们的原始数据转换为模型可以接受的数据。
虽然流程不算复杂,但是为了保障输入输出的正确,我们需要提前获取一些前置信息。
而对于像是风格迁移这类更加复杂的场景,我们需要处理更多的东西
对于输入内容需要进行图片裁剪、归一化、数据类型转换,对于输出需要将原始的数据还原为使用与 Bitmap 的 0~255 数据。这些步骤都需要我们对神经网络模型有一些了解,否则完全不知道在干什么以及为什么要这么做。
MLModelBinding
什么是 MLModelBinding
ML Model Binding 是 Android Studio(和 Android Gradle Plugin)提供的功能,它可以让你自动将 .tflite 模型绑定成 Java/Kotlin 类,这样你就不需要手动处理 Tensor 输入输出。使用它,你可以像使用普通对象那样使用 TFLite 模型。
下面介绍如何使用 MLModelBinding
使用 MLModelBinding
- 修改 build.gradle.kts , 添加 mlModelBinding 功能
buildFeatures {
...
mlModelBinding = true
}
- 在项目 src 目录下创建和 java 目录平级的目录
ml,将适用于 TensorFlow Lite 的模型放在该文件夹下
- 然后进行一次 build (注意是 build 不是 sync) ,执行完成后就可以在 build 目录下看到有 ml_source_out 文件夹,其中会生成基于模型的模板文件,但是我们并不用关注这些文件
以 whitebox_cartoon_gan_dr.tflite 这个模型为例,这是一个可以实现卡通动画风格迁移的模型,可以将输入图片转换为特定的风格,借助 MLModelBinding 特性,可以非常快捷的使用这个模型。
fun transStyle(context: Context,bitmap: Bitmap) {
val model = WhiteboxCartoonGanDr.newInstance(context)
val tensorImage = TensorImage.fromBitmap(bitmap)
val out = model.process(tensorImage)
val result = out.cartoonizedImageAsTensorImage
val bitmap = result.bitmap
// show bitmap
model.close()
}
可以看到,就是这么简单,无需关心其中复杂的转换,只需要 4 个步骤
- 创建模型的实例
- 将输入的 bitmap 转换为 tensorImage
- 调用 process 进行推理
- 从返回结果中直接获取 bitmap
甚至以代码都不用写,模板代码都已经生成好了,我们只需要做微调 。点击 WhiteboxCartoonGanDr 这个类,Android Studio 会自动打开这样一个文件。
其中包含了对于模型介绍,输入输出参数的协议,以及最后使用这个模型的模板代码,甚至还提供了 Kotlin 和 Java 两种语言的版本。
我们可以看一下这个模型的效果
综合来说是形似,但是也有点卡通风格的那个味道了。
MLModelBinding 原理
MLModelBinding 和我们之前熟悉的各类 Binding (DataBinding/ViewBinding)类似,通过编译期生成的中间代码,提供了更上层的 API,将一些模板代码封装到了底层,方便开发者聚焦于实际业务,而不是各种数据的封装和转换。
我们可以看一下 build 过程中生成的文件
public final class WhiteboxCartoonGanDr {
@NonNull
private final ImageProcessor sourceImageProcessor;
private int sourceImageHeight;
private int sourceImageWidth;
@NonNull
private final ImageProcessor cartoonizedImagePostProcessor;
private int cartoonizedImageHeight;
private int cartoonizedImageWidth;
@NonNull
private final Model model;
private WhiteboxCartoonGanDr(@NonNull Context context, @NonNull Model.Options options) throws
IOException {
model = Model.createModel(context, "whitebox_cartoon_gan_dr.tflite", options);
MetadataExtractor extractor = new MetadataExtractor(model.getData());
ImageProcessor.Builder sourceImageProcessorBuilder = new ImageProcessor.Builder()
.add(new ResizeOp(512, 512, ResizeMethod.NEAREST_NEIGHBOR))
.add(new NormalizeOp(new float[] {127.5f}, new float[] {127.5f}))
.add(new QuantizeOp(0f, 0.0f))
.add(new CastOp(DataType.FLOAT32));
sourceImageProcessor = sourceImageProcessorBuilder.build();
ImageProcessor.Builder cartoonizedImagePostProcessorBuilder = new ImageProcessor.Builder()
.add(new DequantizeOp((float)0, (float)0.0))
.add(new NormalizeOp(new float[] {-1.0f}, new float[] {0.00784313f}))
.add(new CastOp(DataType.UINT8));
cartoonizedImagePostProcessor = cartoonizedImagePostProcessorBuilder.build();
}
@NonNull
public static WhiteboxCartoonGanDr newInstance(@NonNull Context context) throws IOException {
return new WhiteboxCartoonGanDr(context, (new Model.Options.Builder()).build());
}
@NonNull
public Outputs process(@NonNull TensorImage sourceImage) {
sourceImageHeight = sourceImage.getHeight();
sourceImageWidth = sourceImage.getWidth();
TensorImage processedsourceImage = sourceImageProcessor.process(sourceImage);
Outputs outputs = new Outputs(model);
model.run(new Object[] {processedsourceImage.getBuffer()}, outputs.getBuffer());
return outputs;
}
@NonNull
public Outputs process(@NonNull TensorBuffer sourceImage) {
TensorBuffer processedsourceImage = sourceImage;
Outputs outputs = new Outputs(model);
model.run(new Object[] {processedsourceImage.getBuffer()}, outputs.getBuffer());
return outputs;
}
public class Outputs {
private TensorImage cartoonizedImage;
private Outputs(Model model) {
this.cartoonizedImage = new TensorImage(DataType.FLOAT32);
cartoonizedImage.load(TensorBuffer.createFixedSize(model.getOutputTensorShape(0), DataType.FLOAT32));
}
@NonNull
public TensorImage getCartoonizedImageAsTensorImage() {
return cartoonizedImagePostProcessor.process(cartoonizedImage);
}
}
}
可以看到对于输入数据 Bitmap 按照模型约束进行缩放、归一化的操作一点也没有减少。只不过将细节隐藏在了这里。 TensorImage 不仅支持 Bitmap, 还提供了对其他格式图片的支持,还支持不同类型的数据,Float32,Int8 等等都只支持的。
因此,借助 MLBinding 可以大大简化使用 TensorFlow Lite 模型的成本。
任意风格转换
我们借助 MLModelBinding 实现一个任意风格类型图像转换的能力,这里我们使用 dr-predict 提供的风格检测模型和风格转换模型
fun tansFree(context: Context, styleBitmap: Bitmap, contentBitmap: Bitmap) {
val styleModel = ArbitraryImageStylizationV1Tflite256Int8TransferV1.newInstance(context)
val predict = PredictInt8.newInstance(context)
val styleImage = TensorImage.fromBitmap(styleBitmap)
val tensorImage = TensorImage.fromBitmap(contentBitmap)
// Runs model inference and gets result.
val outputs = predict.process(styleImage)
val styleBottleneck = outputs.styleBottleneckAsTensorBuffer
Log.i(TAG, "styleBottleneck ${styleBottleneck}")
val out = styleModel.process(tensorImage, styleBottleneck)
val result = out.styledImageAsTensorImage
val bitmap = result.bitmap
predict.close()
styleModel.close()
}
由于我们需要将 styleBitmap 的风格迁移到 contentBitmap 上,因此这里需要两个模型,一个用于获取风格,一个用于进行转换。
- PredictInt8 用于获取风格特征
- ArbitraryImageStylizationV1Tflite256Int8TransferV1 基于特征实现对内容图片的风格迁移
可以看到使用 MLModelBinding 之后,不用再纠结于模型转换的各种细节了,相反如果使用传统的方式,需要按以下步骤实现。
这还不包括获取模型原始输入输出信息的步骤。使用 MLModelBinding 我们只需要关心最核心的实现,即 processs 的调用流程,根据模板代码提供的输出,可以方便实现具体的业务流程。
可以看看任意风格转换的效果
| 示例1 | 示例2 | 示例3 |
|---|---|---|
可以看到任意风格转换模型,输出结果的大小是固定的。这一点,我们从 MLBinding 输出的模板文件就能看到
其实这里从输入就做了限制,模型输入只支持 384x383 的三通道彩色图片。而输出也是同样的格式。
总结
使用 MLModelBinding 这个特性,通过编译其生成的中间件封装的接口,屏蔽了底层繁琐的细节,简化了模型的使用流程,让开发者可以按照自己熟悉的内容去使用,而不用过分关心实现细节,提升了效率,也避免了无谓的 bug。