之前在训练好模型之后,都会用model.save()直接保存模型,保存的类型都是.h5格式,也并没有觉得有什么不妥。直到今天在保存一个含有自定义层(Attention)的模型的时候发现报错了,很是疑惑,看了一眼报错信息,发现重写什么get_config()和from_config()方法,我跟个没头苍蝇一样鼓捣了半天也没弄好,经过我的不懈努力瞎折腾之下总算是保存下来了,不过不是.h5格式的,是一个文件夹,里面包含了各种信息。很开心!结果在加载模型做F1-Score测试的时候发现又出错误了!!!
ValueError: Unable to restore custom object of type _tf_keras_metric. Please make sure that any custom layers are included in the `custom_objects` arg when calling `load_model()` and make sure that all layers implement `get_config` and `from_config`.。
属实想吐。
0、保存和加载Keras模型
这里先解释一下保存模型的方法吧。
如同前文所说,以前一直用model.save('model_name.h5')来保存模型,今天保存模型出错之后才想起自己还没有研究过Keras的保存方法。
以下内容来源于keras官方文档:
Keras 模型由多个组件组成:
- 架构或配置,它指定模型包含的层及其连接方式。
- 一组权重值(即“模型的状态”)。
- 优化器(通过编译模型来定义)。
- 一组损失和指标(通过编译模型或通过调用 `add_loss()` 或 `add_metric()` 来定义)。
您可以通过 Keras API 将这些片段一次性保存到磁盘,或仅选择性地保存其中一些片段:
- 将所有内容以 TensorFlow SavedModel 格式(或较早的 Keras H5 格式)保存到单个归档。这是标准做法。
- 仅保存架构/配置,通常保存为 JSON 文件。
- 仅保存权重值。通常在训练模型时使用。
保存和加载整个模型
您可以将整个模型保存到单个工件中。它将包括:
- 模型的架构/配置
- 模型的权重值(在训练过程中学习)
- 模型的编译信息(如果调用了 `compile()`)
- 优化器及其状态(如果有的话,使您可以从上次中断的位置重新开始训练)
API
model.save()或tf.keras.models.save_model()tf.keras.models.load_model()
可以使用两种方法将整个模型保存到本地,TensorFlow SavedModel 格式和较早的 Keras H5 格式。Keras推荐使用 SavedModel 格式。它是使用 model.save() 时的默认格式。
如果想要保存为.h5格式,可以通过以下两种方法:
- 将
save_format='h5'传递给save()。 - 将以
.h5或.keras结尾的文件名传递给save()。
如果使用的是model.save('my_model'),在程序执行完毕后会生成一个名为my_model的文件夹,其包含以下内容:
asserts saved_model.pb variables
模型架构和训练配置(包括优化器、损失和指标)存储在 saved_model.pb 中。权重保存在 variables/ 目录下。
SavedModel 处理自定义对象的方式
保存模型和模型的层时,SavedModel 格式会存储类名称、调用函数、损失和权重(如果已实现,则还包括配置)。调用函数会定义模型/层的计算图。
如果没有模型/层配置,调用函数会被用来创建一个与原始模型类似的模型,该模型可以被训练、评估和用于推断。
尽管如此,在编写自定义模型或层类时,对 get_config 和 from_config 方法进行定义始终是一种好的做法。这样您就可以稍后根据需要轻松更新计算。有关详细信息,请参阅自定义对象。
1、加载模型报错
ValueError: Unable to restore custom object of type _tf_keras_metric. Please make sure that any custom
layers are included in the `custom_objects` arg when calling `load_model()` and make sure that all
layers implement `get_config` and `from_config`.
【解决方法】
如果要加载的模型包含自定义层或其他自定义类或函数,则可以通过 custom_objects 参数将它们传递给加载机制:
model_attention_bilstm = tf.keras.models.load_model('models/attention_bilstm', custom_objects={'metric_F1score': metric_F1score})
因为在该模型中有自定义的模型损失函数metric_F1score,所以在加载模型时需要加上custom_objects={'metric_F1score': metric_F1score}
更多信息移步Tensorflow-保存和加载 Keras 模型