tf.gather

gather就是按行取值:

a1 = [[1,2], [3, 4], [5, 6]]
a2 = tf.gather(tf.constant(a1), [0, 1])
print(a2)

输出:

tf.Tensor(
[[1 2]
 [3 4]], shape=(2, 2), dtype=int32)

相当于:

a1[:2]
posted @ 2020-11-26 19:27  oaksharks  阅读(220)  评论(0编辑  收藏  举报