拒绝for循环,从take_along_axis开始

技术背景

在前一篇文章中,我们提到了关于Numpy中的各种取index的方法,可以用于取出数组里面的元素,也可以用于做切片,甚至可以用来做排序。但是遇到对于高维矩阵的某一个维度取多个值的时候,单纯的使用下标已经无法完成相关的操作了。如果找不到相应的接口,对于性能要求不高的场景可以使用一个for循环进行替代,但是对于性能要求比较高的场景下,我们还是尽可能的使用Numpy本身自带的接口,比如本文将要提到的take_along_axis操作。

使用案例

我们考虑这样的一个场景,给定一个维度为(4,11,3)的矩阵a作为数据,和一个维度为(4,2)的矩阵b作为下标,意味着从a中第二条轴的11个元素中每次取两个元素,也就是希望得到一个维度为(4,2,3)的结果:

In [11]: a = np.arange(132).reshape((4,11,3))

In [12]: a
Out[12]: 
array([[[  0,   1,   2],
        [  3,   4,   5],
        [  6,   7,   8],
        [  9,  10,  11],
        [ 12,  13,  14],
        [ 15,  16,  17],
        [ 18,  19,  20],
        [ 21,  22,  23],
        [ 24,  25,  26],
        [ 27,  28,  29],
        [ 30,  31,  32]],

       [[ 33,  34,  35],
        [ 36,  37,  38],
        [ 39,  40,  41],
        [ 42,  43,  44],
        [ 45,  46,  47],
        [ 48,  49,  50],
        [ 51,  52,  53],
        [ 54,  55,  56],
        [ 57,  58,  59],
        [ 60,  61,  62],
        [ 63,  64,  65]],

       [[ 66,  67,  68],
        [ 69,  70,  71],
        [ 72,  73,  74],
        [ 75,  76,  77],
        [ 78,  79,  80],
        [ 81,  82,  83],
        [ 84,  85,  86],
        [ 87,  88,  89],
        [ 90,  91,  92],
        [ 93,  94,  95],
        [ 96,  97,  98]],

       [[ 99, 100, 101],
        [102, 103, 104],
        [105, 106, 107],
        [108, 109, 110],
        [111, 112, 113],
        [114, 115, 116],
        [117, 118, 119],
        [120, 121, 122],
        [123, 124, 125],
        [126, 127, 128],
        [129, 130, 131]]])

In [13]: b = np.array([[0,1],[1,2],[2,3],[3,4]])

In [14]: b
Out[14]: 
array([[0, 1],
       [1, 2],
       [2, 3],
       [3, 4]])

为了方便展示我们就定义了这样两个比较简单的矩阵a和b,那么在这个结果中,我们理想的结果应该是:

[[[  0,   1,   2],
  [  3,   4,   5]],

 [[ 36,  37,  38],
  [ 39,  40,  41]],

 [[ 72,  73,  74],
  [ 75,  76,  77]],

 [[108, 109, 110],
  [111, 112, 113]]]

这样的一个矩阵。关于这个结果的来源,可以对b这个定义进行展开解释,b的值为:

[[0, 1],
 [1, 2],
 [2, 3],
 [3, 4]]

它所表示的是在a[0]下取第0个元素和第1个元素,在a[1]下取第1个元素和第2个元素,以此类推。然而如果我们直接把定义好的b放到a的索引中或者直接使用numpy.take的方法的话,得到的结果是这样的:

In [16]: a[:,b]
Out[16]: 
array([[[[  0,   1,   2],
         [  3,   4,   5]],

        [[  3,   4,   5],
         [  6,   7,   8]],

        [[  6,   7,   8],
         [  9,  10,  11]],

        [[  9,  10,  11],
         [ 12,  13,  14]]],


       [[[ 33,  34,  35],
         [ 36,  37,  38]],

        [[ 36,  37,  38],
         [ 39,  40,  41]],

        [[ 39,  40,  41],
         [ 42,  43,  44]],

        [[ 42,  43,  44],
         [ 45,  46,  47]]],


       [[[ 66,  67,  68],
         [ 69,  70,  71]],

        [[ 69,  70,  71],
         [ 72,  73,  74]],

        [[ 72,  73,  74],
         [ 75,  76,  77]],

        [[ 75,  76,  77],
         [ 78,  79,  80]]],


       [[[ 99, 100, 101],
         [102, 103, 104]],

        [[102, 103, 104],
         [105, 106, 107]],

        [[105, 106, 107],
         [108, 109, 110]],

        [[108, 109, 110],
         [111, 112, 113]]]])

显然这不是我们想要的结果。需要额外申明的是,这个执行操作中,最后一个维度的冒号加与不加是一样的效果,跟numpy.take本质上也是同样的操作,因此就需要使用到numpy中的另外一个接口:take_along_axis,如下是其官方的API文档:

还有相关的使用案例:

需要注意的是,输入的indices必须要跟原始的数据矩阵保持同样的维度,因此在我们自己的案例中,对b进行了扩维,最终的代码如下所示:

In [23]: np.take_along_axis(a,b[:,:,None],axis=1)
Out[23]: 
array([[[  0,   1,   2],
        [  3,   4,   5]],

       [[ 36,  37,  38],
        [ 39,  40,  41]],

       [[ 72,  73,  74],
        [ 75,  76,  77]],

       [[108, 109, 110],
        [111, 112, 113]]])

最后得到的就是我们想要的结果了,并且是直接使用下标无法实现的操作(当然,也可能是我还没研究出来这样的操作)。这里axis设置为1,就表示a的第0个维度和b的第0个维度是一致的取法,也可以理解成全取的意思。

总结概要

Numpy是在Python中用于各种矩阵运算非常强大的工具之一,而快速的通过下标取出所需位置的元素也是numpy所支持的强大功能之一。常规的元素取法都可以通过numpy的下标或者是numpy.take函数来实现,比如array[0,:]可用于取第一条轴的所有元素,array[:,0]可以用于取第二条轴的所有第二个元素,放在一个2维的矩阵里面就分别是取第一行的所有元素和取第一列的所有元素。但是本文更加关注于更高维的矩阵,当我们想从多个维度中取多个元素时,是不太容易直接用下标去取的,比如同时取a[0][0],a[0][1],a[1][1],a[1][2]的话,那么就只能使用numpy所支持的另外一个函数numpy.take_along_axis来实现。

版权声明

本文首发链接为:https://www.cnblogs.com/dechinphy/p/take_along_axis.html

作者ID:DechinPhy

更多原著文章请参考:https://www.cnblogs.com/dechinphy/

打赏专用链接:https://www.cnblogs.com/dechinphy/gallery/image/379634.html

腾讯云专栏同步:https://cloud.tencent.com/developer/column/91958

参考链接

  1. https://numpy.org/doc/stable/reference/generated/numpy.take_along_axis.html#numpy.take_along_axis
posted @ 2022-02-24 17:32  DECHIN  阅读(1356)  评论(0编辑  收藏  举报