tensorflow模型的持久化:保存与测试pb文件

1,621 阅读2分钟

介绍
最近在做模型的量化,量化的模型是人脸检测网络mtcnn,我从Onet开始入手,原先这个模型使用的权重文件是ckpt,这种存储格式适合训练,如果要做量化的话,需要先转化为pb文件,把其中的变量都持久化。再进一步做量化
生成的思路是给加载ckpt文件的onet网络导入一张48x48的人头图像,输出softmax值和box数值,再把网络加载方式换成生成的pb文件,再送一样的一幅图进去,查看输出结果,一样则转化成功。然后接下来就可以在生成的pb文件上做int8量化。

pb文件
pb是protocol(协议) buffer(缓冲)的缩写。TensorFlow训练模型后存成的pb文件,是一种表示模型(神经网络)结构的二进制文件,不带有源代码。
pb文件中可以只存参数,也可以存参数加网络结构,我们这里要生成的是存参数+网络结构,这样在推断的时候,可以不用重新在代码中定义网络结构,直接送入图像就可以输出结果,很方便。google现在也推荐这种文件格式。

把模型保存成pb文件
我们在原网络中加载ckpt模型,然后回复成sess,再从sess保存到pb文件
代码如下:

import sys
import argparse
import time
import os
os.environ['CUDA_VISIBLE_DEVICES']='3'
import tensorflow as tf
import cv2
import numpy as np
from tensorflow.python.framework import graph_util
from src.mtcnn import PNet, RNet, ONet
from tools import detect_face, get_model_filenames

def main(args):
out_pb_path="onet_trained2.pb"
img = cv2.imread(args.image_path)
img48 = (img - 127.5) * (1. / 128.0)
img_x = np.expand_dims(img48, 0)
file_paths = get_model_filenames(args.model_dir)
with tf.device('/gpu:3'):
with tf.Graph().as_default():
config = tf.ConfigProto(allow_soft_placement=True)
# 指定输出的节点名称,该节点名称必须是原模型中存在的节点
with tf.Session(config=config) as sess:
if len(file_paths) == 3:
image_onet = tf.placeholder(tf.float32, [None, 48, 48, 3])
onet = ONet({'data': image_onet}, mode='test')
out_tensor_onet = onet.get_all_output()
saver_onet = tf.train.Saver(
[v for v in tf.global_variables()
if v.name[0:5] == "onet/"])
saver_onet.restore(sess, file_paths[2])
sess.run(out_tensor_onet, feed_dict={image_onet: img_x})
graph = tf.get_default_graph() # 获得默认的图
input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图
# for op in graph.get_operations():
# print(op.name, op.values())

output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
sess = sess,
input_graph_def = input_graph_def,# 等于:sess.graph_def input_graph_def
output_node_names = ['softmax/softmax','onet/conv6-2/onet/conv6-2'])# 如果有多个输出节点,以逗号隔开
with tf.gfile.GFile(out_pb_path, "wb") as f: #保存模型
f.write(output_graph_def.SerializeToString()) #序列化输出
print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点

代码中最关键的是输出节点名称的确定,只要写对了程序基本没有问题,我在这一块卡了好久。查节点的方法有直接看原网络的输出节点名称、可视化工具tensorflow、netron。我使用的是netron,很方便,在网页中上次模型文件即可

更多学习资料可关注:gzitcast获取