前端工程师如何排查和修复keras2onnx源码问题

avatar
阿里巴巴 前端委员会智能化小组 @阿里巴巴

文/ 阿里淘系 F(x) Team - 雷姆

很多人可能觉得做一个开源项目,自然是希望用户越多越好,但对于我来说,这并不是我期望的。

我更希望我正在做的开源项目Pipcook 是一本书,看过它的人,能从中学习到一些机器学习的知识。当然它有时候也是一本工具书,帮大家完成一些基本的任务,但并非前端智能化就一定得使用 Pipcook,最重要的是前端工程师们能真正掌握机器学习的能力来协助他们更高效的完成工作。

接下来,我就来分享一下,我——作为一个前端工程师,是如何从零开始,从算法同学获取一个模型,到通过 keras2onnx 将模型转成 onnx 格式的,最后再完成了一次对 keras2onnx 的一次开源贡献。

简单解释:ONNX 是一种算法模型的格式,有点类似于 JS 的 Bytecode 或者 WebAssembly,而 Keras 则有点像 TypeScript/CoffeeScript,属于一种前端方言。

故事的起源

事情要从一个遥不可及的需求说起,有一天某老师跟我说,算力就是未来,给我推荐了阿里云自家的 AliNPU 芯片,于是我就找到了 NPU 的接口人老师。

一开始可能因为我职位(前端)的关系,所以没能吸引到老师的注意,但是架不住我每天一 Ping,于是老师让我提供一个模型给他们去评估,于是我找到了某算法同学,他给了我一个开源的模型,我就“转发”给了 NPU 的接口人,但是他说这个模型他那边的工具解析不了,需要我给他 ONNX 的模型格式。

于是我就开始了我的模型转换之旅,由于工作(Pipcook)关系,我之前有简单了解过 ONNX 和 Keras 这些概念,不熟悉的可以自行谷歌一下。

模型转换初尝试

然后接下来讲讲模型转换是个什么概念,一般来说用 Keras 训练的模型会有一个自己的格式,就跟使用 TypeScript 或者 CoffeeScript 编译出来生成的 JS 代码是不同的(但表达的意思可能一样),然后模型转换呢,就是把 Keras 的模型的格式呢,转成 ONNX 这种通用的格式,换言之,就是把 TypeScript 或者 CoffeeScript 构建生成的 JS 代码转成一种更为通用的 WebAssembly 表示(例子确实不是很严谨,但意思就凑活着看吧)。

简单解释:Keras 是一个用 Python 编写的高级神经网络 API,它能够以 TensorFlow, CNTK, 或者 Theano 作为后端运行。Keras 的开发重点是支持快速的实验。

在了解了模型转换,一个 ONNX 与 Keras 七七八八之后,我就开始使用 GitHub 大法,果然搜到了一个由 ONNX 官方提供的 keras2onnx 库(一个 Python 库),于是我尝试按着例子,去将模型进行转换,结果我一运行就遇到了一个错误:

AssertionError: conv_7b_bn/cond/input_0:01is disconnected, check the parsing log for more details.

问题解决尝试方法之 —— Debug大法好

我看到,就很懵,这是啥错误,完全没有头绪,于是我问了一些算法同学,他们也不知道。

于是作为一个好学的前端工程师,不就是 Python 代码吗,跟 JavaScript 不也一样 Debug 吗,于是我通过错误堆栈信息,找到了出错的代码,发现是在 keras2onnx 库种有一个叫 remove_unused_nodes 的函数报的错,我认真一步步读了一下这个函数的代码,了解到这个函数是在转换结束后,用来把没有用到的一些节点剔除掉使用的,于是我就把一些关键的变量打印出来。

然后发现在表示一个模型的时候,其实就是一个一个节点组成,然后每个节点都有一个输入和输出,而上面错误中的 conv_7b_bn/cond/input_0:01 则是一个节点中输入的名称,在我发现这个规律后,我就查阅了一下 Keras 的文档,发现提供了 model.to_json() 的方法,可以把模型的信息转成 JSON,于是我通过这个方法 dump 成 json 文件来对比:

"class_name":"BatchNormalization",
"config":{
  "name":"conv_7b_bn",
  "trainable":true,
  "dtype":"float32",
  "axis":[
    3
  ],
  "momentum":0.99,
  "epsilon":0.001,
  "center":true,
  "scale":false,
  "beta_initializer":{
    "class_name":"Zeros",
    "config":{

    }
  },
}

但是我看了这个之后还是没有找到问题的原因,而且中间因为我看到 keras2onnx 自己的例子中,将 include_top 这个参数都是设置为 True,所以我一度怀疑是因为这个参数的原因导致的,但在细读了相关文档以及结合我们自己的模型代码后,发现这个参数并没有对结果产生任何影响。

问题解决尝试方法之 —— 源码阅读

此时我开始怀疑人生,在无奈之下,我继续谷歌大法,无果。好,最后我耐下性子,决定继续读一下 keras2onnx 到底是如何解析模型,以及转换代码的。

于是我开始通过加日志 Debug + 读部分代码的方式,一步步了解实现机制,大抵是通过遍历 Keras 的模型对象,然后根据不同类型的节点去转换和生成目标树来完成的,看到这里仍然没有发现有任何异常,但是我继续阅读代码,发现了一个叫 filter_out_input 的函数,其中有一行注释是这样的:

tf.keras BN layer sometimes create a placeholder node 'scale' in tf 2.x. It creates 'cond/input' since tf 2.2. Given bn layer will be converted in a whole layer, it's fine to just filter this node out.

这里的 BN 是指 BatchNormalization 层,具体函数在这里定义。大概的意思就是说 BN 层会创建一个变量,然后这些层在转换的过程中会出现在最外层,这样就导致了一开始的问题,在移除节点时,这个变量不会出现在模型中的任何输入/输出集中,就报错了。

于是这个函数所做的就是将这些节点过滤掉:

r"batch_normalization_\d+\/cond/input",

我发现上面的正则与 conv_7b_bn/cond/input_0:01 实在是很像,如果将 batch_normalization 换成 conv_7b_bn 就能匹配上了,于是我直接在过滤条件中加入 r"conv_7b_bn/cond/input",然后又执行了一次,结果是执行正常,生成了 ONNX 格式,然后我先把模型给了接口人老师,最后结果也是验证 OK。

回过头来,我想看看 conv_7b_bn 和 batch_normalization 到底什么关系,于是我找到 tensorflow 中 keras 的代码,发现:

x = conv2d_bn(x, 1536, 1, name='conv_7b')

然后我又看了下 conv2d_bn() 函数实现:

bn_name = Noneif name isNoneelse name + '_bn'
x = layers.BatchNormalization(axis=bn_axis, scale=False, name=bn_name)(x)

结合起来一看,就发现了原来是 Keras 中的模型实现中,给最后这一层起了一个别名,而如果没有设置这个别名,最终模型中的名字就是 batch_normalization 开头了,于是发现这个问题的源头就是:keras2onnx 没有考虑到 Keras 内置模型中别名的情况,于是我就给 keras2onnx 提了一个 Pull Request #567,然后最终在维护者的指导下添加了测试用例和改进代码后,被合并到了主分支。

总结一下

作为前端工程师遇到迷惑的时候,还是要深入阅读代码和 Debug,加深自己对系统的认识,千万不要觉得机器学习相关的代码就写得很好,或者就看不懂,很玄学,这是大忌。

比如这个例子中,其实就是一个正则导致的问题,深度学习再怎么难以理解,但是代码中还不是由 if-else、switch-case 以及我们熟知的语法组成的,关键还是要耐下心来,一步步顺藤摸瓜地发现问题,这才是作为工程师存在于此的优势。