理解 numpy.rollaxis() 函数

函数声明

先看看 numpy.rollaxis() 函数的定义形式,如下:

rollaxis(a, axis, start=0)

参数 a 通常为 numpy.ndarray 类型,则 a.ndim表示 numpy 数组的维数;

参数 axis 通常为 int 类型,范围为 [0, a.ndim);

参数 start 为 int 类型,默认值为 0,取值范围为 [-a.ndim, a.ndim],如果超过这个范围,则会 raise AxisError。

函数功能

numpy.rollaxis() 函数用于滚动(roll)指定轴(axis)到某位置。这个函数可以用更易理解的函数 numpy.moveaxis(a, source, destination) 代替。但由于 numpy.moveaxis() 函数是在 numpy v1.11 版本新增的,为了与之前的版本兼容,这个函数依旧保留。

具体来说,需要根据 axis 和 normalized start 的比较结果,选择将 axis 滚动到哪个位置上,而其他轴的位置顺序不变。如果 axis 参数值大于或等于 normalized start,则 axis 从后向前滚动,直到 start 位置;如果 axis 参数值小于 normalized start,则 axis 轴从前往后滚动,直到 start 的前一个位置,即 start-1 位置。其中 start 和 normalized start 的对应关系,如下表所示:

start

Normalized start

-(a.ndim+1)

raise AxisError

-a.ndim

0

-1

a.ndim-1

0

0

a.ndim

a.ndim

a.ndim+1

raise AxisError

从表中,可以看出 normalized start 是在 -a.ndim <= start < 0 时, start + a.ndim 的值;在  0 <= start <= a.ndim 时,start 值。

具体的示例及解释,如下所示

import numpy as np

a = np.ones((3,4,5,6))

axis, start = 3, 1
# 因为 3 > 1,所以 axis index 3 移动到 axis index 1(start位置),而其他维度位置不变
print(np.rollaxis(a, axis=axis, start=start).shape)  # (3,6,4,5)
# np.moveaxis 的等价调用
print(np.moveaxis(a, source=axis, destination=start).shape)

axis, start = 2, 0
# 因为 2 > 0,所以 axis index 2 移动到 axis index 0(start位置),而其他维度位置不变
print(np.rollaxis(a, axis, start).shape)  # (5,3,4,6)
# np.moveaxis 的等价调用
print(np.moveaxis(a, axis, start).shape)

axis, start = 1, 4
# 因为 1 < 4,所以 axis index 1 移动到 axis index 3(start-1位置),而其他维度位置不变
print(np.rollaxis(a, axis=axis, start=start).shape)  # (3,5,6,4)
# np.moveaxis 的等价调用
print(np.moveaxis(a, source=axis, destination=start-1).shape)

 为了更好理解这个过程,最后看看该函数在 numpy 中实现的核心代码,如下所示:

def rollaxis(a, axis, start=0):
    """
    Roll the specified axis backwards, until it lies in a given position.

    Parameters
    ----------
    a : ndarray
        Input array.
    axis : int
        The axis to be rolled. The positions of the other axes do not
        change relative to one another.
    start : int, optional
        When ``start <= axis``, the axis is rolled back until it lies in
        this position. When ``start > axis``, the axis is rolled until it
        lies before this position. The default, 0, results in a "complete"
        roll. 

    Returns
    -------
    res : ndarray
        For NumPy >= 1.10.0 a view of `a` is always returned. For earlier
        NumPy versions a view of `a` is returned only if the order of the
        axes is changed, otherwise the input array is returned.

    """
    n = a.ndim
    if start < 0:
        start += n
    msg = "'%s' arg requires %d <= %s < %d, but %d was passed in"
    if not (0 <= start < n + 1):
        raise AxisError(msg % ('start', -n, 'start', n + 1, start))
    if axis < start:
        start -= 1
    if axis == start:
        return a[...]
    axes = list(range(0, n))
    axes.remove(axis)
    axes.insert(start, axis)
    return a.transpose(axes)

参考资料

[1] numpy.rollaxis API reference. https://numpy.org/doc/stable/reference/generated/numpy.rollaxis.html

posted @ 2021-02-28 18:41  klchang  阅读(1466)  评论(0编辑  收藏  举报