本文翻译自: 《Building a neural network training framework with learn API》, 如有侵权请联系删除, 仅限于学术交流, 请勿商用。 如有谬误, 请联系指出。
为了简单起见, 在之前的大多数示例中, 我们都是手动创建一个会话(session), 并不关心保存和加载检查点, 但在实践中通常不是这样做的。 在这我推荐你使用 learn API
来进行会话管理和日志记录(session management and logging)。 我们使用 TensorFlow 提供了一个简单而实用的框架来训练神经网络。 在这一节中, 我们将解释这个框架是如何工作的。
当利用神经网络训练模型进行实验时, 通常需要分割训练集和测试集。 你需要利用训练集训练你的模型, 并在测试集中计算一些指标来评估模型的好坏。 你还需要将模型参数存储为一个检查点(checkpoint), 因为你需要可以随时停止并重启训练过程。 TensorFlow 的 learn API 旨在简化这项工作, 使我们能够专注于开发实际模型。
使用 tf.learn
API 的最简单的方式是直接使用 tf.Estimator
对象。 你需要定义一个模型函数, 该模型函数包含一个损失函数(loss function)、 一个训练操作(train op)、 一个或一组预测, 以及一组可选的用于评估的度量操作:
import tensorflow as tf
def model_fn(features, labels, mode, params):
predictions = ...
loss = ...
train_op = ...
metric_ops = ...
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
loss=loss,
train_op=train_op,
eval_metric_ops=metric_ops)
params = ...
run_config = tf.contrib.learn.RunConfig(model_dir=FLAGS.output_dir)
estimator = tf.estimator.Estimator(
model_fn=model_fn, config=run_config, params=params)
要训练模型, 你只需调用 Estimator.train()
函数, 同时提供一个输入函数来读取数据即可:
def input_fn():
features = ...
labels = ...
return features, labels
estimator.train(input_fn=input_fn, max_steps=...)
如果想要评估模型, 只需要调用 Estimator.evaluate()
:
estimator.evaluate(input_fn=input_fn)
对于一些简单的情况, Estimator 对象就已经足够应付了, 但是 TensorFlow 还提供了一个更高级别的对象, 称为** Experiment
** , 它提供了一些额外的实用功能。 创建一个 experiment 对象非常简单:
experiment = tf.contrib.learn.Experiment(
estimator=estimator,
train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn)
现在我们可以调用 train_and_evaluate
函数来计算训练时的指标:
experiment.train_and_evaluate()
运行 experiment
的另一种更为高级的方法是使用 learn_runner.run()
函数。 下面是我们在框架中提供的主要功能:
import tensorflow as tf
tf.flags.DEFINE_string("output_dir", "", "Optional output dir.")
tf.flags.DEFINE_string("schedule", "train_and_evaluate", "Schedule.")
tf.flags.DEFINE_string("hparams", "", "Hyper parameters.")
FLAGS = tf.flags.FLAGS
def experiment_fn(run_config, hparams):
estimator = tf.estimator.Estimator(
model_fn=make_model_fn(),
config=run_config,
params=hparams)
return tf.contrib.learn.Experiment(
estimator=estimator,
train_input_fn=make_input_fn(tf.estimator.ModeKeys.TRAIN, hparams),
eval_input_fn=make_input_fn(tf.estimator.ModeKeys.EVAL, hparams))
def main(unused_argv):
run_config = tf.contrib.learn.RunConfig(model_dir=FLAGS.output_dir)
hparams = tf.contrib.training.HParams()
hparams.parse(FLAGS.hparams)
estimator = tf.contrib.learn.learn_runner.run(
experiment_fn=experiment_fn,
run_config=run_config,
schedule=FLAGS.schedule,
hparams=hparams)
if __name__ == "__main__":
tf.app.run()
调度标志(schedule flag)决定 Experiment
对象的哪个成员函数被调用。 因此, 如果你将 schedule 设置为 “train_and_evaluate”
, experiment.train_and_evaluate()
这个函数将会被调用。
def input_fn():
features = ...
labels = ...
return features, labels
有关如何使用数据集 API 读取数据的示例, 请参见mnist .py。 要了解在 TensorFlow 中读取数据的各种方法, 可以参考这段代码。
该框架还提供了一个简单的卷积网络分类器, 详见alexnet.py, 其中包括一个示例模型。
这就是开始使用 TensorFlow learn API 所需要的全部内容。 我建议查看框架源码并查看官方 python API, 以了解更多关于 learn API 的信息。