1.线性回归是利用数理统计中回归分析,来确定两种或两种以上变量间相互依赖的定量关系的一种统计分析方法。
直观地说,在二维情况下,已知一些点的X,Y坐标,统计条件X与结果Y的关系,画一条直线,让直线离所有点都尽量地近(距离之和最小),用直线抽象地表达这些点,然后对新的X预测新的Y。具体实现一般使用最小二乘法。
2.什么是最小二乘法,我是参考的xieyan0811博主的最小二乘法写的,附上链接http://blog.youkuaiyun.com/xieyan0811/article/details/78562610
最主要的是看懂线性代数矩阵求解方程的公式,图片来自上面博主
#-*- coding:utf-8 -*-
import numpy as np
import matplotlib.pyplot as plt
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
#产生随机的2x12 的ndarray数据
x = np.random.random((2,12))*10
#产生随机数
r = []
for i in range(x.size):
r.append(np.random.rand()*5)
#预想的线性函数
y1 = 2*x+2+np.array(r).reshape((2,12))
#假设 y = a*x +b
#计算公式 Wx = Y,W为a,b参数的矩阵,Y为结果矩阵
#在这里我有24个训练集x,24个训练集y
#训练数据
def train(x,y):
m,n = np.shape(x) #得到输入数据的维度 m, n = 2 x 12
train_x = np.mat(np.ones((m,m*n))) #生成 2 x 24 的矩阵
train_x[1,:] = x.reshape((1,24))[:] #赋值给训练集的x矩阵
train_x = train_x.T # 24 x 2 #转置
train_y = np.mat(np.ones((1,m*n))) #1 x 24
train_y[:] = y.reshape((1,m*n))[:]
train_y = train_y.T #24 x 1
xTx = train_x.T * train_x
xy = train_x.T * train_y
w = xTx.I * xy
return w
#预测输入数据
def predicts(x,w):
m,n = np.shape(x) # m, n = 2,12
train_x = np.mat(np.ones((m,m*n))) # 2 x 24
train_x[1,:] = np.mat(x).reshape((1,-1))[:]
train_x = train_x.T # 24 x 2
results = train_x*w
return results
#预测新的数据,此时输入的是一个单个的数据,比如我输入的是[[5]],即1x1维的一个数据
def predicts2(x,w):
m,n = np.shape(x) # m, n = 2,12
train_x = np.mat(np.ones((m+1,m*n))) # 2 x 24
train_x[1,:] = np.mat(x).reshape((1,-1))[:]
train_x = train_x.T # 24 x 2
results = train_x*w
return results
def main():
global x,y1
w = train(x,y1)
test = predicts(x,w)
result = predicts2([[5]],w)
print 'w',w
ax.scatter(x.tolist(),y1)
ax.scatter(5,result.tolist(),c='red')
ax.plot(x.reshape((1,24)).tolist()[0][:],test.reshape((1,24)).tolist()[0][:] ,color='red')
plt.show()
print 'x_resource*****',x.tolist()
print 'y1_resource*****',y1.tolist()
print 'y1_testResult*****',test.tolist()
print 'test [[5]]*****',result.tolist()
print '误差*****:',np.sqrt((np.array(y1)-np.array(test.reshape((2,12))))**2) #计算误差的损失函数,我用的和平方差
if __name__ == '__main__':
main()
输出
w [[ 5.19725373]
[ 1.83472901]]
x_resource***** [[5.025534604406413, 7.171500015259127, 6.259056500365995, 5.458404244214018, 8.147462884592366, 5.905533426325871, 0.20805685707873778, 2.216424081055975, 6.652866876373937, 4.504519471871173, 1.774855334422124, 7.894633227253828], [8.43255686345111, 8.235615654342839, 9.97308681226296, 7.778263642814868, 4.500385331396345, 4.3763375000402345, 4.297061388188819, 0.31458505158080574, 8.582915571044293, 3.900393754870249, 7.985011146607281, 1.9651864681500597]]
y1_resource***** [[13.390906690094853, 20.10176698958281, 18.12663547495091, 15.49984206421861, 19.412400375679752, 14.57024827084178, 4.60654342729701, 10.52094006927134, 16.565320692568108, 15.506522542731101, 6.729930387113277, 20.710030381156702], [19.30649860262777, 22.047868639854645, 22.989980520577078, 18.970255879765386, 13.225206048172211, 15.148022170299393, 12.842386998364857, 7.032259262934227, 20.666747891259547, 14.090279585069181, 18.07200055650239, 5.97889695138205]]
y1_testResult***** [[14.417747850404082], [18.355012838637553], [16.680926254221497], [15.211946334994982], [20.145640225178113], [16.032307215484167], [5.57898168473221], [9.263791288135433], [17.40346157451029], [13.461826273431557], [8.453632299911684], [19.681766318920445], [20.66871041803396], [20.307376668924626], [23.495165401934116], [19.46825966676768], [13.454241245981596], [13.226647091483724], [13.081196909473647], [5.774432053294327], [20.9445779003749], [12.353419296315355], [19.84758530918942], [8.802838351979416]]
test [[5]]***** [[14.37089877100638]]
误差*****: [[ 1.02684116 1.74675415 1.44570922 0.28789573 0.73323985 1.46205894
0.97243826 1.25714878 0.83814088 2.04469627 1.72370191 1.02826406]
[ 1.36221182 1.74049197 0.50518488 0.49800379 0.2290352 1.92137508
0.23880991 1.25782721 0.27783001 1.73686029 1.77558475 2.8239414 ]]
[Finished in 3.3s]