本文将会介绍TensorFlow保存和恢复模型,主要讲解Saver类save保存和restore恢复方法,我们可以解决如何保存和恢复一个已经训练好的神经网络模型用于推理预测的现实需求,也可以辅助查看分析一个长时间训练的模型性能,最重要的是我们可以预防因长时间训练中途出现断电、宕机、出错退出等问题导致的训练功亏一篑问题!可见,掌握tensorflow保存和恢复模型的方法,对我们工程应用有多么大的帮助,同时,这也是我们必须要掌握的基础技能, 在tensorflow中保存和恢复模型主要通过tf.train.Saver(),具体如下:
- 保存模型
- saver = tf.train.Saver()获得一个文件句柄,将训练中的某一个快照状态保存到文件中去
- saver.save(sess, os.path.join(model_dir, ‘ckp-%05d’%(i+1))),将训练好的模型保存到文件中
- 参数1:sess,session会话,参数2:模型保存路径
- 恢复模型
- saver.restore(sess, model_path)从文件中恢复模型
- 参数1:sess,session会话;参数2:model_path需要被恢复模型
示例代码:
with tf.Session() as sess:
sess.run( init ) # 注意: 这一步必须要有!!
# 打开一个writer,向writer中写数据
train_writer = tf.summary.FileWriter(train_log_dir, sess.graph) # 参数2:显示计算图
test_writer = tf.summary.FileWriter(test_log_dir)
fixed_test_batch_data, fixed_test_batch_labels = test_data.next_batch(batch_size)
if os.path.exists(model_path + '.index'):
saver.restore(sess, model_path)
print('model restored from %s' % model_path)
else:
print('model %s dose not exist' % model_path)
# 开始训练
for i in range( train_steps ):
# 得到batch
batch_data, batch_labels = train_data.next_batch( batch_size )
eval_ops = [loss, accuracy, train_op]
should_output_summary = ((i+1) % output_summary_every_steps == 0)
if should_output_summary:
eval_ops.append(merged_summary)
# 获得 损失值, 准确率
eval_val_results = sess.run( eval_ops, feed_dict={x:batch_data, y:batch_labels} )
loss_val, acc_val = eval_val_results[0:2]
if should_output_summary:
train_summary_str = eval_val_results[-1]
train_writer.add_summary(train_summary_str, i+1)
test_summary_str = sess.run([merged_summary_test],
feed_dict = {x: fixed_test_batch_data,y: fixed_test_batch_labels} )[0]
test_writer.add_summary(test_summary_str, i+1)
# 每 500 次 输出一条信息
if ( i+1 ) % 500 == 0:
print('[Train] Step: %d, loss: %4.5f, acc: %4.5f' % ( i+1, loss_val, acc_val ))
# 每 5000 次 进行一次 测试
if ( i+1 ) % 5000 == 0:
# 获取数据集,但不随机
test_data = CifarData( test_filename, False )
all_test_acc_val = []
for j in range( test_steps ):
test_batch_data, test_batch_labels = test_data.next_batch( batch_size )
test_acc_val = sess.run( [accuracy], feed_dict={ x:test_batch_data, y:test_batch_labels } )
all_test_acc_val.append( test_acc_val )
test_acc = np.mean( all_test_acc_val )
print('[Test ] Step: %d, acc: %4.5f' % ( (i+1), test_acc ))
if (i+1) % output_model_every_steps == 0:
# saver 机制,保存最近的5个模型
saver.save(sess, os.path.join(model_dir, 'ckp-%05d'%(i+1)))
print('model saved to ckp-%05d' % (i+1))
参考: