tf.gather()、tf.gather_nd()、tf.batch_gather()、tf.where()和tf.slice()
分类:
TensorFlow
1.tf.gather
1 | tf.gather(params, indices, validate_indices = None , name = None , axis = 0 ) |
功能:根据提供的
indices
在axis
这个轴上对params
进行索引,拼接成一个新的张量。
参数:
- params:需要被索引的张量
- indices:必须为整数类型,如int32,int64等,注意检查不要越界了,因为如果越界了,如果使用的
CPU
,则会报错,如果在GPU
上进行操作的,那么相应的输出值将会被置为0,而不会报错,因此认真检查是否越界。 - name:返回张量名称
返回维度: params.shape[:axis] + indices.shape + params.shape[axis + 1:]
举例:
1 2 3 4 5 6 7 | import tensorflow as tf temp4 = tf.reshape(tf. range ( 0 , 20 ) + tf.constant( 1 ,shape = [ 20 ]),[ 2 , 2 , 5 ])temp4: [[[ 1 2 3 4 5 ] [ 6 7 8 9 10 ]] [[ 11 12 13 14 15 ] [ 16 17 18 19 20 ]]] |
(1)当indices是向量时,输出的形状和输入形状相同,不改变
1 2 3 4 5 6 7 8 9 10 | temp5 = tf.gather(temp4,[ 0 , 1 ],axis = 0 ) #indices是向量temp5:<br>[[[ 1 2 3 4 5] [ 6 7 8 9 10 ]] [[ 11 12 13 14 15 ] [ 16 17 18 19 20 ]]]<br><br>temp7 = tf.gather(temp4,[ 1 , 4 ],axis = 2 )<br> # (2,2,5)[:2]+(2,)+(2,2,5)[3:]=(2,2,2)temp7: [[[ 2 5 ] [ 7 10 ]] [[ 12 15 ] [ 17 20 ]]] |
(2)当indices是数值时,输出的形状比输入的形状少一维
1 2 | temp6 = tf.gather(temp4, 1 ,axis = 1 ) #indices是数值<br># (2,2,5)[:1]+()+(2,2,5)[2:]=(2,5)temp:<br>[[ 6 7 8 9 10] [ 16 17 18 19 20 ]] |
(3)当indices是多维时
1 2 3 4 5 6 7 8 9 10 11 | temp8 = tf.gather(temp4,[[ 0 , 1 ],[ 3 , 4 ]],axis = 2 ) #indices是多维的<br># (2,2,5)[:2]+(2,2)+(2,2,5)[3:]=(2,2,2,2)<br>temp8:[[[[ 1 2] [ 4 5 ]] [[ 6 7 ] [ 9 10 ]]] [[[ 11 12 ] [ 14 15 ]] [[ 16 17 ] [ 19 20 ]]]] |
bert源码:
1 2 3 4 5 6 | flat_input_ids = tf.reshape(input_ids, [ - 1 ]) #【batch_size*seq_length*input_num】 if use_one_hot_embeddings: one_hot_input_ids = tf.one_hot(flat_input_ids, depth = vocab_size) output = tf.matmul(one_hot_input_ids, embedding_table) else : output = tf.gather(embedding_table, flat_input_ids) |
2.tf.gather_nd
1 2 3 4 5 | tf.gather_nd( params, indices, name = None , batch_dims = 0 ) |
功能:类似于tf.gather
,不过后者只能在一个维度上进行索引,而前者可以在多个维度上进行索引,
参数:
- params:待索引输入张量
- indices:索引,int32,int64,indices将切片定义为params的前N个维度,其中N = indices.shape [-1]
- 通常要求indices.shape[-1] <= params.rank(可以用np.ndim(params)查看)
- 如果等号成立是在索引具体元素
- 如果等号不成立是在沿params的indices.shape[-1]轴进行切片
- name=None:操作的名称(可选)
返回维度: indices.shape[:-1] + params.shape[indices.shape[-1]:],前面的indices.shape[:-1]代表索引后的指定形状
举例:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 | indices = [[ 0 , 0 ], [ 1 , 1 ]] params = [[ 'a' , 'b' ], [ 'c' , 'd' ]] # (2,2)[:-1]+(2,2)[(2,2)[-1]:]=(2,) output = [ 'a' , 'd' ] 表示将params对应第一行第一列的 'a' 和第二行第二列的 'd' 取出来 indices = [[ 1 ], [ 0 ]] params = [[ 'a' , 'b' ], [ 'c' , 'd' ]] # (2,1)[:-1]+(2,2)[(2,1)[-1]:]=(2,)+(2,)=(2,2) output = [[ 'c' , 'd' ], [ 'a' , 'b' ]] 表示将params对应第二行和第一行取出来 ''' 功能:T是一个二维tensor,我们想要根据另外一个二维tensor value的最后一维最大元素的下标选出tensor T中 最后一维最大的元素,组成一个新的一维的tensor,那么就可以首先选出最后一维度的下标[1,2,3], 然后将其扩展成[[0,1],[1,2],[2,3]],然后使用这个函数选择即可。 ''' import tensorflow as tf sess = tf.InteractiveSession() values = tf.constant([[ 0 , 0 , 0 , 1 ], [ 0 , 1 , 0 , 0 ], [ 0 , 0 , 1 , 0 ]]) T = tf.constant([[ 0 , 1 , 2 , 3 ], [ 4 , 5 , 6 , 7 ], [ 8 , 9 , 10 , 11 ]]) max_indices = tf.argmax(values, axis = 1 ) # 行 print ( 'max_indices' ,max_indices. eval ()) # [3 1 2] # If T.get_shape()[0] is None, you can replace it with tf.shape(T)[0]. print (tf.stack((tf. range (T.get_shape()[ 0 ],dtype = max_indices.dtype),max_indices),axis = 1 ). eval ()) print (tf. range (T.get_shape()[ 0 ]). eval ()) result = tf.gather_nd(T, tf.stack((tf. range (T.get_shape()[ 0 ], dtype = max_indices.dtype), max_indices), axis = 1 )) print (result. eval ()) |
3.tf.batch_gather
作用:支持对张量的批量索引.注意因为是批处理,所以indices要有和params相同的第0个维度。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 | import tensorflow as tf tensor_a = tf.Variable([[ 1 , 2 , 3 ],[ 4 , 5 , 6 ],[ 7 , 8 , 9 ]]) tensor_b = tf.Variable([[ 0 ],[ 1 ],[ 2 ]],dtype = tf.int32) tensor_c = tf.Variable([[ 0 ],[ 0 ],[ 0 ]],dtype = tf.int32) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) print ( 'gather' ) print (sess.run(tf.gather(tensor_a,tensor_b))) print (sess.run(tf.gather(tensor_a,tensor_c))) print ( 'gather_nd' ) print (sess.run(tf.gather_nd(tensor_a, tensor_b))) print (sess.run(tf.gather_nd(tensor_a, tensor_c))) print ( 'batch_gather' ) print (sess.run(tf.batch_gather(tensor_a, tensor_b))) print (sess.run(tf.batch_gather(tensor_a, tensor_c))) |
4.tf.where
1 | tf.where(condition, x = None , y = None , name = None ) |
作用: 返回condition为True的元素坐标(x=y=None)
- condition:布尔型张量,True/False
- x:与y具有相同类型的张量,可以使用条件和y进行广播。
- y:与x具有相同类型的张量,可以在条件和x的条件下进行广播。
- name:操作名称(可选)
返回维度: (num_true, dim_size(condition)),其中dim_size为condition的维度。
(1)tf.where(condition)
- condition是bool型值,True/False
- 返回值,是condition中元素为True对应的索引
1 2 3 4 5 6 7 8 9 10 | import tensorflow as tf a = [[ 1 , 2 , 3 ],[ 4 , 5 , 6 ]] b = [[ 1 , 0 , 3 ],[ 1 , 5 , 1 ]] condition1 = [[ True , False , False ], [ False , True , True ]] condition2 = [[ True , False , False ], [ False , True , False ]] with tf.Session() as sess: print (sess.run(tf.where(condition1))) print (sess.run(tf.where(condition2))) |
(2)tf.where(condition, x=None, y=None, name=None)
- condition, x, y 相同维度,condition是bool型值,True/False
- 返回值是对应元素,condition中元素为True的元素替换为x中的元素,为False的元素替换为y中对应元素
- x只负责对应替换True的元素,y只负责对应替换False的元素,x,y各有分工
- 由于是替换,返回值的维度,和condition,x , y都是相等的。
1 2 3 4 5 6 7 8 9 10 | import tensorflow as tf x = [[ 1 , 2 , 3 ],[ 4 , 5 , 6 ]] y = [[ 7 , 8 , 9 ],[ 10 , 11 , 12 ]] condition3 = [[ True , False , False ], [ False , True , True ]] condition4 = [[ True , False , False ], [ True , True , False ]] with tf.Session() as sess: print (sess.run(tf.where(condition3,x,y))) print (sess.run(tf.where(condition4,x,y))) |
5.tf.slice()
1 | tf. slice (inputs, begin, size, name) |
作用:用来进行切片操作,实现在python
中的a[:,2:3,5:6]
类似的操作,从列表、数组、张量等对象中抽取一部分数据
- begin和size是两个多维列表,他们共同决定了要抽取的数据的开始和结束位置
- begin表示从inputs的哪几个维度上的哪个元素开始抽取
- size表示在inputs的各个维度上抽取的元素个数
- 若begin[]或size[]中出现-1,表示抽取对应维度上的所有元素
1 2 3 4 5 6 7 8 9 | import tensorflow as tf t = tf.constant([[[ 1 , 1 , 1 ], [ 2 , 2 , 2 ]], [[ 3 , 3 , 3 ], [ 4 , 4 , 4 ]], [[ 5 , 5 , 5 ], [ 6 , 6 , 6 ]]]) tf. slice (t, [ 1 , 0 , 0 ], [ 1 , 1 , 3 ]) # [[[3, 3, 3]]] tf. slice (t, [ 1 , 0 , 0 ], [ 1 , 2 , 3 ]) # [[[3, 3, 3], # [4, 4, 4]]] tf. slice (t, [ 1 , 0 , 0 ], [ 2 , 1 , 3 ]) # [[[3, 3, 3]], # [[5, 5, 5]]] |
bert源码:
1 2 3 4 5 | # 这里position embedding是可学习的参数,[max_position_embeddings, width] # 但是通常实际输入序列没有达到max_position_embeddings # 所以为了提高训练速度,使用tf.slice取出句子长度的embedding # full_position_embeddings:[max_position_embeddings, width] position_embeddings = tf. slice (full_position_embeddings, [ 0 , 0 ],[seq_length, - 1 ]) |
参考文献:
【1】tf.gather, tf.gather_nd和tf.slice_机器学习杂货铺1号店-CSDN博客
【2】tf.where/tf.gather/tf.gather_nd - 知乎
【3】tenflow 入门 tf.where()用法_ustbbsy的博客-CSDN博客
【4】tf.gather tf.gather_nd 和 tf.batch_gather 使用方法_张冰洋的天空-CSDN博客
分类:
TensorFlow
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· go语言实现终端里的倒计时
· 如何编写易于单元测试的代码
· 10年+ .NET Coder 心语,封装的思维:从隐藏、稳定开始理解其本质意义
· .NET Core 中如何实现缓存的预热?
· 从 HTTP 原因短语缺失研究 HTTP/2 和 HTTP/3 的设计差异
· 分享一个免费、快速、无限量使用的满血 DeepSeek R1 模型,支持深度思考和联网搜索!
· 使用C#创建一个MCP客户端
· ollama系列1:轻松3步本地部署deepseek,普通电脑可用
· 基于 Docker 搭建 FRP 内网穿透开源项目(很简单哒)
· 按钮权限的设计及实现