X = np.arange(12).reshape((3, 4))
row = np.array([0, 1, 2])
mask = np.array([1, 0, 1, 0], dtype=bool)
X[row[:, np.newaxis], mask]
解释说明:
row[:, np.newaxis]*mask
步骤:
row[:, np.newaxis] shape: (3,1)
mask shape (4)
根据广播的第一条规则,mask在最最边扩充维度1 shap变成(1,4)
根据广播的第二条规则,当维度为1时扩充至另一个操作数与之对应的维数,则row,mask都扩充维度至(3,4)
row为:
array([[0, 0, 0, 0],
[1, 1, 1, 1],
[2, 2, 2, 2]])
mask为:
array([[1, 0, 1, 0],
[1, 0, 0, 1],
[1, 0, 1, 0]])
此时,将两个操作数对应位置的数值相乘,得:
array([[0, 0, 0, 0],
[1, 0, 1, 0],
[2, 0, 2, 0]])
综上所述:
X[row[:, np.newaxis], mask]的结果为:
array([[ 0, 2],
[ 4, 6],
[ 8, 10]])