tf.split( )和tf.unstack( )

import tensorflow as tf

A = [[1, 2, 3], [4, 5, 6]]
a0 = tf.split(A, num_or_size_splits=3, axis=1)#不改变维数(!!)
a1 = tf.unstack(A, num=3,axis=1)
a2 = tf.split(A, num_or_size_splits=2, axis=0)
a3 = tf.unstack(A, num=2,axis=0)
with tf.Session() as sess:
    print(sess.run(a0))
    print(sess.run(a1))
    print(sess.run(a2))
    print(sess.run(a3))

[array([[1],[4]]), array([[2],[5]]), array([[3],[6]])]

[array([1, 4]), array([2, 5]), array([3, 6])]
[array([[1, 2, 3]]), array([[4, 5, 6]])]
[array([1, 2, 3]), array([4, 5, 6])]

posted on 2017-12-17 21:32  mdumpling  阅读(3386)  评论(0编辑  收藏  举报

导航