原理见ng教程
code
"""
Created on Tue Sep 01 17:08:36 2015
@author: young
"""
import os
import numpy as np
import matplotlib.pyplot as plt
def getData(filename):
data = []
with open(filename) as f:
alldata = f.readlines()
print alldata
for line in alldata:
line = line.strip('\n')
x0 = line.split('\t')[0]
living_area = line.split('\t')[1]
price = line.split('\t')[2]
temp = []
temp.append(x0)
temp.append(living_area)
temp.append(price)
data.append(temp)
return data
def genData(data):
alpha = 0.000005
theta = np.ones(2)
X = np.zeros(shape = (len(data),2))
Y = np.zeros(shape = (len(data)))
for i in range(0,len(data)):
X[i][0] = data[i][0]
X[i][1] = data[i][1]
Y[i] = data[i][2]
iteration_num = 10000
cost = []
xTrans = X.transpose()
k = 0
while k < iteration_num:
hypothesis = np.dot(X,theta)
loss = hypothesis - Y
temp = loss**2
J = np.sum(temp)
cost.append(J)
gradient = np.dot(xTrans,loss)
theta = theta - alpha * gradient
k = k + 1
plt.figure(1)
cx = range(len(cost))
plt.plot(cx,cost)
plt.figure(2)
plt.plot(X[:,1],Y,'b.')
x = np.arange(0,2,0.1)
y = x * theta[1] + theta[0]
plt.plot(x,y)
plt.margins(0.2)
plt.show()
if __name__ == '__main__':
data = getData('ex1.txt')
genData(data)