while 循环
def while_loop(cond, ### 一个函数,负责判断循环是否进行 body, ### 一个函数,循环体,更新变量 loop_vars, ### 初始循环变量,可以是多个,这些变量是 cond、body 的输入 和输出 shape_invariants=None, parallel_iterations=10, back_prop=True, swap_memory=False, name=None, maximum_iterations=None, return_same_structure=False):
返回 迭代后的 loop_vars
def cond(i, n): return i < n def body(i, n): i = i + 1 return i, n i = tf.get_variable("ii", dtype=tf.int32, shape=[], initializer=tf.ones_initializer()) # i = 1 # 也可以 # i = tf.constant(1) # 也可以 n = tf.constant(10) i, n = tf.while_loop(cond, body, [i, n]) with tf.Session() as sess: tf.global_variables_initializer().run() res = sess.run([i, n]) print(res) # [10, 10]
注意:cond 和 body 的输入和输出要相同,且等于 loop_vars,即使在函数中没有用到全部的 loop_vars,也要做为输入和输出
参考资料: