小知识,大挑战!本文正在参与“程序员必备小知识”创作活动。
Keras是由纯python编写的基于theano/tensorflow的深度学习框架。Keras是一个高层神经网络API,支持快速实验,能够把你的idea迅速转换为结果。 当我们通过代码去理解一个模型的网络结构的时候,对于比较复杂的结构不太好理解,但如果将这个结构以图片的形式展示,对于我们能比较直观、快速的理解,这篇文章就利用Keras框架绘制出Bi-LSTM模型的网络结构。
一、前期相关准备
1、安装pydot
pip install pydot
2、安装graphviz
graphviz需要在官网安装:graphvizgraphviz.org/
安装后需要添加程序所在目录的bin文件夹加入系统变量
二、编写代码
1、导入相关包
load_model:用于加载网络模型
CRF:网络模型中存在CRF模型层
plot_model:生成网络模型结构,并将其保存为图片
pyplot :加载网络模型结构图片
from keras.models import load_model
from keras_contrib.layers import CRF
from keras.utils.vis_utils import plot_model
import matplotlib.pyplot as plt
2、生成网络模型结构
plot_model接口参数:
to_file:网络模型结构图片存储路径和名称
show_shapes:是否显示形状(神经层输入和输出)
show_layer_names:是否显示神经层的名称
rankdir:神经层之间的方向,“TB”代表上下,“LR”代表左右
model_path = "./model/ch_ner_model.h5"
# 模型文件
model = load_model(model_path, custom_objects={'CRF': CRF}, compile=False)
plot_model(model,to_file='./model/nerbilstm.png',show_shapes=True,show_layer_names='False',rankdir="TB")
3、加载网络模型结构
使用matplotlib包中的pyplot方法,将生成的网络模型结构图片加载出来。
plt.figure(figsize=(10,10))
img = plt.imread("./model/nerbilstm.png")
plt.imshow(img)
plt.axis("off")
plt.show()