Tensorflow 学习笔记 -----tf.where
在之前版本对应函数tf.select
官方解释:
1 tf.where( input , name = None )` 2 Returns locations of true values in a boolean tensor. 3 4 This operation returns the coordinates of true elements in input . The coordinates are returned in a 2 - D tensor where the first dimension (rows) represents the number of true elements, and the second dimension (columns) represents the coordinates of the true elements. Keep in mind, the shape of the output tensor can vary depending on how many true values there are in input . Indices are output in row - major order. 5 6 For example: 7 # 'input' tensor is [[True, False] 8 # [True, False]] 9 # 'input' has two true values, so output has two coordinates. 10 # 'input' has rank of 2, so coordinates have two indices. 11 where( input ) = = > [[ 0 , 0 ], 12 [ 1 , 0 ]] 13 14 # `input` tensor is [[[True, False] 15 # [True, False]] 16 # [[False, True] 17 # [False, True]] 18 # [[False, False] 19 # [False, True]]] 20 # 'input' has 5 true values, so output has 5 coordinates. 21 # 'input' has rank of 3, so coordinates have three indices. 22 where( input ) = = > [[ 0 , 0 , 0 ], 23 [ 0 , 1 , 0 ], 24 [ 1 , 0 , 1 ], 25 [ 1 , 1 , 1 ], 26 [ 2 , 1 , 1 ]] |
有两种用法:
1、tf.where(tensor)
tensor 为一个bool 型张量,where函数将返回其中为true的元素的索引。如上图官方注释
2、tf.where(tensor,a,b)
a,b为和tensor相同维度的tensor,将tensor中的true位置元素替换为a中对应位置元素,false的替换为b中对应位置元素。
例:
1 2 3 4 5 6 7 8 9 | import tensorflow as tf import numpy as np sess = tf.Session() a = np.array([[ 1 , 0 , 0 ],[ 0 , 1 , 1 ]]) a1 = np.array([[ 3 , 2 , 3 ],[ 4 , 5 , 6 ]]) print (sess.run(tf.equal(a, 1 ))) print (sess.run(tf.where(tf.equal(a, 1 ),a1, 1 - a1))) |
>>[[true,false,false],[false,true,true]]
>>[[3,-1,-2],[-3,5,6]]
本文作者:love小酒窝
本文链接:https://www.cnblogs.com/lyc-seu/p/8565997.html
版权声明:本作品采用知识共享署名-非商业性使用-禁止演绎 2.5 中国大陆许可协议进行许可。
分类:
Tensorflow
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】凌霞软件回馈社区,博客园 & 1Panel & Halo 联合会员上线
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步