Numpy 高维空间中的轴
完成日期: 2024-03-01
更新日期: 2024-03-01
问题
Numpy 中有众多操作会涉及到一个参数 axis
, 也就是 轴. 这到底是什么? 沿着某轴操作 (例如 np.sum(axis=0)
) 又是什么意思?
对于低维数组, 或许可以按 行 和 列 来理解, 但如果上升到了四维、五维乃至更高, 就变得十分抽象了. 因为这里要讨论更高维度的情况, 所以就不使用 "行" 或 "列" 之类的词语描述数组的几何意义了
一维数组
我们先来看一维数组的情况, 也就是 np.shape = (1,)
的情况 (这里的等号 =
是数学上的符号, 不是赋值的意思), arange()
方法可以接受一个参数 n
, 生成从 0 到 n 的整数序列
>>> a = np.arange(8)
>>> a
array([0, 1, 2, 3, 4, 5, 6, 7])
我们现在想要计算这个序列的和, 使用 np.sum(axis=0)
方法, 由于序列 a
只有一个维度, axis
参数只能为 0
>>> a
array([0, 1, 2, 3, 4, 5, 6, 7])
>>> a.shape
(8,)
>>> a.sum(axis=0)
28
输出值是 28, np.shape
方法可以打印数组的维度信息, 这里序列 a
只有一个, 这个维度有 8 个元素. axis=0
指在第 0
维上, 遍历所有元素, 执行求和 sum()
的操作, 也就是把索引为 0, 1, 2, ..., 7
的元素累加起来, 因为它们的层级都是第 0
维, 得到 [28]
, 写成伪代码的形式就是:
for i in range(8)
sum = sum + i
Numpy 中 sum()
、min()
等方法还会有一个降维的效果, 也就是 reduce, 于是用于表示维度的方括号被抽离, [28]
变成 28
, 成为 0
维的标量
二维数组
axis=0
>>> a = np.arange(6).reshape(2,3)
>>> a.shape
(2, 3)
>>> a
array([[0, 1, 2],
[3, 4, 5]])
>>> a.sum(axis=0)
array([3, 5, 7])
reshape()
方法可以调整数组维度, arange(6)
生成了有 6 个元素的一维数组, reshape(2,3)
将它调整成 2*3 的二维数组
在Numpy显示多维数组的方式中, 可以通过数方括号来确定维度, 第一个方括号是第 axis=0
维, 第二个方括号是第 axis=1
维. 我们调整一下上面 Numpy 显示二维数组的方式, 方便指示维度
>>> a
array([ -----------> 表示第 axis=0 维
[ --------> 表示第 axis=1 维
0, 1, 2],
[3, 4, 5]])
那么, axis=0
要求按第 0
轴去求和, 我们把数组 a
再换个表示方式
>>>a
array([[0, 1, 2],
[3, 4, 5]])
array([ A,
B ])
这里, 将 [0, 1, 2]
看作 A
, 将 [3, 4, 5]
看作 B
(还记得矩阵中的子矩阵吗?)
那么, 我们要遍历 axis=0
轴的元素执行求和, 就是要计算 [A+B]
, 也就是 [[0, 1, 2] + [3, 4, 5]]
, 结果是 [[3, 5, 7]]
, 由于 np.sum()
会降维, 抽离最外边第 axis=0
的方括号, 于是变成 [3, 5, 7]
如果写成伪代码, 即使有 N 行
for i in range(A, B, C, ... , N)
sum = sum + i
如果用切片的方式来表示, 是在计算 a[0, ...] + a[1, ...]
, 符号 ...
表示此维度不作指定, 会自动推导选择全部
>>> a[0, ...]
array([0, 1, 2])
>>> a[1, ...]
array([3, 4, 5])
[0, 1, 2] + [3, 4, 5] = [3, 5, 7]
数组 a
的 shape
从 (2,3)
, 执行完 sum(axis=0)
后, 变成 (3,)
, 第一个维度没有了
axis=1
>>> a
array([[0, 1, 2],
[3, 4, 5]])
>>> a.shape
(2, 3)
>>> a.sum(axis=1)
array([ 3, 12])
如果 np.sum(axis=1)
, 我们需要遍历同属于 axis=1
层级的元素并求和
>>> a
array([
[0, 1, 2], 左 ---> 右, 遍历求和, 得 [3]
[3, 4, 5] 左 ---> 右, 遍历求和, 的 [12]
])
array([
[3],
[12]
])
最终, 得到数组 [[3], [12]]
, 由于降维, 抽离里面 axis=1
的方括号, 变成 [3, 12]
如果用切片的方式查看, 实质上是在计算 a[..., 0] + a[..., 1] + a[..., 2]
>>> a[..., 0]
array([0, 3])
>>> a[..., 1]
array([1, 4])
>>> a[..., 2]
array([2, 5])
[0, 3] + [1, 4] + [2, 5] = [3, 12]
数组 a
的 shape
从 (2,3)
, 执行完 sum(axis=1)
后, 变成 (2,)
三维数组
到这里, 我们可以看到:
对于一维数组, axis=0
, 我们需要遍历数组中的每一个元素, 也就是 a[...]
对于二维数组, axis=0
, 我们需要遍历数组第 0
轴的每一个元素, 也就是 a[0, ...] + a[1, ...] + a[2, ...] ... a[n, ...]
如果 axis=1
, 我们需要遍历数组第 1
轴的每一个元素, 也就是 a[..., 0] + a[..., 1] + a[..., 2] ... a[..., n]
那么, 是不是可以认为, 对某一轴操作, 就是去遍历那个轴的切片?
axis=0
for i in {a[0, :, :, ...], a[1, :, :, ...], ..., a[n, :, :, ...]}
axis=1
for i in {a[:, 0, :, ...], a[:, 1, :, ...], ..., a[:, n, :, ...]}
axis=2
for i in {a[:, :, 0, ...], a[:, :, 1, ...], ..., a[:, :, n, ...]}
...
...
我们用三维数组验证一下
>>> a = np.arange(24).reshape(2,3,4)
>>> a
array([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
>>> a.sum(axis=0)
array([[12, 14, 16, 18],
[20, 22, 24, 26],
[28, 30, 32, 34]])
# 打印切片
>>> a[0,...]
array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
>>> a[1,...]
array([[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]])
[[ 0, 1, 2, 3], [[12, 13, 14, 15], [[12, 14, 16, 18],
[ 4, 5, 6, 7], + [16, 17, 18, 19], = [20, 22, 24, 26],
[ 8, 9, 10, 11]] [20, 21, 22, 23]] [28, 30, 32, 34]]
同一层级 axis=0
的元素有两个, 是切片 a[0, ...]
与 切片 a[1, ...]
, 它们的和正好就是 a.sum(axis=0)
现在我们在来看看 axis=1
的情况
>>> a
array([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
>>> a.sum(axis=1)
array([[12, 15, 18, 21],
[48, 51, 54, 57]])
# 打印切片
>>> a[:, 0, :]
array([[ 0, 1, 2, 3],
[12, 13, 14, 15]])
>>> a[:, 1, :]
array([[ 4, 5, 6, 7],
[16, 17, 18, 19]])
>>> a[:, 2, :]
array([[ 8, 9, 10, 11],
[20, 21, 22, 23]])
>>> a[:,0,:] + a[:,1,:] + a[:,2,:] == a.sum(axis=1)
array([[ True, True, True, True],
[ True, True, True, True]])
切片的和也确实等于 a.sum(axis=1)
, 回到数组 a
本身, 第一个维度(也就是 axis=0
) 不动, axis=1
的元素有 3 组, 需要遍历它们求和
array([[ ------------------------> axis=1 的轴
[ 0, 1, 2, 3], |
+ |
[ 4, 5, 6, 7], | 把这三组数加起来, 有
+ | [12, 15, 18, 21]
[ 8, 9, 10, 11]], V
[
[12, 13, 14, 15], |
+ |
[16, 17, 18, 19], | 同上操作, 有
+ | [48, 51, 54, 57]
[20, 21, 22, 23]]]) V
可以看到, 切片出来元素再相加, 和我们用这种方式取出元素再相加, 它们本质上是相同的
现在再来看看 axis=2
的情况, 它们都是一致的
array([[
[ 0, 1, 2, 3], 左---->右, 遍历求和, 有 [6]
[ 4, 5, 6, 7], 有 [22]
[ 8, 9, 10, 11]], 有 [38]
[
[12, 13, 14, 15], 有 [54]
[16, 17, 18, 19], 有 [70]
[20, 21, 22, 23] 有 [86]
]])
>>> a.sum(axis=2)
array([[ 6, 22, 38],
[54, 70, 86]])
# 打印切片
>>> a[:,:,0]
array([[ 0, 4, 8],
[12, 16, 20]])
>>> a[:,:,1]
array([[ 1, 5, 9],
[13, 17, 21]])
>>> a[:,:,2]
array([[ 2, 6, 10],
[14, 18, 22]])
>>> a[:,:,3]
array([[ 3, 7, 11],
[15, 19, 23]])
>>> a[:,:,0] + a[:,:,1] + a[:,:,2] + a[:,:,3] == a.sum(axis=2)
array([[ True, True, True],
[ True, True, True]])
切片相加, 也确实等于 sum(axis=2)
四维数组
我们再来简单验证一下四维数组
>>> a
array([[[ -----------------------------> axis=2 的轴
[ 0, 1, 2, 3, 4], |
[ 5, 6, 7, 8, 9], | 遍历求和, 有
[ 10, 11, 12, 13, 14], | [ 30, 34, 38, 42, 46]
[ 15, 16, 17, 18, 19]], V
[[ 20, 21, 22, 23, 24], |
[ 25, 26, 27, 28, 29], | 同上, 有
[ 30, 31, 32, 33, 34], | [110, 114, 118, 122, 126]
[ 35, 36, 37, 38, 39]], V
[[ 40, 41, 42, 43, 44], |
[ 45, 46, 47, 48, 49], | 同上, 以下不再赘述
[ 50, 51, 52, 53, 54], |
[ 55, 56, 57, 58, 59]]], V
[[[ 60, 61, 62, 63, 64],
[ 65, 66, 67, 68, 69],
[ 70, 71, 72, 73, 74],
[ 75, 76, 77, 78, 79]],
[[ 80, 81, 82, 83, 84],
[ 85, 86, 87, 88, 89],
[ 90, 91, 92, 93, 94],
[ 95, 96, 97, 98, 99]],
[[100, 101, 102, 103, 104],
[105, 106, 107, 108, 109],
[110, 111, 112, 113, 114],
[115, 116, 117, 118, 119]]]])
>>> a.sum(axis=2)
array([[[ 30, 34, 38, 42, 46],
[110, 114, 118, 122, 126],
[190, 194, 198, 202, 206]],
[[270, 274, 278, 282, 286],
[350, 354, 358, 362, 366],
[430, 434, 438, 442, 446]]])
# 打印切片
>>> a[:,:,0,:]
array([[[ 0, 1, 2, 3, 4],
[ 20, 21, 22, 23, 24],
[ 40, 41, 42, 43, 44]],
[[ 60, 61, 62, 63, 64],
[ 80, 81, 82, 83, 84],
[100, 101, 102, 103, 104]]])
>>> a[:,:,1,:]
array([[[ 5, 6, 7, 8, 9],
[ 25, 26, 27, 28, 29],
[ 45, 46, 47, 48, 49]],
[[ 65, 66, 67, 68, 69],
[ 85, 86, 87, 88, 89],
[105, 106, 107, 108, 109]]])
>>> a[:,:,2,:]
array([[[ 10, 11, 12, 13, 14],
[ 30, 31, 32, 33, 34],
[ 50, 51, 52, 53, 54]],
[[ 70, 71, 72, 73, 74],
[ 90, 91, 92, 93, 94],
[110, 111, 112, 113, 114]]])
>>> a[:,:,3,:]
array([[[ 15, 16, 17, 18, 19],
[ 35, 36, 37, 38, 39],
[ 55, 56, 57, 58, 59]],
[[ 75, 76, 77, 78, 79],
[ 95, 96, 97, 98, 99],
[115, 116, 117, 118, 119]]])
>>> a[:,:,0,:] + a[:,:,1,:] + a[:,:,2,:] + a[:,:,3,:] == a.sum(axis=2)
array([[[ True, True, True, True, True],
[ True, True, True, True, True],
[ True, True, True, True, True]],
[[ True, True, True, True, True],
[ True, True, True, True, True],
[ True, True, True, True, True]]])
可以看到结果仍然是正确的, 后续五维、六维乃至更高维度, 也是如此
高维数组的切片
高维数组的第 axis
维度切片过程, 需要将高于 axis
的维度保持不动, 将低于它的维度看作整体 (或者说是子矩阵), 我们还是以四维数组为例, 切片 a[:,1:2, :, 2:3]
, axis=1
为 1:2
, aixs=3
为 2:3
>>> a = np.arange(120).reshape(2,3,4,5)
array([ --------------------------------> axis=0, 全选
[[[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[ 10, 11, 12, 13, 14],
[ 15, 16, 17, 18, 19]],
[ -----------------------------> axis=1, 选取
[ 20, 21, 22, 23, 24],
[ 25, 26, 27, 28, 29],
[ 30, 31, 32, 33, 34],
[ 35, 36, 37, 38, 39]],
[[ 40, 41, 42, 43, 44],
[ 45, 46, 47, 48, 49],
[ 50, 51, 52, 53, 54],
[ 55, 56, 57, 58, 59]]],
# 下面是 axis=0 的第二个子矩阵
[[[ 60, 61, 62, 63, 64],
[ 65, 66, 67, 68, 69],
[ 70, 71, 72, 73, 74],
[ 75, 76, 77, 78, 79]],
[ -----------------------------> axis=1, 选取
[ 80, 81, 82, 83, 84],
[ 85, 86, 87, 88, 89],
[ 90, 91, 92, 93, 94],
[ 95, 96, 97, 98, 99]],
[[100, 101, 102, 103, 104],
[105, 106, 107, 108, 109],
[110, 111, 112, 113, 114],
[115, 116, 117, 118, 119]]]])
axis=0
需要全选, 它的维度是 2, 所以这两组子矩阵中, 都需要进行选取. axis=1
是 1:2
, 需要选择第 1
到第 2
个子矩阵, 但不包含第 2
个, 得到
[[[ -----------------------------> axis=1, 选取
[ 20, 21, 22, 23, 24], ----> 选择 [22]
[ 25, 26, 27, 28, 29], ----> 选择 [27]
[ 30, 31, 32, 33, 34], ----> 选择 [32]
[ 35, 36, 37, 38, 39]]],----> 选择 [37]
[[ -----------------------------> axis=1, 选取
[ 80, 81, 82, 83, 84], ----> 同上, 不再赘述
[ 85, 86, 87, 88, 89],
[ 90, 91, 92, 93, 94],
[ 95, 96, 97, 98, 99]]]],
axis=2
是全选, 所以上面每一行都需要保留, 进入下一层选择. axis=3
是 2:3
, 也就是选择第 2
个元素, 最终的结果如下:
>>> a[:,1:2, :, 2:3]
array([[[[22],
[27],
[32],
[37]]],
[[[82],
[87],
[92],
[97]]]])
可以看到, 切片后的数组, 仍然有 4 个轴, 切片不会降维