tf.argmax

tf.argmax(input, axis=None, name=None, dimension=None)
此函数是对矩阵按行或列计算最大值

参数
input:输入Tensor
axis:0表示按列,1表示按行
name:名称
dimension:和axis功能一样,默认axis取值优先。新加的字段
返回:Tensor 一般是行或列的最大值下标

import tensorflow as tf  


a=tf.Variable(tf.random_uniform([3,4],minval=-1,maxval=1))  
b=tf.argmax(input=a,axis=0)  
c=tf.argmax(input=a,dimension=1)   #此处用dimesion或用axis是一样的  
sess = tf.Session() 
sess.run(tf.initialize_all_variables())  
print(sess.run(a))  
#[[ 0.04261756 -0.34297419 -0.87816691 -0.15430689]  
# [ 0.18663144  0.86972666 -0.06103253  0.38307118]  
# [ 0.84588599 -0.45432305 -0.39736366  0.38526249]]  
print(sess.run(b))  
#[2 1 1 2]  
print(sess.run(c))  
#[0 1 0]  
posted @ 2022-08-19 22:59  luoganttcc  阅读(12)  评论(0编辑  收藏  举报