本文翻译自: 《Control flow operations: conditionals and loops》, 如有侵权请联系删除, 仅限于学术交流, 请勿商用。 如有谬误, 请联系指出。
当我们在构建一个复杂模型如 RNN(循环神经网络)的时候, 你可能需要通过条件和循环来控制操作流程。 在这一节, 我们介绍一些在 TensorFlow 中常用的控制流。
假设我们现在需要通过一个条件判断来决定我们是否相加还是相乘两个变量。 这个可以通过调用 tf.cond()
简单实现, 它表现出像 python 中 if...else...
相似的功能。
a = tf.constant(1)
b = tf.constant(2)
p = tf.constant(True)
x = tf.cond(p, lambda: a + b, lambda: a * b)
print(tf.Session().run(x))
因为这个条件判断为 True, 所以这个输出应该是加法输出, 也就是输出 3。
在使用 TensorFlow 的过程中, 大部分时间你都会使用大型的张量, 并且在一个批次(a batch)中进行操作。 一个与之相关的条件操作符是 tf.where()
, 它需要提供一个条件判断, 就和 tf.cond()
一样, 但是 tf.where()
将会根据这个条件判断, 在一个批次中选择输出, 如:
a = tf.constant([1, 1])
b = tf.constant([2, 2])
p = tf.constant([True, False])
x = tf.where(p, a + b, a * b)
print(tf.Session().run(x))
返回的结果是 [3, 2]
。
另一个广泛使用的控制流操作是 tf.while_loop()
。 它允许在 TensorFlow 中构建动态的循环, 这个可以实现对一个序列的变量进行操作。 让我们看看我们如何通过 tf.while_loops
函数生成一个斐波那契数列吧:
n = tf.constant(5)
def cond(i, a, b):
return i < n
def body(i, a, b):
return i + 1, b, a + b
i, a, b = tf.while_loop(cond, body, (2, 1, 1))
print(tf.Session().run(b))
这段代码将会返回 5。 tf.while_loop()
需要一个条件函数和一个循环体函数, 除此之外还需初始化循环变量。 这些循环变量在每一次循环体函数调用完之后都会被更新一次, 直到这个条件返回 False 为止。
现在想象我们想要保存这个斐波那契序列, 我们可能更新我们的循环体函数以纪录当前值的历史记录:
n = tf.constant(5)
def cond(i, a, b, c):
return i < n
def body(i, a, b, c):
return i + 1, b, a + b, tf.concat([c, [a + b]], 0)
i, a, b, c = tf.while_loop(cond, body, (2, 1, 1, tf.constant([1, 1])))
print(tf.Session().run(c))
当你尝试运行这个程序的时候, TensorFlow 将会“抱怨”说第四个循环变量的形状正在改变。 所以你必须指明这件事后有意为之的:
i, a, b, c = tf.while_loop(
cond, body, (2, 1, 1, tf.constant([1, 1])),
shape_invariants=(tf.TensorShape([]),
tf.TensorShape([]),
tf.TensorShape([]),
tf.TensorShape([None])))
这使得代码变得丑陋不堪, 而且效率极低。 请注意, 我们正在构建许多我们不使用的中间张量。 TensorFlow 对这种增长式的数组, 其实有一个更好的解决方案: tf.TensorArray
。 让我们用张量数组做同样的事情:
n = tf.constant(5)
c = tf.TensorArray(tf.int32, n)
c = c.write(0, 1)
c = c.write(1, 1)
def cond(i, a, b, c):
return i < n
def body(i, a, b, c):
c = c.write(i, a + b)
return i + 1, b, a + b, c
i, a, b, c = tf.while_loop(cond, body, (2, 1, 1, c))
c = c.stack()
print(tf.Session().run(c))
TensorFlow 中的 while_loop
和张量数组是构建复杂的循环神经网络(RNN)的基本工具。 作为练习, 你可以尝试使用 tf.while_loops
实现 beam search 。 你可以再尝试使用张量数组提高效率吗?