tensorflow TensorArray 代码例子

 1 import tensorflow as tf 
 2 import numpy as np 
 3 
 4 B=3
 5 D=4
 6 T=5
 7 
 8 tf.reset_default_graph()
 9 xs=tf.placeholder(shape=[T,B,D],dtype=tf.float32)
10 
11 with tf.variable_scope('rnn'):
12     GRUcell = tf.nn.rnn_cell.GRUCell(num_units=D)
13     cell = tf.nn.rnn_cell.MultiRNNCell([GRUcell])
14 
15     output_ta = tf.TensorArray(size=T,dtype=tf.float32)
16     input_ta = tf.TensorArray(size=T,dtype=tf.float32)
17     input_ta = input_ta.unstack(xs)
18 
19     def body(time,output_ta_t,state):
20         xt = input_ta.read(time)
21         new_output,new_state = cell(xt,state)
22         output_ta_t = output_ta_t.write(time, new_output)
23         return (time+1,output_ta_t,new_state)
24 
25     def condition(time,output,state):
26         return time<T
27     
28     time=0
29     state=cell.zero_state(B,tf.float32)
30     time_final,output_ta_final,state_final=tf.while_loop(cond=condition,body=body,loop_vars=(time,output_ta,state))
31     output_final = output_ta_final.stack()
32 
33 x=np.random.randn(T,B,D)
34 with tf.Session() as sess:
35     sess.run(tf.global_variables_initializer())
36     output_final_,state_final_=sess.run([output_final,state_final],feed_dict={xs:x})
 1 import tensorflow as tf
 2 tf.enable_eager_execution()
 3 
 4 def condition(time,max_time, output_ta_l):
 5     return tf.less(time, max_time)
 6 
 7 def body(time,max_time, output_ta_l):
 8     output_ta_l = output_ta_l.write(time, [2.4, 3.5])
 9     return time + 1, max_time,output_ta_l
10 
11 max_time=tf.constant(3)
12 time = tf.constant(0)
13 output_ta = tf.TensorArray(dtype=tf.float32, size=1, dynamic_size=True)
14 result = tf.while_loop(condition, body, loop_vars=[time,max_time,output_ta])
15 last_time,max_time, last_out = result
16 final_out = last_out.stack()
17 
18 
19 print(last_time)
20 print(final_out)
21 
22 
23 '''
24 ta.stack(name=None) 将TensorArray中元素叠起来当做一个Tensor输出
25 ta.unstack(value, name=None) 可以看做是stack的反操作,输入Tensor,输出一个新的TensorArray对象
26 ta.write(index, value, name=None) 指定index位置写入Tensor
27 ta.read(index, name=None) 读取指定index位置的Tensor
28 作者:加勒比海鲜 
29 原文:https://blog.csdn.net/guolindonggld/article/details/79256018 
30 '''

 TensorArray可以看做是具有动态size功能的Tensor数组。通常都是跟while_loop或map_fn结合使用

tips:[n.name for n in tf.get_default_graph().as_graph_def().node]获取图中所有节点

posted @ 2018-11-06 11:05  阿夏z  阅读(680)  评论(0编辑  收藏  举报