关于Keras在加载或保存含有自定义层的模型时遇到的问题的记录

1,592 阅读3分钟

之前在训练好模型之后,都会用model.save()直接保存模型,保存的类型都是.h5格式,也并没有觉得有什么不妥。直到今天在保存一个含有自定义层(Attention)的模型的时候发现报错了,很是疑惑,看了一眼报错信息,发现重写什么get_config()from_config()方法,我跟个没头苍蝇一样鼓捣了半天也没弄好,经过我的不懈努力瞎折腾之下总算是保存下来了,不过不是.h5格式的,是一个文件夹,里面包含了各种信息。很开心!结果在加载模型做F1-Score测试的时候发现又出错误了!!!
ValueError: Unable to restore custom object of type _tf_keras_metric. Please make sure that any custom layers are included in the `custom_objects` arg when calling `load_model()` and make sure that all layers implement `get_config` and `from_config`.

属实想吐。

0、保存和加载Keras模型

这里先解释一下保存模型的方法吧。 如同前文所说,以前一直用model.save('model_name.h5')来保存模型,今天保存模型出错之后才想起自己还没有研究过Keras的保存方法。

以下内容来源于keras官方文档:

    Keras 模型由多个组件组成:

        -   架构或配置,它指定模型包含的层及其连接方式。
        -   一组权重值(即“模型的状态”)。
        -   优化器(通过编译模型来定义)。
        -   一组损失和指标(通过编译模型或通过调用 `add_loss()``add_metric()` 来定义)。

        您可以通过 Keras API 将这些片段一次性保存到磁盘,或仅选择性地保存其中一些片段:

        -   将所有内容以 TensorFlow SavedModel 格式(或较早的 Keras H5 格式)保存到单个归档。这是标准做法。
        -   仅保存架构/配置,通常保存为 JSON 文件。
        -   仅保存权重值。通常在训练模型时使用。
       

        保存和加载整个模型

        您可以将整个模型保存到单个工件中。它将包括:

        -   模型的架构/配置
        -   模型的权重值(在训练过程中学习)
        -   模型的编译信息(如果调用了 `compile()`        -   优化器及其状态(如果有的话,使您可以从上次中断的位置重新开始训练)

API

  • model.save()tf.keras.models.save_model()
  • tf.keras.models.load_model()

可以使用两种方法将整个模型保存到本地,TensorFlow SavedModel 格式较早的 Keras H5 格式。Keras推荐使用 SavedModel 格式。它是使用 model.save() 时的默认格式。

如果想要保存为.h5格式,可以通过以下两种方法:

  • save_format='h5' 传递给 save()
  • 将以 .h5.keras 结尾的文件名传递给 save()

如果使用的是model.save('my_model'),在程序执行完毕后会生成一个名为my_model的文件夹,其包含以下内容:

asserts  saved_model.pb  variables

模型架构和训练配置(包括优化器、损失和指标)存储在 saved_model.pb 中。权重保存在 variables/ 目录下。

SavedModel 处理自定义对象的方式

保存模型和模型的层时,SavedModel 格式会存储类名称、调用函数、损失和权重(如果已实现,则还包括配置)。调用函数会定义模型/层的计算图。

如果没有模型/层配置,调用函数会被用来创建一个与原始模型类似的模型,该模型可以被训练、评估和用于推断。

尽管如此,在编写自定义模型或层类时,对 get_configfrom_config 方法进行定义始终是一种好的做法。这样您就可以稍后根据需要轻松更新计算。有关详细信息,请参阅自定义对象

1、加载模型报错

ValueError: Unable to restore custom object of type _tf_keras_metric. Please make sure that any custom 
layers are included in the `custom_objects` arg when calling `load_model()` and make sure that all 
layers implement `get_config` and `from_config`.

【解决方法】
如果要加载的模型包含自定义层或其他自定义类或函数,则可以通过 custom_objects 参数将它们传递给加载机制:

model_attention_bilstm = tf.keras.models.load_model('models/attention_bilstm', custom_objects={'metric_F1score': metric_F1score})

因为在该模型中有自定义的模型损失函数metric_F1score,所以在加载模型时需要加上custom_objects={'metric_F1score': metric_F1score}
更多信息移步Tensorflow-保存和加载 Keras 模型