TensorFlow、Numpy中的axis的理解
TensorFlow中有很多函数涉及到axis,比如tf.reduce_mean(),其函数原型如下:
1 def reduce_mean(input_tensor, 2 axis=None, 3 keepdims=None, 4 name=None, 5 reduction_indices=None, 6 keep_dims=None):
其中axis表示的是,对该维度进行求均值(默认情况下,是对所有值求均值)。
除了TensorFlow中,numpy中也经常遇到很多对矩阵操作的函数会涉及axis操作。比如np.mean(),其函数原型如下:
1 def mean(a, axis=None, dtype=None, out=None, keepdims=np._NoValue):
想要弄清楚如何处理涉及axis(维度)的操作,必须先明白axis是什么。
首先axis是维度,如果axis=0则对应着高; 如果axis=1则对应着行处理;如果axis=2则对应着列;如果axis=3…n(无法用直观的图来表示)。我相信很多人看到这还是会一头雾水。什么是高,行还有列。为了说明这个问题,我举个列子:
data=[[[1,2,3],[11,22,33]],[[4,5,6],[44,55,66]],[[10,11,12],[100,110,120]],[[7,8,9],[77,88,99]]] data_np=np.array(data) print(data_np) [[[ 1 2 3] [ 11 22 33]] [[ 4 5 6] [ 44 55 66]] [[ 10 11 12] [100 110 120]] [[ 7 8 9] [ 77 88 99]]] 如上面,可以将最外层[ ]去掉,可以发现有4组元素(这里的元素是矩阵),你可以将其理解为高。 再从这3组元素中选取一组,比如选择的是 [[ 1 2 3] [ 11 22 33]] 然后将该组的最外层[ ]去掉,可以发现有2组元素分别为[ 1 2 3]和 [ 11 22 33],此时对应的是行。 在从这两组元素中选组一组,比如选择的是 [ 11 22 33] 现在无需去掉最外层的[ ]了,一眼就能看出里面有3个元素。这就是对应的列。 理解了上面的分析后,很容易就知道(高,行,列)对应的其实就是改矩阵的shape. print(data_np.shape): (4,2,3)
现在弄清楚了axis的值与(高,行,列)的关系后,再来分析tf.reduce_mean()或者np.mean()等函数是如何对axis进行操作的。
1 data=[[[1,2,3],[11,22,33]],[[4,5,6],[44,55,66]],[[10,11,12],[100,110,120]],[[7,8,9],[77,88,99]]] 2 3 data_tensor=tf.constant(data,dtype=tf.float32) 4 5 mean_axis0=tf.reduce_mean(data_tensor,axis=0) 6 mean_axis1=tf.reduce_mean(data_tensor,axis=1) 7 mean_axis2=tf.reduce_mean(data_tensor,axis=2) 8 9 with tf.Session() as sess: 10 print(sess.run(mean_axis0)) 11 print(sess.run(mean_axis1)) 12 print(sess.run(mean_axis2))
针对上述代码,我们先对axis=0维度的数据处理进行分析。
首先对上述data数据进行立体化变换,如下图(本人本想用软件来绘制3D的矩阵叠加效果,可惜找了很多软件都不适合,也许是本人寻找的还不够,欢迎有知道可以绘制3D的矩阵叠加效果的朋友们,能够分享一下。感激…)
如上如,axis=0的维度数据求均值,
[[(1+4+10+7)/4 (2+5+11+8)/4 (3+6+12+9)/4] [(11+44+100+77)/4 (22+55+110+88)/4 (33+66+120+99)/4]] = [[ 5.5 6.5 7.5 ] [58. 68.75 79.5 ]]
同理,对axis=1的维度数据求均值,
[[(1+11)/2 (2+22)/2 (3+33)/2] [(4+44)/2 (5+55)/2 (6+66)/2] [(10+100)/2 (11+110)/2 (12+120)/2] [(7+77)/2 (8+88)/2 (9+99)/2]] = [[ 6. 12. 18. ] [24. 30. 36. ] [55. 60.5 66. ] [42. 48. 54. ]]
同理可得axis=2维度的数据平均值为(过程留给读者去推,运算结果如下):
[[ 2. 22.] [ 5. 55.] [ 11. 110.] [ 8. 88.]]
在python的世界里,有很多时候都需要对数据进行维度的操作,如果对axis理解的不透的话,很容易找不着方向。
更多干货请关注: