最近在做机器学习的作业,要求是实现IRLS算法用于logistic regression, 并且画出来scatter。
先回顾一下IRLS算法,IRLS是iterative reweighted least squares,和OSL相比起来,多了两个单词iterative 和 reweighted。先说一下iterative
为什么要iterative呢?书上(Pattern Recognition and Machine learning by Bishop)的原话是"For logistic regression, there is no longer a closed-form solution, due to the nonlinearity of the logistic sigmoid function", 也就是说这里不像OLS那样一步到位,而是一种online renew的方式去求w。具体的公式如下(略大。。):
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
R是什么呢?R是一个n by n 的matrix(n is the size of your dataset), what is more,R is a diagonal matrix,对角元素Rii = yi*(1-yi).
注意这里的y是每次更新出来的output,y=W*X+W0。这就是我要强调的第二点,reweighted。之所以是reweighted,是因为对角元素在每次更新的时候都是要变化的。
好了, 废话不多说了,下面开始码了:
首先定义一个irls的方法:
maxiter是最大的循环次数,当然如果前后两次的W差值低于我们的threshold,就会break这个loop
def IRLS(y, X, maxiter, w_init&