机器学习基础_单变量线性回归
首先我们需要了解什么是机器学习,我们可以把机器学习理解为函数的变化,例如下面的是受教育年限和工资之间的关系
Unnamed: 0 Education Income
0 1 10.000000 26.658839
1 2 10.401338 27.306435
2 3 10.842809 22.132410
3 4 11.244147 21.169841
4 5 11.645485 15.192634
5 6 12.086957 26.398951
6 7 12.488294 17.435307
7 8 12.889632 25.507885
8 9 13.290970 36.884595
9 10 13.732441 39.666109
10 11 14.133779 34.396281
11 12 14.535117 41.497994
12 13 14.976589 44.981575
13 14 15.377926 47.039595
14 15 15.779264 48.252578
15 16 16.220736 57.034251
16 17 16.622074 51.490919
17 18 17.023411 61.336621
18 19 17.464883 57.581988
19 20 17.866221 68.553714
20 21 18.267559 64.310925
21 22 18.709030 68.959009
22 23 19.110368 74.614639
23 24 19.511706 71.867195
24 25 19.913043 76.098135
25 26 20.354515 75.775218
26 27 20.755853 72.486055
27 28 21.157191 77.355021
28 29 21.598662 72.118790
29 30 22.000000 80.260571
他们的散点图如下
我们可以清楚看到受教育年限和工资之间是线性关系的,可以表示为一个单变量的函数
F(x) = ax + b
x表示学历,F(x)表示工资
只要我们求出了a和b那么我们就可以通过受教育年限来预测工资,这就是机器学习
如何求解a和b?
我们的目标是求出a和b使得预测目标和真实值之间的整体误差最小。
那么怎么定义最小误差?
损失函数用来评价模型的预测值和真实值不一样的程度,损失函数越好,通常模型的性能越好。不同的模型用的损失函数一般也不一样。
我们通过均方差来定义我们的损失函数
优化目标
找到合适的a和b,使得(F(x)-y)^2/n越小越好
使用梯队下降算法
代码如下使用jupyter运行,Income1.csv内容是上面的受教育年限和工资的数据
import tensorflow as tf
import pandas as pd
data = pd.read_csv('Income1.csv')
#查看data的数据
data
import matplotlib.pyplot as plt
#绘制散点图
plt.scatter(data.Education,data.Income)
x = data.Education
y = data.Income
model = tf.keras.Sequential()
#顺序模型,一个输入一个输出,一层层搭建模型
model = tf.keras.Sequential()
#Dense(a,b) a是输出参数的维度,b是输入参数的维度,(1,)表示的是元组形式
#Dense相当于建立模型f(x)=ax+b
model.add(tf.keras.layers.Dense(1,input_shape=(1,)))
model.summary()
#output shape下的None表示的是样本的个数(不需要考虑),1表示的是样本的维度
#param下的2表示的是两个参数,这是Dense层初始化的2个参数,生成ax+b形式然后得到一个输出值,a表示的是权重,b表示的是偏置
#编译 ——配置的过程
model.compile(optimizer='adam',loss='mse')
#optimizer='adam' 表示选择优化方法为'adam'
#loss='mse' 表示选择损失函数为'mse',均方差
#进行模型训练,这是一个循环过程epochs=5000表示循环的次数
history = model.fit(x,y,epochs=5000)
#进行模型的预测,x为现有的数据
model.predict(x)
#查看x的数据类型
type(x)
#输入的形式为pd.Series,所以需要用pd.Series格式来进行预测
model.predict(pd.Series([19]))