Numpy 高维空间中的轴

Numpy 高维空间中的轴 (axis)

完成日期: 2024-03-01
更新日期: 2024-03-01

问题

Numpy 中有众多操作会涉及到一个参数 axis, 也就是 . 这到底是什么? 沿着某轴操作 (例如 np.sum(axis=0)) 又是什么意思?

对于低维数组, 或许可以按 来理解, 但如果上升到了四维、五维乃至更高, 就变得十分抽象了. 因为这里要讨论更高维度的情况, 所以就不使用 "行" 或 "列" 之类的词语描述数组的几何意义了

一维数组

我们先来看一维数组的情况, 也就是 np.shape = (1,) 的情况 (这里的等号 = 是数学上的符号, 不是赋值的意思), arange() 方法可以接受一个参数 n, 生成从 0 到 n 的整数序列

>>> a = np.arange(8)
>>> a
array([0, 1, 2, 3, 4, 5, 6, 7])

我们现在想要计算这个序列的和, 使用 np.sum(axis=0) 方法, 由于序列 a 只有一个维度, axis 参数只能为 0

>>> a
array([0, 1, 2, 3, 4, 5, 6, 7])
>>> a.shape
(8,)
>>> a.sum(axis=0)
28

输出值是 28, np.shape 方法可以打印数组的维度信息, 这里序列 a 只有一个, 这个维度有 8 个元素. axis=0 指在第 0 维上, 遍历所有元素, 执行求和 sum() 的操作, 也就是把索引为 0, 1, 2, ..., 7 的元素累加起来, 因为它们的层级都是第 0 维, 得到 [28], 写成伪代码的形式就是:

for i in range(8)
    sum = sum + i

Numpy 中 sum()min() 等方法还会有一个降维的效果, 也就是 reduce, 于是用于表示维度的方括号被抽离, [28] 变成 28, 成为 0 维的标量

二维数组

axis=0

>>> a = np.arange(6).reshape(2,3)
>>> a.shape
(2, 3)
>>> a
array([[0, 1, 2],
       [3, 4, 5]])
>>> a.sum(axis=0)
array([3, 5, 7])

reshape() 方法可以调整数组维度, arange(6) 生成了有 6 个元素的一维数组, reshape(2,3) 将它调整成 2*3 的二维数组

在Numpy显示多维数组的方式中, 可以通过数方括号来确定维度, 第一个方括号是第 axis=0 维, 第二个方括号是第 axis=1 维. 我们调整一下上面 Numpy 显示二维数组的方式, 方便指示维度

>>> a
array([ -----------> 表示第 axis=0 维
        [  --------> 表示第 axis=1 维
         0, 1, 2],
        [3, 4, 5]])

那么, axis=0 要求按第 0 轴去求和, 我们把数组 a 再换个表示方式

>>>a
array([[0, 1, 2],
       [3, 4, 5]])

array([ A,
        B ])

这里, 将 [0, 1, 2] 看作 A, 将 [3, 4, 5] 看作 B (还记得矩阵中的子矩阵吗?)

那么, 我们要遍历 axis=0 轴的元素执行求和, 就是要计算 [A+B], 也就是 [[0, 1, 2] + [3, 4, 5]], 结果是 [[3, 5, 7]], 由于 np.sum() 会降维, 抽离最外边第 axis=0 的方括号, 于是变成 [3, 5, 7]

如果写成伪代码, 即使有 N 行

for i in range(A, B, C, ... , N)
    sum = sum + i

如果用切片的方式来表示, 是在计算 a[0, ...] + a[1, ...], 符号 ... 表示此维度不作指定, 会自动推导选择全部

>>> a[0, ...]
array([0, 1, 2])

>>> a[1, ...]
array([3, 4, 5])

[0, 1, 2] + [3, 4, 5] = [3, 5, 7]

数组 ashape(2,3), 执行完 sum(axis=0) 后, 变成 (3,), 第一个维度没有了

axis=1

>>> a
array([[0, 1, 2],
       [3, 4, 5]])
>>> a.shape
(2, 3)
>>> a.sum(axis=1)
array([ 3, 12])

如果 np.sum(axis=1), 我们需要遍历同属于 axis=1 层级的元素并求和

>>> a
array([
        [0, 1, 2],  左 ---> 右, 遍历求和, 得 [3]
        [3, 4, 5]   左 ---> 右, 遍历求和, 的 [12]
                 ])

array([
        [3],
        [12]
            ])

最终, 得到数组 [[3], [12]], 由于降维, 抽离里面 axis=1 的方括号, 变成 [3, 12]

如果用切片的方式查看, 实质上是在计算 a[..., 0] + a[..., 1] + a[..., 2]

>>> a[..., 0]
array([0, 3])

>>> a[..., 1]
array([1, 4])

>>> a[..., 2]
array([2, 5])

[0, 3] + [1, 4] + [2, 5] = [3, 12]

数组 ashape(2,3), 执行完 sum(axis=1) 后, 变成 (2,)

三维数组

到这里, 我们可以看到:

对于一维数组, axis=0, 我们需要遍历数组中的每一个元素, 也就是 a[...]

对于二维数组, axis=0, 我们需要遍历数组第 0 轴的每一个元素, 也就是 a[0, ...] + a[1, ...] + a[2, ...] ... a[n, ...]

如果 axis=1, 我们需要遍历数组第 1 轴的每一个元素, 也就是 a[..., 0] + a[..., 1] + a[..., 2] ... a[..., n]

那么, 是不是可以认为, 对某一轴操作, 就是去遍历那个轴的切片?

axis=0
for i in {a[0, :, :, ...], a[1, :, :, ...], ..., a[n, :, :, ...]}

axis=1
for i in {a[:, 0, :, ...], a[:, 1, :, ...], ..., a[:, n, :, ...]}

axis=2
for i in {a[:, :, 0, ...], a[:, :, 1, ...], ..., a[:, :, n, ...]}

...
...

我们用三维数组验证一下

>>> a = np.arange(24).reshape(2,3,4)
>>> a
array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]],

       [[12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23]]])
>>> a.sum(axis=0)
array([[12, 14, 16, 18],
       [20, 22, 24, 26],
       [28, 30, 32, 34]])

# 打印切片
>>> a[0,...]
array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]])
>>> a[1,...]
array([[12, 13, 14, 15],
       [16, 17, 18, 19],
       [20, 21, 22, 23]])

[[ 0,  1,  2,  3],     [[12, 13, 14, 15],       [[12, 14, 16, 18],
 [ 4,  5,  6,  7],  +   [16, 17, 18, 19],   =    [20, 22, 24, 26],
 [ 8,  9, 10, 11]]      [20, 21, 22, 23]]        [28, 30, 32, 34]]

同一层级 axis=0 的元素有两个, 是切片 a[0, ...] 与 切片 a[1, ...], 它们的和正好就是 a.sum(axis=0)

现在我们在来看看 axis=1 的情况

>>> a
array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]],

       [[12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23]]])
>>> a.sum(axis=1)
array([[12, 15, 18, 21],
       [48, 51, 54, 57]])

# 打印切片
>>> a[:, 0, :]
array([[ 0,  1,  2,  3],
       [12, 13, 14, 15]])

>>> a[:, 1, :]
array([[ 4,  5,  6,  7],
       [16, 17, 18, 19]])

>>> a[:, 2, :]
array([[ 8,  9, 10, 11],
       [20, 21, 22, 23]])

>>> a[:,0,:] + a[:,1,:] + a[:,2,:] == a.sum(axis=1)
array([[ True,  True,  True,  True],
       [ True,  True,  True,  True]])

切片的和也确实等于 a.sum(axis=1), 回到数组 a 本身, 第一个维度(也就是 axis=0) 不动, axis=1 的元素有 3 组, 需要遍历它们求和

array([[  ------------------------> axis=1 的轴
        [ 0,  1,  2,  3],       |
                +               |
        [ 4,  5,  6,  7],       |   把这三组数加起来, 有
                +               |   [12, 15, 18, 21]
        [ 8,  9, 10, 11]],      V


       [
        [12, 13, 14, 15],       |
                +               |
        [16, 17, 18, 19],       |   同上操作, 有
                +               |   [48, 51, 54, 57]
        [20, 21, 22, 23]]])     V

可以看到, 切片出来元素再相加, 和我们用这种方式取出元素再相加, 它们本质上是相同的

现在再来看看 axis=2 的情况, 它们都是一致的

array([[
        [ 0,  1,  2,  3],  左---->右, 遍历求和, 有 [6]
        [ 4,  5,  6,  7],                      有 [22]
        [ 8,  9, 10, 11]],                     有 [38]

       [
        [12, 13, 14, 15],       有 [54]
        [16, 17, 18, 19],       有 [70]
        [20, 21, 22, 23]        有 [86]
                        ]])
>>> a.sum(axis=2)
array([[ 6, 22, 38],
       [54, 70, 86]])

# 打印切片
>>> a[:,:,0]
array([[ 0,  4,  8],
       [12, 16, 20]])

>>> a[:,:,1]
array([[ 1,  5,  9],
       [13, 17, 21]])

>>> a[:,:,2]
array([[ 2,  6, 10],
       [14, 18, 22]])

>>> a[:,:,3]
array([[ 3,  7, 11],
       [15, 19, 23]])

>>> a[:,:,0] + a[:,:,1] + a[:,:,2] + a[:,:,3] == a.sum(axis=2)
array([[ True,  True,  True],
       [ True,  True,  True]])

切片相加, 也确实等于 sum(axis=2)

四维数组

我们再来简单验证一下四维数组

>>> a
array([[[  -----------------------------> axis=2 的轴
         [  0,   1,   2,   3,   4],   |
         [  5,   6,   7,   8,   9],   | 遍历求和, 有
         [ 10,  11,  12,  13,  14],   | [ 30,  34,  38,  42,  46]
         [ 15,  16,  17,  18,  19]],  V

        [[ 20,  21,  22,  23,  24],   |
         [ 25,  26,  27,  28,  29],   | 同上, 有
         [ 30,  31,  32,  33,  34],   | [110, 114, 118, 122, 126]
         [ 35,  36,  37,  38,  39]],  V

        [[ 40,  41,  42,  43,  44],   |
         [ 45,  46,  47,  48,  49],   | 同上, 以下不再赘述
         [ 50,  51,  52,  53,  54],   |
         [ 55,  56,  57,  58,  59]]], V


       [[[ 60,  61,  62,  63,  64],
         [ 65,  66,  67,  68,  69],
         [ 70,  71,  72,  73,  74],
         [ 75,  76,  77,  78,  79]],

        [[ 80,  81,  82,  83,  84],
         [ 85,  86,  87,  88,  89],
         [ 90,  91,  92,  93,  94],
         [ 95,  96,  97,  98,  99]],

        [[100, 101, 102, 103, 104],
         [105, 106, 107, 108, 109],
         [110, 111, 112, 113, 114],
         [115, 116, 117, 118, 119]]]])
>>> a.sum(axis=2)
array([[[ 30,  34,  38,  42,  46],
        [110, 114, 118, 122, 126],
        [190, 194, 198, 202, 206]],

       [[270, 274, 278, 282, 286],
        [350, 354, 358, 362, 366],
        [430, 434, 438, 442, 446]]])

# 打印切片
>>> a[:,:,0,:]
array([[[  0,   1,   2,   3,   4],
        [ 20,  21,  22,  23,  24],
        [ 40,  41,  42,  43,  44]],

       [[ 60,  61,  62,  63,  64],
        [ 80,  81,  82,  83,  84],
        [100, 101, 102, 103, 104]]])
>>> a[:,:,1,:]
array([[[  5,   6,   7,   8,   9],
        [ 25,  26,  27,  28,  29],
        [ 45,  46,  47,  48,  49]],

       [[ 65,  66,  67,  68,  69],
        [ 85,  86,  87,  88,  89],
        [105, 106, 107, 108, 109]]])
>>> a[:,:,2,:]
array([[[ 10,  11,  12,  13,  14],
        [ 30,  31,  32,  33,  34],
        [ 50,  51,  52,  53,  54]],

       [[ 70,  71,  72,  73,  74],
        [ 90,  91,  92,  93,  94],
        [110, 111, 112, 113, 114]]])
>>> a[:,:,3,:]
array([[[ 15,  16,  17,  18,  19],
        [ 35,  36,  37,  38,  39],
        [ 55,  56,  57,  58,  59]],

       [[ 75,  76,  77,  78,  79],
        [ 95,  96,  97,  98,  99],
        [115, 116, 117, 118, 119]]])

>>> a[:,:,0,:] + a[:,:,1,:] + a[:,:,2,:] + a[:,:,3,:] == a.sum(axis=2)
array([[[ True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True]],

       [[ True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True]]])

可以看到结果仍然是正确的, 后续五维、六维乃至更高维度, 也是如此

高维数组的切片

高维数组的第 axis 维度切片过程, 需要将高于 axis 的维度保持不动, 将低于它的维度看作整体 (或者说是子矩阵), 我们还是以四维数组为例, 切片 a[:,1:2, :, 2:3], axis=11:2, aixs=32:3

>>> a = np.arange(120).reshape(2,3,4,5)
array([ --------------------------------> axis=0, 全选
        [[[  0,   1,   2,   3,   4],
         [  5,   6,   7,   8,   9],
         [ 10,  11,  12,  13,  14],
         [ 15,  16,  17,  18,  19]],

        [  -----------------------------> axis=1, 选取
         [ 20,  21,  22,  23,  24],
         [ 25,  26,  27,  28,  29],
         [ 30,  31,  32,  33,  34],
         [ 35,  36,  37,  38,  39]],

        [[ 40,  41,  42,  43,  44],
         [ 45,  46,  47,  48,  49],
         [ 50,  51,  52,  53,  54],
         [ 55,  56,  57,  58,  59]]],

        # 下面是 axis=0 的第二个子矩阵
       [[[ 60,  61,  62,  63,  64],
         [ 65,  66,  67,  68,  69],
         [ 70,  71,  72,  73,  74],
         [ 75,  76,  77,  78,  79]],

        [  -----------------------------> axis=1, 选取
         [ 80,  81,  82,  83,  84],
         [ 85,  86,  87,  88,  89],
         [ 90,  91,  92,  93,  94],
         [ 95,  96,  97,  98,  99]],

        [[100, 101, 102, 103, 104],
         [105, 106, 107, 108, 109],
         [110, 111, 112, 113, 114],
         [115, 116, 117, 118, 119]]]])

axis=0 需要全选, 它的维度是 2, 所以这两组子矩阵中, 都需要进行选取. axis=11:2, 需要选择第 1 到第 2 个子矩阵, 但不包含第 2 个, 得到

[[[  -----------------------------> axis=1, 选取
    [ 20,  21,  22,  23,  24],  ----> 选择 [22]
    [ 25,  26,  27,  28,  29],  ----> 选择 [27]
    [ 30,  31,  32,  33,  34],  ----> 选择 [32]
    [ 35,  36,  37,  38,  39]]],----> 选择 [37]

[[  -----------------------------> axis=1, 选取
    [ 80,  81,  82,  83,  84],  ----> 同上, 不再赘述
    [ 85,  86,  87,  88,  89],
    [ 90,  91,  92,  93,  94],
    [ 95,  96,  97,  98,  99]]]],

axis=2 是全选, 所以上面每一行都需要保留, 进入下一层选择. axis=32:3, 也就是选择第 2 个元素, 最终的结果如下:

>>> a[:,1:2, :, 2:3]
array([[[[22],
         [27],
         [32],
         [37]]],


       [[[82],
         [87],
         [92],
         [97]]]])

可以看到, 切片后的数组, 仍然有 4 个轴, 切片不会降维

posted @ 2024-03-01 02:51  Asnelin  阅读(208)  评论(0编辑  收藏  举报