TensorFlow入门:线性回归

#coding=utf-8
import tensorflow as tf
import numpy as np
import pandas as pd
import cv2 as cv
import os
import matplotlib
matplotlib.use('TkAgg')
from matplotlib import pyplot as plt

def draw_line(x_data, y_data, k, c):
	#这是离散的点图,同时含有高斯噪声
    plt.plot(x_data, y_data, "ro", label="points")
    #X和Y的范围在0-15,因为之前构造的X范围是从0-14
    plt.axis([0, 15, 0, 15])
    #准备一个列表
    curr_y = []
    #从我们通过训练模型得到的斜率和截距画出一个线性回归的曲线
    for i in range(len(x_data)):
        curr_y.append(x_data[i]*k+c)
    plt.plot(x_data, curr_y, label="fit-line")
    #设置显示label
    plt.legend()
    #显示
    plt.show()



def line_regression():
    # 定义线性方程y=wx+b的斜率和截距的变量,后面我们会通过训练得到这2个值
    w = tf.Variable(0.1, dtype=tf.float32)
    b = tf.Variable(0.1, dtype=tf.float32)
    # 产生30个随机数,数值从0-14,类型是浮点数
    X = tf.random_uniform([30], minval=0, maxval=14, dtype=tf.float32)
    """
    1.
    这一步会生成一个含有30对坐标(x, y), 同时所有的坐标都满足线性方程
    Y = 1.2X + 0.8
    2.
    我们的目的就是要将w逼近1.2,b逼近0.8
    3.
    最后line_model是一个包含30个元素的一维矩阵
    """
    line_model = tf.add(tf.multiply(X, 1.2), 0.8)
    # 生成30个满足高斯分布的随机噪声
    noise = tf.random_normal([30], mean=0, stddev=0.5, dtype=tf.float32)
    # 将30个随机噪声叠加在line_model(Y分量)上
    y_labels = line_model + noise
    """
    1.
    以上所有操作的目的就是为了构造用于测试线性回归的输入数据参数
    2.
    现在我们得到了30组能够拟合到线性方程的输入数据参数
    """

    # 通过输入参数和训练模型(训练w和b)得到输出
    y_ = tf.add(tf.multiply(w, X), b)
    # 计算平方
    diff = tf.square(y_ - y_labels)
    # 计算均值,目标是均值最小
    loss = tf.reduce_mean(diff)

    # 设置学习力
    optimizer = tf.train.GradientDescentOptimizer(0.01)
    # 设置优化器的优化方向
    step = optimizer.minimize(loss)
    # 初始化tensorflow的变量
    init = tf.global_variables_initializer()
    # 为tensorflow的会话设置别名
    with tf.Session() as sess:
        # run变量
        sess.run(init)
        # 训练次数设置
        for i in range(5000):
            sess.run(step)
            # 每隔多少次打印一次均值,这个值应当逐步收敛
            if (i + 1) % 100 == 0:
                curr_loss = sess.run(loss)
                print("current loss : ", curr_loss)
        # 训练结束后,得到训练后的变量w和b
        k, c, y = sess.run([w, b, y_])
        # 得到训练的输入参数,用于作图
        x_input, y_input = sess.run([X, y_labels])
        # 打印拟合后线性方程的斜率和截距
        print(k, c)
        draw_line(x_input, y_input, k, c)

line_regression()

结果如下:

...
('current loss : ', 0.46891606)
('current loss : ', 0.27399457)
('current loss : ', 0.23788354)
('current loss : ', 0.28854463)
('current loss : ', 0.27047512)
('current loss : ', 0.19254531)
('current loss : ', 0.17524074)
('current loss : ', 0.3638595)
('current loss : ', 0.12371392)
('current loss : ', 0.42371178)
('current loss : ', 0.22886035)
('current loss : ', 0.26392117)
('current loss : ', 0.29713854)
('current loss : ', 0.29254925)
('current loss : ', 0.32334834)
('current loss : ', 0.17560533)
('current loss : ', 0.20040913)
('current loss : ', 0.40767363)
('current loss : ', 0.29662272)
(1.220254, 0.8045528)

图片展示:
在这里插入图片描述
很明显,k逼近到了1.2,b已经逼近到了0.8

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值