本文翻译自: 《Building a neural network training framework with learn API》, 如有侵权请联系删除,仅限于学术交流,请勿商用。如有谬误,请联系指出。
为了简单起见,在之前的大多数示例中,我们都是手动创建一个会话(session),并不关心保存和加载检查点,但在实践中通常不是这样做的。在这我推荐你使用learn API来进行会话管理和日志记录(session management and logging)。我们使用TensorFlow提供了一个简单而实用的框架来训练神经网络。在这一节中,我们将解释这个框架是如何工作的。
当利用神经网络训练模型进行实验时,通常需要分割训练集和测试集。你需要利用训练集训练你的模型,并在测试集中计算一些指标来评估模型的好坏。你还需要将模型参数存储为一个检查点(checkpoint),因为你需要可以随时停止并重启训练过程。TensorFlow的learn API旨在简化这项工作,使我们能够专注于开发实际模型。
使用tf.learnAPI的最简单的方式是直接使用tf.Estimator对象。你需要定义一个模型函数,该模型函数包含一个损失函数(loss function)、一个训练操作(train op)、一个或一组预测,以及一组可选的用于评估的度量操作:
import tensorflow as tf
def model_fn(features, labels, mode, params):
predictions = ...
loss = ...
train_op = ...
metric_ops = ...
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
loss=loss,
train_op=train_op,
eval_metric_ops=metric_ops)
params = ...
run_config = tf.contrib.learn.RunConfig(model_dir=FLAGS.output_dir)
estimator = tf.estimator.Estimator(
model_fn=model_fn, config=run_config, params=params)
要训练模型,你只需调用Estimator.train()函数,同时提供一个输入函数来读取数据即可:
def input_fn():
features = ...
labels = ...
return features, labels
estimator.train(input_fn=input_fn, max_steps=...)
如果想要评估模型,只需要调用Estimator.evaluate():
estimator.evaluate(input_fn=input_fn)
对于一些简单的情况,Estimator对象就已经足够应付了,但是TensorFlow还提供了一个更高级别的对象,称为**Experiment** ,它提供了一些额外的实用功能。创建一个experiment对象非常简单:
experiment = tf.contrib.learn.Experiment(
estimator=estimator,
train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn)
现在我们可以调用train_and_evaluate函数来计算训练时的指标:
experiment.train_and_evaluate()
运行experiment的另一种更为高级的方法是使用learn_runner.run()函数。下面是我们在框架中提供的主要功能:
import tensorflow as tf
tf.flags.DEFINE_string("output_dir", "", "Optional output dir.")
tf.flags.DEFINE_string("schedule", "train_and_evaluate", "Schedule.")
tf.flags.DEFINE_string("hparams", "", "Hyper parameters.")
FLAGS = tf.flags.FLAGS
def experiment_fn(run_config, hparams):
estimator = tf.estimator.Estimator(
model_fn=make_model_fn(),
config=run_config,
params=hparams)
return tf.contrib.learn.Experiment(
estimator=estimator,
train_input_fn=make_input_fn(tf.estimator.ModeKeys.TRAIN, hparams),
eval_input_fn=make_input_fn(tf.estimator.ModeKeys.EVAL, hparams))
def main(unused_argv):
run_config = tf.contrib.learn.RunConfig(model_dir=FLAGS.output_dir)
hparams = tf.contrib.training.HParams()
hparams.parse(FLAGS.hparams)
estimator = tf.contrib.learn.learn_runner.run(
experiment_fn=experiment_fn,
run_config=run_config,
schedule=FLAGS.schedule,
hparams=hparams)
if __name__ == "__main__":
tf.app.run()
调度标志(schedule flag)决定Experiment对象的哪个成员函数被调用。因此,如果你将schedule设置为“train_and_evaluate”,experiment.train_and_evaluate()这个函数将会被调用。
def input_fn():
features = ...
labels = ...
return features, labels
有关如何使用数据集API读取数据的示例,请参见mnist .py。要了解在TensorFlow中读取数据的各种方法,可以参考这段代码。
该框架还提供了一个简单的卷积网络分类器,详见alexnet.py,其中包括一个示例模型。
这就是开始使用TensorFlow learn API所需要的全部内容。我建议查看框架源码并查看官方python API,以了解更多关于learn API的信息。