Numpy中的对于数组间的算术运算采用“元素一 一对应”的计算机制,因而一般要求两个数组的形状相同才能进行数组间的算术运算,但是在某些情况中,Numpy中允许符合一定规则的不同形状的数组进行算术运算。广播的核心规则如下:
- 两数组的后缘维度轴长度相同(满足低维数组和高维数组的某一内层元素形状相同,低维数组才能被拓展成和高维数组一样的形状)
- 两数组的后缘维度中有任意个长度是1或缺失
- 输出数组的形状是输入数组形状的各个维度上的最大值
上面的规则可能看起来比较难懂,简单来说就是低维数组要想被拓展成和高维数组一样的形状,从后往前比较时两数组时,低维数组的所有维度要么等于高维数组的对应维度,要么该维度大小值为1或缺失,数组广播将会在维度为1或缺失的维度方向上进行。举例说明如下:
下面是一些广播的例子:
举例1:
In:
a = np.array([1,2,3,4])
b = 3
print(a*b)
Out:
[ 3 6 9 12]
标量和数组的乘积运算即是数组广播的最简单的一个例子,上诉过程中标量b被“虚拟”地扩展为数组[3,3,3,3]然后才和数组a[1,2,3,4]进行乘法运算,以满足Numpy中两数组间“元素一 一对应”的运算规则。
举例2:
In:
a = np.array([[1,2,3,4],
[5,6,7,8],
[9,10,11,12],
[13,14,15,16]])
b = np.array([1,2,3,4])
print(a.shape)
print(b.shape)
print(a+b)
print((a+b).shape)
Out:
a.shape: (4, 4)
b.shape: (4,)
运算结果为:
[[ 2 4 6 8]
[ 6 8 10 12]
[10 12 14 16]
[14 16 18 20]]
(a+b).shape: (4, 4)
上诉一维数组和二维数组的加法运算中,b数组其实和a数组在第0维上的一个内层元素形状相同,因而b数组可以被"虚拟"地拓展为一个4x4的二维数组,相当于被"纵向拉长了",然后再和a数组进行加法运算。
更多例子:
In:
a = np.ones((3,4,2))
b = np.ones((4,2))
print("a:\n",a)
print("b:\n",b)
print("a.shape: ",a.shape)
print('b.shape: ',b.shape)
print("a+b\n",a+b)
print("(a+b).shape: ",(a+b).shape)
Out:
a:
[[[1. 1.]
[1. 1.]
[1. 1.]
[1. 1.]]
[[1. 1.]
[1. 1.]
[1. 1.]
[1. 1.]]
[[1. 1.]
[1. 1.]
[1. 1.]
[1. 1.]]]
b:
[[1. 1.]
[1. 1.]
[1. 1.]
[1. 1.]]
a.shape: (3, 4, 2)
b.shape: (4, 2)
a+b
[[[2. 2.]
[2. 2.]
[2. 2.]
[2. 2.]]
[[2. 2.]
[2. 2.]
[2. 2.]
[2. 2.]]
[[2. 2.]
[2. 2.]
[2. 2.]
[2. 2.]]]
(a+b).shape: (3, 4, 2)
上诉例子中,数组b(4,2)实际上是数组a(3,4,2)中的一个第0维度上的一个元素,满足"两数组的后缘维度轴长度相同(满足低维数组和高维数组的某一内层元素形状相同,低维数组才能被拓展成和高维数组一样的形状)"规则,因而b数组能被“虚拟”地拓展成为(3,4,2)的形状,完成和数组a的加法运算。
In:
a = np.ones((3,2,1,4,2))
b = np.ones((1,3,4,1))
print("a.shape: ",a.shape)
print('b.shape: ',b.shape)
print("(a+b).shape: ",(a+b).shape)
Out:
a.shape: (3, 2, 1, 4, 2)
b.shape: (1, 3, 4, 1)
(a+b).shape: (3, 2, 3, 4, 2)
上述例子中,从后往前依次比较a、b数组的维度发现,要么对应相等,要么为1,满足广播的条件因而能进行算术运算。结果数组的各维度取决于输入数组的各维度的最大值。
不满足广播规则的一些例子:
In:
a = np.ones((2,3,4))
b = np.ones((4,4))
print(a)
print(b)
print("a.shape: ",a.shape)
print('b.shape: ',b.shape)
print("(a+b).shape: ",(a+b).shape)
Out:
[[[1. 1. 1. 1.]
[1. 1. 1. 1.]
[1. 1. 1. 1.]]
[[1. 1. 1. 1.]
[1. 1. 1. 1.]
[1. 1. 1. 1.]]]
[[1. 1. 1. 1.]
[1. 1. 1. 1.]
[1. 1. 1. 1.]
[1. 1. 1. 1.]]
a.shape: (2, 3, 4)
b.shape: (4, 4)
Traceback (most recent call last):
ValueError: operands could not be broadcast together with shapes (2,3,4) (4,4)
上述例子中,b数组的形状为(4,4)不是a数组的任意维度上的一个元素,且从后往前比较时,对应维度大小不同且不等于1,因而不符合广播规则,不能进行广播后完成运算。
总结:Numpy中数组的广播规则其实就是一句话,要么低维数组和高维数组的某一内层元素形状相同,要么从后往前依次比较两数组的维度时,维度大小对应相等或为1。