数组过滤在机器学习中的应用举例(python代码)
机器学习过程中,计算误差(矩阵)时经常会遇到要在矩阵中进行标签值与预测值之间的比较,可以使用数据过滤方法来解决:
error = mat(ones((5, 1)))
print("比较之前:\n", error.T)
label = mat(([1, -1, 1, -1, 1])).T
pred = mat(([1, -1, -1, 1, 1])).T
error[pred == label] = 0 # 如果预测值与标签值一致时,误差为0值
print("比较之后:\n",error.T)
结果显示为:
比较之前:
[[1. 1. 1. 1. 1.]]
比较之后:
[[0. 0. 1. 1. 0.]]
标签值(label)中的第2和第3个值与预测值(pred)是不一致的,因此error矩阵中只有这两个值是"1", 其余值均为"0"。