用TensorFlow Lite搜索器库进行设备上的文本到图像搜索

244 阅读6分钟

发布者:软件工程师李宗霖、王璐、Maxime Brénon和李玉琪

今天,我们很高兴地宣布一个新的基于设备嵌入的搜索库,它可以让你在几毫秒内从数以百万计的数据样本中快速找到相似的图像、文本或音频。

它的工作原理是使用一个模型将搜索查询嵌入到代表查询的语义的高维向量中。然后,它使用ScaNN(可扩展最近的邻居)从预定义的数据库中搜索类似的项目。为了将其应用于你的数据集,你需要使用Model Maker Searcher API(视觉/文本)来建立一个自定义的TFLite Searcher模型,然后使用任务库Searcher API(视觉/文本)将其部署到设备上。

例如,使用在COCO上训练的搜索器模型,搜索查询,A passenger plane on the runway ,将返回以下图像。

图1:所有图片都来自COCO 2014训练和验证数据集。图片1由Mark Jones Jr.根据Attribution License提供。图片2由305 Seahill根据Attribution-NoDerivs License提供。图片3由tataquax在署名-相同方式共享许可下提供。

在这篇文章中,我们将引导你通过一个端到端的例子,使用新的TensorFlow Lite搜索器库构建一个文本到图像的搜索功能(检索给定的文本查询的图像)。以下是主要的步骤。

  1. 使用COCO数据集为图像和文本查询编码训练一个双编码器模型。
  2. 使用Model Maker Searcher API创建一个文本到图像的搜索器模型。
  3. 使用任务库搜索器API,用文本查询检索图像。

训练一个双编码器模型

图2:用点积相似性距离训练双编码器模型。损失鼓励相关的图像和文本有更大的点积(阴影中的绿色方块)。

双重编码器模型由一个图像编码器和一个文本编码器组成。这两个编码器分别将图像和文本映射到高维空间的嵌入中。该模型计算图像和文本嵌入之间的点积,损失鼓励相关的图像和文本有较大的点积(更接近),不相关的有较小的点积(相距更远)。

训练过程受到CLIP论文和这个Keras例子的启发。图像编码器是基于一个预先训练好的EfficientNet模型,而文本编码器是基于一个预先训练好的Universal Sentence Encoder模型。然后,两个编码器的输出被投射到一个128维的空间,并被L2归一化。对于数据集,我们选择使用COCO,因为它的训练和验证部分的每张图片都有人类生成的标题。请看一下配套的Colab笔记本,了解训练过程的细节。

双重编码器模型使得从数据库中检索没有标题的图像成为可能,因为一旦经过训练,图像嵌入器可以直接从图像中提取语义,而不需要人类生成的标题。

使用Model Maker创建文本-图像搜索器模型

图3:使用图像编码器生成图像嵌入,并使用Model Maker来创建TFLite Searcher模型。

一旦双编码器模型被训练好,我们就可以用它来创建TFLite搜索器模型,它可以根据文本查询从图像数据集中搜索出最相关的图像。这可以通过以下三个步骤完成。

  1. 使用TensorFlow图像编码器生成图像数据集的嵌入。ScaNN能够搜索非常大的数据集,因此我们将2014年COCO的训练和验证部分结合起来,共12.3万多张图片,以证明其能力。请看这里的代码。
  2. 将TensorFlow文本编码器模型转换成TFLite格式。请看这里的代码。
  3. 使用Model Maker从TFLite文本编码器和图像嵌入中使用下面的代码创建TFLite Searcher模型。
# Configure ScaNN options. See the API doc (todo: link) for how to configure ScaNN. scann_options = searcher.ScaNNOptions(      distance_measure='dot_product',      tree=searcher.Tree(num_leaves=351, num_leaves_to_search=4),      score_ah=searcher.ScoreAH(1, anisotropic_quantization_threshold=0.2))# Load the image embeddings and corresponding metadata if any.data = searcher.DataLoader(tflite_embedder_path, image_embeddings, metadata)# Create the TFLite Searcher model.model = searcher.Searcher.create_from_data(data, scann_options)# Export the TFLite Searcher model.model.export(      export_filename='searcher.tflite',      userinfo='',      export_format=searcher.ExportFormat.TFLITE)

当创建搜索器模型时,Model Maker利用ScaNN来索引嵌入向量。嵌入数据集首先被划分为多个子集。在每个子集中,ScaNN存储嵌入向量的量化表示。在检索时,ScaNN选择几个最相关的分区,并以快速、近似的距离对量化的表示进行评分。这个过程既节省了模型的大小(通过量化),又实现了速度的提高(通过分区选择)。请看深入检查,了解更多关于ScaNN算法的信息。

在上面的例子中,我们将数据集分为351个分区(大约是我们拥有的嵌入数量的平方根),并在检索过程中搜索其中的4个分区,这大约是数据集的1%。我们还将128维的浮点嵌入量化为128个int8值,以节省空间。

使用任务库运行推理

图4:使用任务库和TFLite搜索器模型运行推理。它接受查询文本并返回顶部邻居的元数据。从那里我们可以找到相应的图像。

要使用搜索器模型查询图像,你只需要像下面这样使用任务库的几行代码。

from tflite_support.task import text# Initialize a TextSearcher objectsearcher = text.TextSearcher.create_from_file('searcher.tflite')# Search the input queryresults = searcher.search(query_text)# Show the resultsfor rank in range(len(results.nearest_neighbors)):  print('Rank #', rank, ':')  image_id = results.nearest_neighbors[rank].metadata  print('image_id: ', image_id)  print('distance: ', results.nearest_neighbors[rank].distance)  show_image_by_id(image_id)

试试Colab中的代码。另外,请看更多关于如何使用任务库Java和C++ API整合模型的信息,特别是在Android上。一般来说,每次查询在Pixel 6上只需要6毫秒。

下面是一些例子的结果。

查询。A man riding a bike

结果根据近似的相似性距离进行排序。这里是一个检索到的图片的样本。请注意,我们只在图片的许可证允许的情况下显示图片。

图5:所有的图片都来自COCO 2014训练和验证数据集。图片1由Reuel Mark Delez根据署名许可制作。图片2:Richard Masoner / Cyclelicious在署名-相同方式共享许可下拍摄。图片3由Julia在署名-相同方式共享许可下拍摄。图片4:Aaron Fulkerson在Attribution-ShareAlike License下拍摄。图片5:Richard Masoner / Cyclelicious 在Attribution-ShareAlike License下拍摄。图片6:Richard Masoner / Cyclelicious采用署名-相同方式共享许可

鸣谢

我们要感谢Khanh LeViet, Chuo-Ling Chang, Ruiqi Guo, Lawrence Chan, Laurence Moroney, Yu-Cheng Ling, Matthias Grundmann, 以及Robby Neale, Chung-Ching Chang 和 Khalid Salama对这项工作的积极支持。我们也要感谢整个ScaNN团队。David Simcha, Erik Lindgren, Felix Chern, Phil Sun和Sanjiv Kumar。