10 tensorflow在循环体中用tf.print输出节点内容
代码
i=tf.constant(0,dtype=tf.int32) batch_len=tf.constant(10,dtype=tf.int32) loop_cond = lambda a,b: tf.less(a,batch_len) #yy=tf.Print(batch_len,[batch_len],"batch_len:") yy=tf.constant(0) loop_vars=[i,yy] def _recurrence(i,yy): c=tf.constant(2,dtype=tf.int32) x=tf.multiply(i,c) print_info=tf.Print(x,[x],"x:") yy=yy+print_info i=tf.add(i,1) return i,yy i,yy=tf.while_loop(loop_cond,_recurrence,loop_vars,parallel_iterations=1)#可以批处理 sess = tf.Session() sess.run(yy)
输出信息
为什么会这样,因为执行sess.run(yy)的时候,会有数据流过循环体中的所有tf.Print节点,此时就会执行tf.Print中指定的输出。最关键的操作就是yy=yy+print_info
存在的问题(与Spyder有关)
在spyder中使用调试模式的时候,无法输出上面的信息。
上面的代码是使用‘python 测试程序__在循环中使用tf.print.py’的方式在命令行执行才会输出。
如何不断的输出tf.Print信息
除了上述使用yy=yy+print_info。
如果print_info是这样的,比如:
print_info=tf.Print(constructionErrorMatrix,[constructionErrorMatrix],"constructionErrorMatrix:")#专门为了调试用,输出相关信息。 tfPrint=tfPrint+tf.to_int32(print_info[0])#一种不断输出tf.Print的方式,注意tf.Print的返回值。
constructionErrorMatrix是一个(?,)类型的float64 Tensor。我们可以用上述代码,继续进行tf.Print的输出。
此外,tf.Print中的第二个参数[]中放入的内容,也必须是能够转为Tensor。否则会提示
TypeError: Tensors in list passed to 'data' of 'Print' Op have types [<NOT CONVERTIBLE TO TENSOR>] that are invalid.
比如,一个Tensor的shape中如果有“?”,就不能转换为Tensor。对于这种不能Tensor,我们不能用get_shape()[i].value去获取?的维度,但是我们可以用tf.shape获取有数据流入以后的动态维度。就是?最终确定的维度。
你永远不知道未来会有什么,做好当下。技术改变世界,欢迎交流。