#coding=utf-8
import tensorflow as tf
import numpy as np
import pandas as pd
import cv2 as cv
import os
import matplotlib
matplotlib.use('TkAgg')
from matplotlib import pyplot as plt
def draw_line(x_data, y_data, k, c):
#这是离散的点图,同时含有高斯噪声
plt.plot(x_data, y_data, "ro", label="points")
#X和Y的范围在0-15,因为之前构造的X范围是从0-14
plt.axis([0, 15, 0, 15])
#准备一个列表
curr_y = []
#从我们通过训练模型得到的斜率和截距画出一个线性回归的曲线
for i in range(len(x_data)):
curr_y.append(x_data[i]*k+c)
plt.plot(x_data, curr_y, label="fit-line")
#设置显示label
plt.legend()
#显示
plt.show()
def line_regression():
# 定义线性方程y=wx+b的斜率和截距的变量,后面我们会通过训练得到这2个值
w = tf.Variable(0.1, dtype=tf.float32)
b = tf.Variable(0.1, dtype=tf.float32)
# 产生30个随机数,数值从0-14,类型是浮点数
X = tf.random_uniform([30], minval=0, maxval=14, dtype=tf.float32)
"""
1.
这一步会生成一个含有30对坐标(x, y), 同时所有的坐标都满足线性方程
Y = 1.2X + 0.8
2.
我们的目的就是要将w逼近1.2,b逼近0.8
3.
最后line_model是一个包含30个元素的一维矩阵
"""
line_model = tf.add(tf.multiply(X, 1.2), 0.8)
# 生成30个满足高斯分布的随机噪声
noise = tf.random_normal([30], mean=0, stddev=0.5, dtype=tf.float32)
# 将30个随机噪声叠加在line_model(Y分量)上
y_labels = line_model + noise
"""
1.
以上所有操作的目的就是为了构造用于测试线性回归的输入数据参数
2.
现在我们得到了30组能够拟合到线性方程的输入数据参数
"""
# 通过输入参数和训练模型(训练w和b)得到输出
y_ = tf.add(tf.multiply(w, X), b)
# 计算平方
diff = tf.square(y_ - y_labels)
# 计算均值,目标是均值最小
loss = tf.reduce_mean(diff)
# 设置学习力
optimizer = tf.train.GradientDescentOptimizer(0.01)
# 设置优化器的优化方向
step = optimizer.minimize(loss)
# 初始化tensorflow的变量
init = tf.global_variables_initializer()
# 为tensorflow的会话设置别名
with tf.Session() as sess:
# run变量
sess.run(init)
# 训练次数设置
for i in range(5000):
sess.run(step)
# 每隔多少次打印一次均值,这个值应当逐步收敛
if (i + 1) % 100 == 0:
curr_loss = sess.run(loss)
print("current loss : ", curr_loss)
# 训练结束后,得到训练后的变量w和b
k, c, y = sess.run([w, b, y_])
# 得到训练的输入参数,用于作图
x_input, y_input = sess.run([X, y_labels])
# 打印拟合后线性方程的斜率和截距
print(k, c)
draw_line(x_input, y_input, k, c)
line_regression()
结果如下:
...
('current loss : ', 0.46891606)
('current loss : ', 0.27399457)
('current loss : ', 0.23788354)
('current loss : ', 0.28854463)
('current loss : ', 0.27047512)
('current loss : ', 0.19254531)
('current loss : ', 0.17524074)
('current loss : ', 0.3638595)
('current loss : ', 0.12371392)
('current loss : ', 0.42371178)
('current loss : ', 0.22886035)
('current loss : ', 0.26392117)
('current loss : ', 0.29713854)
('current loss : ', 0.29254925)
('current loss : ', 0.32334834)
('current loss : ', 0.17560533)
('current loss : ', 0.20040913)
('current loss : ', 0.40767363)
('current loss : ', 0.29662272)
(1.220254, 0.8045528)
图片展示:
很明显,k逼近到了1.2,b已经逼近到了0.8