【译】Effective TensorFlow Chapter13——在TensorFlow中利用learn API构建神经网络框架
本文翻译自: 《Building a neural network training framework with learn API》, 如有侵权请联系删除, 仅限于学术交流, 请勿商用。 如有谬误, 请联系指出。 为了简单起见, 在之前的大多数示例中, 我们都是手动创建一个会话(session), 并不关心保存和加载检查点, 但在实践中通常不是这样做的。 在这我推荐你使用 learn API 来进行会话管理和日志记录(session management and logging)。 我们使用 TensorFlow 提供了一个简单而实用的框架来训练神经网络。 在这一节中, 我们将解释这个框架是如何工作的。 当利用神经网络训练模型进行实验时, 通常需要分割训练集和测试集。 你需要利用训练集训练你的模型, 并在测试集中计算一些指标来评估模型的好坏。 你还需要将模型参数存储为一个检查点(checkpoint), 因为你需要可以随时停止并重启训练过程。 TensorFlow 的 learn API 旨在简化这项工作, 使我们能够专注于开发实际模型。 使用 tf.learn API 的最简单的方式是直接使用 tf.Estimator 对象。 你需要定义一个模型函数, 该模型函数包含一个损失函数(loss function)、 一个训练操作(train op)、 一个或一组预测, 以及一组可选的用于评估的度量操作: import tensorflow as tf def model_fn(features, labels, mode, params): predictions = ....