tf.where

tf.where

import tensorflow as tf

temp = tf.reshape(tf.range(0, 16) + tf.constant(1, shape=[16]), [4, 1, 2, 2])
category_index = tf.where(tf.greater(temp, 6))

with tf.Session() as sess:
    a = sess.run(category_index)
    temp = sess.run(tf.greater(temp, 6))
    print(temp)
    print(a)

解释

返回值为True的位置

[[[[False False]
  [False False]]]
[[[False False]
  [ True  True]]]
[[[ True  True]
  [ True  True]]]
[[[ True  True]
  [ True  True]]]]
[[1 0 1 0]
[1 0 1 1]
[2 0 0 0]
[2 0 0 1]
[2 0 1 0]
[2 0 1 1]
[3 0 0 0]
[3 0 0 1]
[3 0 1 0]
[3 0 1 1]]
posted @ 2018-10-18 22:23  overfitover  阅读(508)  评论(0编辑  收藏  举报