tf.gather, tf.gather_nd和tf.slice
在tensorflow
和pytorch
中,都有一些函数是用于对张量进行索引的,因为pytorch
是基于动态图的,它的张量可以动态的进行操作,不需要在对话(session)中进行,因此可能采用这些方法的必要性没有那么足,但是因为tensorflow
是基于静态图的,因此如果在构建计算图的过程中需要对张量进行索引,一般是不能进行简单的slice
的,因此引入这些函数就很有必要了,我们这里集中介绍下这些函数。
tf.gather
tf.gather(
params, # 需要被索引的张量
indices, # 索引
validate_indices=None,
name=None,
axis=0
)
其作用很简单,就是根据提供的indices
在axis
这个轴上对params
进行索引,拼接成一个新的张量,其示意图如下所示:
其中的indices
可以是标量张量,向量张量或者是更高阶的张量,但是其元素都必须是整数类型,比如int32
,int64
等,而且注意检查不要越界了,因为如果越界了,如果使用的CPU
,则会报错,如果在GPU
上进行操作的,那么相应的输出值将会被置为0,而不会报错,因此认真检查是否越界。
(Note that on CPU, if an out of bound index is found, an error is returned. On GPU, if an out of bound index is found, a 0 is stored in the corresponding output value.)
给个代码例子:
params = tf.constant(np.random.normal(size=(5,6)))
index = tf.constant([1,3, 6])
out = tf.gather(params, index)
with tf.Session() as sess:
print(sess.run(out).shape)
print(out.eval())
print('-----------------------------------------')
print(params.eval())
# ------ output
(3, 6)
[[ 0.33662994 -0.33750725 1.5710436 0.78380586 -0.39254751 0.15269514]
[ 0.35278309 0.59719792 -0.86332759 -0.88997637 0.18247123 -1.46388103]
[ 0. 0. 0. 0. 0. 0. ]]
-----------------------------------------
[[ 0.98987616 -1.62505873 0.86819471 2.3316706 -1.17362956 0.57264237]
[ 0.33662994 -0.33750725 1.5710436 0.78380586 -0.39254751 0.15269514]
[ 0.62367272 0.83599086 1.60433217 -1.4376806 0.61179675 2.07843436]
[ 0.35278309 0.59719792 -0.86332759 -0.88997637 0.18247123 -1.46388103]
[-1.76421914 0.85159247 0.18397565 0.69460119 -0.0651779 0.14643597]]
可以看到越界的地方全部置为了0。
params = tf.constant(np.random.normal(size=(5,6)))
index = tf.constant([[[0,0],[2,0]]])
out = tf.gather(params, index)
with tf.Session() as sess:
print(sess.run(out).shape)
print(index.eval().shape)
print(out.eval())
print('-----------------------------------------')
print(params.eval())
# ----- output
(1, 2, 2, 6)
(1, 2, 2)
[[[[-1.52187772 -0.86557915 1.15487405 0.94154816 -0.55900061
0.28165203]
[-1.52187772 -0.86557915 1.15487405 0.94154816 -0.55900061
0.28165203]]
[[-0.18865467 -0.00902652 1.03637009 -1.14683883 0.20450098
0.11977745]
[-1.52187772 -0.86557915 1.15487405 0.94154816 -0.55900061
0.28165203]]]]
-----------------------------------------
[[-1.52187772 -0.86557915 1.15487405 0.94154816 -0.55900061 0.28165203]
[-1.55437505 0.1800381 -0.06461522 0.85310958 -1.25240021 -0.64335453]
[-0.18865467 -0.00902652 1.03637009 -1.14683883 0.20450098 0.11977745]
[-0.46992654 -0.07142709 0.35710407 0.24846814 -0.13490197 -1.31201887]
[-0.24019091 0.54458599 0.3623213 -0.10113704 -1.15469468 -0.33837456]]
如果索引indices
为矩阵或者更高阶的张量,其索引出来的尺寸就根据索引而定了,如上例子所示。
tf.gather_nd
tf.gather_nd
类似于tf.gather
,不过后者只能在一个维度上进行索引,而前者可以在多个维度上进行索引,其API为:
tf.gather_nd(
params, # 待索引输入张量
indices, # 索引
name=None
)
例子如:
params = tf.constant(np.random.normal(size=(5,6)))
index = tf.constant([[[0,0],[2,0]]])
out = tf.gather_nd(params, index)
with tf.Session() as sess:
print(sess.run(out).shape)
print(index.eval().shape)
print(out.eval())
print('-----------------------------------------')
print(params.eval())
# -------------- output
(1, 2)
(1, 2, 2)
[[-0.20841454 -0.12850639]]
-----------------------------------------
[[-0.20841454 0.13953242 2.26913464 1.60226729 1.51834021 -0.41044839]
[-1.3215913 -0.15647683 0.17277501 -0.60070571 0.80504465 0.85216738]
[-0.12850639 -0.34113574 0.16567085 -1.61315021 -0.29158798 0.03378417]
[ 0.67928455 0.16899565 1.0065189 -1.05949544 1.87377367 0.71542472]
[-0.2858821 -1.29338336 -0.41112208 -0.21772644 0.17873804 -0.0703971 ]]
tf.slice
用来进行切片操作,实现在python
中的a[:,2:3,5:6]
类似的操作,例子如:
t = tf.constant([[[1, 1, 1], [2, 2, 2]],
[[3, 3, 3], [4, 4, 4]],
[[5, 5, 5], [6, 6, 6]]])
# shape = (3,2,3)
tf.slice(t, [1, 0, 0], [1, 1, 3]) # [[[3, 3, 3]]], shape (1,1,3)
tf.slice(t, [1, 0, 0], [1, 2, 3]) # [[[3, 3, 3],
# [4, 4, 4]]] shape (1,2,3)
tf.slice(t, [1, 0, 0], [2, 1, 3]) # [[[3, 3, 3]],
# [[5, 5, 5]]] shape (2,1,3)