1) 巧用 where函数
where函数是numpy的内置,也是一个非常有用的函数,提供了快速并且灵活的计算功能。
def f_norm_1(data, estimate):
residule = 0
for row_index in range(data.shape[0]):
for column_index in range(data.shape[1]):
if data[row_index][column_index] != 0:
residule += (data[row_index][column_index] - estimate[row_index][column_index]) ** 2
return residule
def f_norm_2(data, estimate)
return sum(where(data != 0, (data-estimate) **2, 0))
这两段代码完成同样的功能,计算两个矩阵的差,然后将残差进行平方,注意,因为我需要的是考虑矩阵稀疏性,所以不能用内置的norm,函数1是我用普通的python写的,不太复杂,对于规模10*10的矩阵,计算200次耗时0.15s,函数2使用了where函数和sum函数,这两个函数都是为向量计算优化过的,不仅简介,而且耗时仅0.03s, 快了有五倍,不仅如此,有同学将NumPy和matlab做过比较,NumPy稍快一些,这已经是很让人兴奋的结果。
本文对比了两种计算矩阵残差平方的方法:一种是传统的 Python 循环方式,另一种是利用 NumPy 的 where 和 sum 函数。通过实验发现,后者不仅代码更简洁,而且运行速度也更快,尤其在处理大规模稀疏矩阵时优势明显。
2761

被折叠的 条评论
为什么被折叠?



