tensorflow对多维tensor按照指定索引重排序
背景是这样的,
比如我有一个张量data,shape是(batch_size,100,128)
我还有一个张量inc,shape是(batch_size,100)
我现在想根据这个张量地索引来对data重排序。
为什么会有这样地需求呢,是因为比如data是数据,100代表数据步长,128代表数据内units数目(维度),inc代表一个分数,这个分数表明了这100个步长当中每一步的重要性。现在我想要对data重排序一下,取top10,变成(batch_size,10,128),这样有利于后面的Attention。
操作例子见代码:
最主要的思想就是你有一个N维向量,那么就要指定一个N-1维的索引来对其重排序。例子中我们是一个(batch_size,100,128)的数据,
那么如果:
data是(batch_size,A,B,C,100,128)
inc是(batch_size,A,B,C,100,128)呢?
我的想法是先data reshape成(batch_size*A*B*C,100,128)
inc reshape成(batch_size*A*B*C,100)
后面的操作就一样了,先unstack,分别用gather取出相应切片(其实这里就已经做了个排序)
然后再stack回去
import tensorflow as tf import numpy as np data = tf.placeholder(tf.int64, [None, 5, 2]) choose = tf.placeholder(tf.int64,[None,5]) sortarg = tf.argsort(choose, direction="DESCENDING") split_data = tf.unstack(data, num=3, axis=0) split_choose = tf.unstack(sortarg, num=3, axis=0) trans_data_list = list() for i in range(3): trans_data_list.append(tf.gather(split_data[i], sortarg[i])) trans_data = tf.stack(trans_data_list, axis=0) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) feed_dict = { choose:[[5,4,3,0,1],[2,3,0,4,2],[2,3,5,4,2]], data:[[[1,2],[3,4],[5,6],[7,8],[9,10]], [[11,12],[13,14],[15,16],[17,18],[19,20]], [[21,22],[23,24],[25,26],[27,28],[29,30]]] } print(sess.run(sortarg,feed_dict=feed_dict)) print("-----------------------------------------------------") # print(sess.run(data_trans,feed_dict = feed_dict)) print(sess.run(data,feed_dict=feed_dict)) print("-----------------------------------------------------") print(sess.run(trans_data, feed_dict=feed_dict))