NumPy 广播
NumPy中,有时两个不同形状的数组之间需要进行运算。
考虑下面的例子。
示例
两个形状相同的数组相乘:
import numpy as np
a = np.array([1,2,3,4,5,6,7])
b = np.array([2,4,6,8,10,12,14])
c = a*b;
print(c)
输出
[ 2 8 18 32 50 72 98]
示例
两个形状不相同的数组相乘:
import numpy as np
a = np.array([1,2,3,4,5,6,7])
b = np.array([2,4,6,8,10,12,14,19])
c = a*b;
print(c)
输出
ValueError: operands could not be broadcast together with shapes (7,) (8,)
上面的例子中,两个数组的形状不同,不能相乘。
NumPy中,要在两个不同形状数组之间进行运算,Python会尝试使用广播机制,广播需要符合一定条件,如果不能广播,则报错。
广播的作用是,把较小数组的形状扩展成与较大数组一致,以便进行运算。
广播的规则
广播机制遵循下列规则:
- 如果两个数的维度数不同,那么小维度数组的形状将会在最左边补1。
- 如果两个数组的形状在任何一个维度都不匹配,那么数组的形状会沿着维度为1的维度扩展,以匹配另外一个数组的形状。
- 如果两个数组的形状在任何一个维度上都不匹配并且没有任何一个维度等于1,那么会引发异常。
让我们看两个个广播的例子。
示例
import numpy as np
a = np.arange(3)
b = np.ones((2,3))
#两个数组的形状为 a.shape=(3,), b.shape=(2,3)
#可以看到,根据规则1, 数组a的维度更小,所以在其左边补1, 变为 a.shape -> (1,3)
#根据规则2, 第一个维度不匹配,因此扩展这个维度以匹配数组:a.shape -> (2,3)
#现在两个数组的形状都是(2,3),可以进行运算了:
print(a + b)
输出
[[1. 2. 3.]
[1. 2. 3.]]
示例
两个数组均需要广播的示例:
import numpy as np
a = np.arange(3).reshape((3,1))
b = np.arange(3)
#两个数组的形状为:a.shape=(3,1), b.shape=(3,)
#规则1告诉我们,需要用1将b的形状补全:b.shape -> (1,3)
#规则2告诉我们,需要更新这两个数组的维度来相互匹配:a.shape -> (3,3), b.shape -> (3,3)
#因为结果匹配,所以这两个形状是兼容的:
print(a + b)
输出
[[0 1 2]
[1 2 3]
[2 3 4]]