tf.nn.top_k(input, k, name=None)和tf.nn.in_top_k(predictions, targets, k, name=None)
tf.nn.top_k(input, k, name=None)
这个函数的作用是返回 input 中每行最大的 k 个数,并且返回它们所在位置的索引。
input: 一个张量,数据类型必须是以下之一:float32、float64、int32、int64、uint8、int16、int8。数据维度是 batch_size 乘上 x 个类别。
k: 一个整型,必须 >= 1。在每行中,查找最大的 k 个值。
name: 为这个操作取个名字。
输出:一个元组 Tensor ,数据元素是 (values, indices),具体如下:
values: 一个张量,数据类型和 input 相同。数据维度是 batch_size 乘上 k 个最大值。
indices: 一个张量,数据类型是 int32 。每个最大值在 input 中的索引位置。
tf.nn.in_top_k(predictions, targets, k, name=None)
就是对比predictions和targets是否一样,一样的返回true,不一样的返回false,接下来用tf.cast(correct,tf.floatxx) 可以计算准确率
predictions:预测的结果,预测矩阵大小为样本数×标注的label类的个数的二维矩阵。
targets:实际的标签,大小为样本数。
k:每个样本的预测结果的前k个最大的数里面是否包含targets预测中的标签,一般都是取1,即取预测最大概率的索引与标签对比。
name:名字。