广播
广播是一个强大的机制,它允许Numpy在运算时可以处理不同种类型的数组。我们通常会使用规模较小的数组和较大的数组,并且希望能够使用小数组在大数组上直接进行操作。
举个栗子,假设我们想要在一个矩阵的每一行加上一个常向量,我们可以这么做:
import numpy as np
# 我们欲在矩阵x的每一行加上向量v
# 将结果储存在矩阵y里
x = np.array([[1,2,3], [4,5,6], [7,8,9], [10, 11, 12]])
v = np.array([1, 0, 1])
y = np.empty_like(x) # 建立一个空矩阵,它具有和x同样的形状(shape)
# 使用一个显式循环来将向量v加在矩阵x的每一行上
for i in range(4):
y[i, :] = x[i, :] + v
# y的结果如下
# [[ 2 2 4]
# [ 5 5 7]
# [ 8 8 10]
# [11 11 13]]
print(y)
这确实好使。然而当矩阵x特别大时,在python中计算一个显式循环可能会花费大量的时间。注意到将v加在x的每一行上等同于先通过垂直堆砌v形成一个矩阵vv,然后将x与vv采取元素与元素对应的相加。我们可以如下实现上述操作:
import numpy as np
# 我们欲在矩阵x的每一行加上向量v
# 将结果储存在矩阵y里
x = np.array([[1,2,3], [4,5,6], [7,8,9], [10, 11, 12]])
v = np.array([1, 0, 1])
vv = np.tile(v, (4, 1)) # 将4个v竖直堆砌在一起
print(vv) # 输出 "[[1 0 1]
# [1 0 1]
# [1 0 1]
# [1 0 1]]"
y = x + vv # 将x与vv元素与元素对应相加
print(y) # 输出 "[[ 2 2 4]
# [ 5 5 7]
# [ 8 8 10]
# [11 11 13]]"
在Numpy的广播机制下,我们可以无需真的将v复制多次,就能完成计算。考虑下面这个使用了广播的版本:
import numpy as np
# 我们欲在矩阵x的每一行加上向量v
# 将结果储存在矩阵y里
= np.array([[1,2,3], [4,5,6], [7,8,9], [10, 11, 12]])
v = np.array([1, 0, 1])
y = x + v # 使用广播机制将v加在x的每一行上
print(y) # 输出 "[[ 2 2 4]
# [ 5 5 7]
# [ 8 8 10]
# [11 11 13]]"
y = x + v 这行代码能够正常运行,即使x的形状是(4,3),v的形状是(3,),广播机制保证了它的正常运行;在这个机制的作用下,v就好像是一个4行3列的矩阵,其中每一行都是一样的v,然后使用元素对元素的求和。
对两个数组进行广播操作遵循以下规则:
如果两个数组的阶数不一样,将低阶数组在缺少的维度上的长度定义为1,直到二者各维度具有相同的长度。
当两个数组在一个维度上长度相同,或是其中一个数组在那个维度上的长度是1,它们被称为在那个维度上匹配。
当两个数组在所有维度都匹配时,它们可以使用广播操作。
广播操作后,每个数组的形状就如同输入数组的形状的最小公倍数。
在任意维度上,如果一个数组的长度是1,而其他数组的长度比1大,则第一个数组的表现就如同它沿着那个维度复制后的数组一样。
如果这个解释让你一头雾水的话,可以尝试阅读说明文档以及这篇解读。
支持广播功能的函数被称为通用函数。你可以在说明文档中找到所有函数的列表。
下面是广播功能的几个例子:
import numpy as np
# 计算向量的外积:
v = np.array([1,2,3]) # v的形状是(3,)
w = np.array([4,5]) # w的形状是(2,)
# 为了计算外积,我们需要先将v改成列向量,其形状变成(3, 1)
# 我们接下来可以使用广播功能,其结果是v和w的外积,形状为(3, 2)
# [[ 4 5]
# [ 8 10]
# [12 15]]
print(np.reshape(v, (3, 1)) * w)
# 为矩阵的每一行加上一个向量:
x = np.array([[1,2,3], [4,5,6]])
# x的形状是(2, 3),v的形状是(3,) ,因此它们广播后结果的形状是(2, 3)
# 结果是
# [[2 4 6]
# [5 7 9]]
print(x + v)
# 为矩阵的每一列加上一个向量:
# x的形状是(2, 3),w的形状是(2,)
# 如果我们将x的形状转换为(3, 2),那么它可以和w生成形状为(3, 2)的结果。将这个结果转换为(2, 3),就是矩阵x在每一列加上w的结果。
# [[ 5 6 7]
# [ 9 10 11]]
print((x.T + w).T)
# 另一种解法是将w重构,形成形状为(2, 1)的列向量。
# 我们可以直接对二者进行广播操作,生成同样的结果。
print(x + np.reshape(w, (2, 1)))
# 矩阵乘一个常数:
# x的形状是(2, 3)。Numpy将标量视为形状为()的数组;
# 它们通过广播操作,得到(2, 3)的数组:
# [[ 2 4 6]
# [ 8 10 12]]
print(x * 2)
广播可以显著地提高你的代码的精简性和运算速度,你应该尽可能多地使用它。
Numpy 说明文档
这个简单的概述触及到很多关于Numpy你需要了解的知识,但这些远远不是全部。查阅Numpy的参考文档可以得到更多关于Numpy的知识。
译者注:CS231N关于Numpy的部分到此结束了,正如作者所说,我们看到的也只是Numpy强大功能的冰山一角,想要了解更多的操作可以参考以上链接。由于是第一次进行翻译,难免会有很多的疏漏,希望看到这篇文章的大家能够在评论区多给予指导。
谢谢~( •̀ ω •́ )✧