看了诸多关于Numpy广播原则的讲解,少数一针间血,还有不部分则越讲越绕;
如果要我用一句话讲Numpy的广播原则,我会说;
"从右往左,要么相同,要么为1。"
这是什么意思呢?从右往左,这是特别指代在看两个不同的Numpy.array的shape时的观察原则;
我们知道,在Numpy中, 对于一个数组的shape, 例如shape=(1, 2, 3), 从右往左维度是依次递增的。
所以,从右往左看,实际上就是在看两个Numpy数组的最低维度。
"要么相同, 要么为1", 就是指从低维到高维,两个数组的维度大小必须要么相同,要么有一方为1;
例如,对于两个不同的Numpy数组: A.shape=(3, 1,5), B.shape=(8, 5)
从右往左看,两个数组最低维度的大小均为5;
再看高一级维度,发现两个数组虽然这一维度大小不同,A为1, B为8,但是有一方的大小为1;
此时,我们发现不能再看B的更高维度了,因为B只有两个维度。于是我们就说,这两个Numpy数组是可广播的,因为完全符合"从右往左,要么相同,要么为1"。
那么这样两个数组广播后进行运算,运算后新数组的shape是什么呢?
还是一句话:
"从右往左,有则取大,无则取余";
例如还是上面的A和B数组,从右往左看他们的shape都有元素,故取最大值: max(5, 5)得到5;
继续往左,A,B双方依然有元素, 故取最大值: max(1, 8)得到8;
继续往左,此时发现B已经没有元素了,故取"余", 即A中的3;
最后将这些元素从右往左组合起来,就变成了: (3, 8, 5)
下面用一段简明扼要的用一段Python代码模拟这一运作过程:
# 使用python模拟广播机制的判断
def broadcast_shape(shape1: tuple, shape2: tuple) -> tuple:
result_shape: list = list(shape1 if len(shape1) > len(shape2) else shape2)
# 从最低维度开始,从右往左判断各个维度是否可以广播
dim: int = len(result_shape) - 1
for dim1, dim2 in zip(shape1[::-1], shape2[::-1]):
if dim1 != dim2 and dim1 != 1 and dim2 != 1:
raise ValueError("Shape mismatch: objects cannot be broadcast together")
result_shape[dim] = max(dim1, dim2)
dim -= 1
return tuple(result_shape)
if __name__ == "__main__":
shape1 = (3, 1, 5)
shape2 = (8, 5)
print(broadcast_shape(shape1, shape2))
那么,为什么Numpy广播要满足上面的原则呢?
我们可以用多重嵌套列表来模拟多维数组,最低维度就是最内层的列表;
想象一下,唯有当最内层列表的元素相同时,进一步扩展到外层列表,他们的元素也才有机会对应上;
而当最内层列表一方只有一个元素时,则可以很方便的将这一个元素通过复制达成最内层列表元素相同;
还是基于上面的例子,为了方便表示将B的维度改成(2, 5), 不改变A、B可以广播的特性。
笔者认为上述图片已经十分明了的解释了广播的过程。读者也可以自行尝试一下让两个可以广播的数组,从两个矩阵大小无法匹配的最低维度开始广播,保证他们这一维度的大小相同,从而体会广播机制的运作过程。