理解 numpy.rollaxis() 函数
函数声明
先看看 numpy.rollaxis() 函数的定义形式,如下:
rollaxis(a, axis, start=0)
参数 a 通常为 numpy.ndarray 类型,则 a.ndim表示 numpy 数组的维数;
参数 axis 通常为 int 类型,范围为 [0, a.ndim);
参数 start 为 int 类型,默认值为 0,取值范围为 [-a.ndim, a.ndim],如果超过这个范围,则会 raise AxisError。
函数功能
numpy.rollaxis() 函数用于滚动(roll)指定轴(axis)到某位置。这个函数可以用更易理解的函数 numpy.moveaxis(a, source, destination) 代替。但由于 numpy.moveaxis() 函数是在 numpy v1.11 版本新增的,为了与之前的版本兼容,这个函数依旧保留。
具体来说,需要根据 axis 和 normalized start 的比较结果,选择将 axis 滚动到哪个位置上,而其他轴的位置顺序不变。如果 axis 参数值大于或等于 normalized start,则 axis 从后向前滚动,直到 start 位置;如果 axis 参数值小于 normalized start,则 axis 轴从前往后滚动,直到 start 的前一个位置,即 start-1 位置。其中 start 和 normalized start 的对应关系,如下表所示:
start |
Normalized start |
-(a.ndim+1) |
raise AxisError |
-a.ndim |
0 |
⋮ |
⋮ |
-1 |
a.ndim-1 |
0 |
0 |
⋮ |
⋮ |
a.ndim |
a.ndim |
a.ndim+1 |
raise AxisError |
从表中,可以看出 normalized start 是在 -a.ndim <= start < 0 时, start + a.ndim 的值;在 0 <= start <= a.ndim 时,start 值。
具体的示例及解释,如下所示
import numpy as np
a = np.ones((3,4,5,6))
axis, start = 3, 1
# 因为 3 > 1,所以 axis index 3 移动到 axis index 1(start位置),而其他维度位置不变
print(np.rollaxis(a, axis=axis, start=start).shape) # (3,6,4,5)
# np.moveaxis 的等价调用
print(np.moveaxis(a, source=axis, destination=start).shape)
axis, start = 2, 0
# 因为 2 > 0,所以 axis index 2 移动到 axis index 0(start位置),而其他维度位置不变
print(np.rollaxis(a, axis, start).shape) # (5,3,4,6)
# np.moveaxis 的等价调用
print(np.moveaxis(a, axis, start).shape)
axis, start = 1, 4
# 因为 1 < 4,所以 axis index 1 移动到 axis index 3(start-1位置),而其他维度位置不变
print(np.rollaxis(a, axis=axis, start=start).shape) # (3,5,6,4)
# np.moveaxis 的等价调用
print(np.moveaxis(a, source=axis, destination=start-1).shape)
为了更好理解这个过程,最后看看该函数在 numpy 中实现的核心代码,如下所示:
def rollaxis(a, axis, start=0):
"""
Roll the specified axis backwards, until it lies in a given position.
Parameters
----------
a : ndarray
Input array.
axis : int
The axis to be rolled. The positions of the other axes do not
change relative to one another.
start : int, optional
When ``start <= axis``, the axis is rolled back until it lies in
this position. When ``start > axis``, the axis is rolled until it
lies before this position. The default, 0, results in a "complete"
roll.
Returns
-------
res : ndarray
For NumPy >= 1.10.0 a view of `a` is always returned. For earlier
NumPy versions a view of `a` is returned only if the order of the
axes is changed, otherwise the input array is returned.
"""
n = a.ndim
if start < 0:
start += n
msg = "'%s' arg requires %d <= %s < %d, but %d was passed in"
if not (0 <= start < n + 1):
raise AxisError(msg % ('start', -n, 'start', n + 1, start))
if axis < start:
start -= 1
if axis == start:
return a[...]
axes = list(range(0, n))
axes.remove(axis)
axes.insert(start, axis)
return a.transpose(axes)
参考资料
[1] numpy.rollaxis API reference. https://numpy.org/doc/stable/reference/generated/numpy.rollaxis.html.