最小二乘法
import numpy as np
import scipy as sp
from scipy. optimize import leastsq
import matplotlib. pyplot as plt
% matplotlib inline
def real_func ( x) :
return np. sin( 2 * np. pi* x)
def fit_func ( p, x) :
f = np. poly1d( p)
return f( x)
def residuals_func ( p, x, y) :
ret = fit_func( p, x) - y
return ret
x = np. linspace( 0 , 1 , 10 )
x_points = np. linspace( 0 , 1 , 1000 )
y_ = real_func( x)
y = [ np. random. normal( 0 , 0.1 ) + y1 for y1 in y_]
def fitting ( M= 0 ) :
"""
M 为 多项式的次数
"""
p_init = np. random. rand( M + 1 )
p_lsq = leastsq( residuals_func, p_init, args= ( x, y) )
print ( 'Fitting Parameters:' , p_lsq[ 0 ] )
plt. plot( x_points, real_func( x_points) , label= 'real' )
plt. plot( x_points, fit_func( p_lsq[ 0 ] , x_points) , label= 'fitted curve' )
plt. plot( x, y, 'bo' , label= 'noise' )
plt. legend( )
return p_lsq
p_lsq_0 = fitting( M= 0 )
p_lsq_1 = fitting( M= 1 )
p_lsq_3 = fitting( M= 3 )
p_lsq_9 = fitting( M= 9 )
regularization = 0.0001
def residuals_func_regularization ( p, x, y) :
ret = fit_func( p, x) - y
ret = np. append( ret,
np. sqrt( 0.5 * regularization * np. square( p) ) )
return ret
p_init = np. random. rand( 9 + 1 )
p_lsq_regularization = leastsq(
residuals_func_regularization, p_init, args= ( x, y) )
plt. plot( x_points, real_func( x_points) , label= 'real' )
plt. plot( x_points, fit_func( p_lsq_9[ 0 ] , x_points) , label= 'fitted curve' )
plt. plot(
x_points,
fit_func( p_lsq_regularization[ 0 ] , x_points) ,
label= 'regularization' )
plt. plot( x, y, 'bo' , label= 'noise' )
plt. legend( )
预备知识
np. sin( 2 * np. pi* 3 )
p = np. poly1d( [ 1 , 2 , 3 ] )
print ( np. poly1d( p) )
2
1 x + 2 x + 3
np. linspace( 2.0 , 3.0 , num= 5 )
array( [ 2 . , 2.25 , 2.5 , 2.75 , 3 . ] )
np. random. rand( 3 , 2 )
array( [ [ 0.14022471 , 0.96360618 ] ,
[ 0.37601032 , 0.25528411 ] ,
[ 0.49313049 , 0.94909878 ] ] )
np. random. normal( 3 , 2.5 , size= ( 2 , 4 ) )
array( [ [ - 4.49401501 , 4.00950034 , - 1.81814867 , 7.29718677 ] ,
[ 0.39924804 , 4.68456316 , 4.99394529 , 4.84057254 ] ] )
np. random. rand( 3 , 2 )
array( [ [ 0.14022471 , 0.96360618 ] ,
[ 0.37601032 , 0.25528411 ] ,
[ 0.49313049 , 0.94909878 ] ] )
参考来源
https://github.com/fengdu78/lihang-code/blob/master/%E7%AC%AC01%E7%AB%A0%20%E7%BB%9F%E8%AE%A1%E5%AD%A6%E4%B9%A0%E6%96%B9%E6%B3%95%E6%A6%82%E8%AE%BA/1.Introduction_to_statistical_learning_methods.ipynb .