完整代码
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
#获取一个文本的数据(即本例中的数据集)
path = r'D:\Project\Pycharm Project\py2022\ExData\ex1data1.txt'#数据集的存储地址
data = pd.read_csv(path, header = None, names = ['Population', 'Profit'])#定义获取数据的形式,一列为人口,一列为经济利润
data.head()#head方法用来显示表格前五行的数据
#显示并检查数据
data.plot(kind='scatter', x='Population', y='Profit', figsize=(12,8))
plt.show()
# 接下来是定义计算代价函数
def cost_func(x, y, w, b):
#获得每一点数据的误差(成本)
cost_matrix = np.power(((x * w + b) - y), 2) # 这个就是矩阵相乘得到一个n*1维的矩阵与n*1维的y矩阵进行相减的平方
return np.sum(cost_matrix) / (2 * len(x)) # 利用sum函数将cost矩阵所有元素加起来,再除以2m(m为总共的数据数量)
# 计算迭代之后的w和b参数,w为斜率,b为截距,alpha为学习率
def compute_new_wb(x, y, w, b, alpha):
m = len(x)
dj_dw = 0
dj_db = 0 #用于计算梯度
for i in range(m):
#老参数w和b所计算的所有y值(y0,y1,y2 ……)
f_wb_i = w * x[i] + b
#偏导之后的第i个和项(具体为什么这样可以看吴恩达老师讲解线性回归的视频)
dj_dw_i = (f_wb_i - y[i]) * x[i]
dj_db_i = (f_wb_i - y[i])
##进行累加
dj_dw = dj_dw + dj_dw_i
dj_db = dj_db + dj_db_i
dj_dw = (1 / m) * dj_dw
dj_db = (1 / m) * dj_db
#此时的dj_dw,dj_db就为老参数情况下的偏导
return w - alpha * dj_dw, b - alpha * dj_db #返回新的w和b
#获取X矩阵和Y矩阵,为了方便计算成本函数,进行转置变成了列向量
X = np.matrix(data['Population'].values).T
Y = np.matrix(data['Profit'].values).T
print(compute_new_wb(X, Y, 0, 0, 0.001)) #测试一下w=0,b=0迭代之后的w和b
#初始将w和b设为1,学习率设为0.01
w = 1
b = 1
alpha = 0.01
#经过1500次迭代
for i in range(1500):
w, b = compute_new_wb(X, Y, w, b, alpha)
print('最后计算出的w和b:', w, ',', b)
print('最后的成本函数', cost_func(X, Y, w, b))
# 接下来就是画图
# 在人口的最小值和最大值之间取100个点,返回一个矩阵
x = np.linspace(data.Population.min(), data.Population.max(), 100)
#w[0, 0]为数值,w[0][0]则为列表
y = w[0, 0] * x + b[0, 0] # 获得y矩阵
fig, ax = plt.subplots(figsize=(12, 8))
ax.plot(x, y, 'r', label='Prediction')#使用plot方法绘制预测直线
ax.scatter(data.Population, data.Profit, label='Training Data')
ax.legend(loc = 2)
ax.set_xlabel('Population')
ax.set_ylabel('Profit')
ax.set_title('Predicted Profit vs. Population Size')
plt.show()
#%%预测人口3500 和 7000时的经济收益
prediction1 = w[0, 0] * 3.5 + b[0, 0]
print("人口为3500时,小吃摊经济规模:", prediction1)
prediction2 = w[0, 0] * 7.0 + b[0, 0]
print("人口为7000时,小吃摊经济规模:", prediction2)
代码输出
打印出来的数据集
整个数据集分布图(注意:数据集必须下载到自己的本地,可以搜索其他博主https://blog.youkuaiyun.com/qq_39435411/article/details/109763239)
你也可以一开始改变w和b的值,例如w = 0,b = 0,不过最后预测出来的会有很小的误差