tensorflow运行mnist的一些常用函数整理(2)
主要有:
tf.equal()
tf.argmax()
tf.cast()
详细介绍:
tf.equal(A, B)
这个函数主要是用于比较两个矩阵或者向量,返回的矩阵维度与A矩阵一样
注意:返回的矩阵里面的元素是布尔类型
B的形状不一定要与A的形状一样
运行如下代码:
import tensorflow as tf import numpy as np A = [[1,3,4],[2,5,6]] B = [[1,2,4]] with tf.Session() as sess: print(sess.run(tf.equal(A, B)))
输出结果:
[[ True False True]
[False False False]]
tf.argmax(A, axis)
这个函数返回矩阵最大值的索引号,对于矩阵来说,返回的是一个向量
axis=1:按列比较----每一列对应元素比较
axis=0:按行比较(第二个参数不填默认为axis=0)
如果某一行或者一列有相同的数,则返回最靠前的
import tensorflow as tf import numpy as np A = [[1,3,4],[2,5,6],[6,5,4]] with tf.Session() as sess: #修改这里的axis print(sess.run(tf.argmax(A, 1)))
axis = 0:[2 1 1]
axis = 1:[2 2 0]
tf.cast(x, dtype, name=None)
将x的数据格式转换为dtype,比较常用的是将bool转换为0、1
注意:要用tensorflow里面定义的类型,如果将下面代码中tf.float32换为float
则报错 Cannot convert value <class 'float'> to a TensorFlow DType.
import tensorflow as tf import numpy as np #用tensorflow来定义变量也可,后面需要初始化 #不初始化出现如下错误 #Attempting to use uninitialized value Variable #A = tf.Variable([[True,True,False],[False,False,False],[True,True,True]]) A = [[True,True,False],[False,False,False],[True,True,True]] with tf.Session() as sess: #sess.run(tf.initialize_all_variables()) print(sess.run(tf.cast(A, tf.float32)))
输出:
[[ 1. 1. 0.]
[ 0. 0. 0.]
[ 1. 1. 1.]]