10 tensorflow在循环体中用tf.print输出节点内容

代码

i=tf.constant(0,dtype=tf.int32)
batch_len=tf.constant(10,dtype=tf.int32)
loop_cond = lambda a,b: tf.less(a,batch_len)
#yy=tf.Print(batch_len,[batch_len],"batch_len:")
yy=tf.constant(0)
loop_vars=[i,yy]
def _recurrence(i,yy):
    c=tf.constant(2,dtype=tf.int32)
    x=tf.multiply(i,c)
    print_info=tf.Print(x,[x],"x:")
    yy=yy+print_info
    i=tf.add(i,1)
    return i,yy
i,yy=tf.while_loop(loop_cond,_recurrence,loop_vars,parallel_iterations=1)#可以批处理
sess = tf.Session()
sess.run(yy)

输出信息

为什么会这样,因为执行sess.run(yy)的时候,会有数据流过循环体中的所有tf.Print节点,此时就会执行tf.Print中指定的输出。最关键的操作就是yy=yy+print_info

存在的问题(与Spyder有关)

在spyder中使用调试模式的时候,无法输出上面的信息。

上面的代码是使用‘python 测试程序__在循环中使用tf.print.py’的方式在命令行执行才会输出。

如何不断的输出tf.Print信息

除了上述使用yy=yy+print_info。

如果print_info是这样的,比如:

print_info=tf.Print(constructionErrorMatrix,[constructionErrorMatrix],"constructionErrorMatrix:")#专门为了调试用,输出相关信息。
tfPrint=tfPrint+tf.to_int32(print_info[0])#一种不断输出tf.Print的方式,注意tf.Print的返回值。
constructionErrorMatrix是一个(?,)类型的float64 Tensor。我们可以用上述代码,继续进行tf.Print的输出。

此外,tf.Print中的第二个参数[]中放入的内容,也必须是能够转为Tensor。否则会提示

TypeError: Tensors in list passed to 'data' of 'Print' Op have types [<NOT CONVERTIBLE TO TENSOR>] that are invalid.

比如,一个Tensor的shape中如果有“?,就不能转换为Tensor。对于这种不能Tensor,我们不能用get_shape()[i].value去获取?的维度,但是我们可以用tf.shape获取有数据流入以后的动态维度。就是?最终确定的维度。

 

posted @ 2018-12-05 20:09  秦皇汉武  阅读(4409)  评论(0编辑  收藏  举报