TensorFlow 线性回归

本文通过Python和TensorFlow实现了一个简单的线性回归模型,介绍了如何生成模拟数据集、定义模型参数、设置成本函数及优化器,并展示了训练过程及最终结果。
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

'''
确定2种变量之间的关系
'''
'''
获取线性回归训练参数
'''


def getTrainForLineNormal(size=100, W_PARAM=0.1, W_B=0.03):
    arr_x = np.random.random(100)
    noise = np.random.normal(0, 0.01, arr_x.shape)
    arr_y = arr_x * W_PARAM + W_B + noise
    return arr_x, arr_y

'''
显示坐标
'''


def showPoint(x, y):
    plt.scatter(x, y)
    plt.show();


# y = w*x +b
def train():
    train_x, train_y = getTrainForLineNormal(1000)
    # 均匀分布
    W = tf.Variable(tf.random_uniform([1]))
    b = tf.Variable(tf.zeros([1]))
    # 将训练的X代入函数
    y = W * train_x + b

    # 用预测的Y值和真实的Y值进行平方求平均数
    cost = tf.reduce_mean(tf.square(y - train_y))
    # 梯度下降 0.08 下降的幅度 尽可能小一点
    optimizer = tf.train.GradientDescentOptimizer(0.08)

    train = optimizer.minimize(cost)

    with tf.Session() as sess:
        init = tf.global_variables_initializer();
        sess.run(init)
        print('cost:', sess.run(cost), 'W=', sess.run(W), 'b=', sess.run(b))
        for k in range(1000):
            sess.run(train)
            print('cost:', sess.run(cost), 'W=', sess.run(W), 'b=', sess.run(b))
        print("执行完成")

        plt.plot(train_x, train_y)

        plt.plot(train_x, sess.run(y))

        plt.legend()

        plt.show()


def tfLearn():
    e = tf.square(2);
    f = tf.reduce_mean([1, 4])
    g = tf.random_uniform([2, 10])
    with tf.Session() as sess:
        init = tf.global_variables_initializer()
        sess.run(init)
        print("e是2的平方?", sess.run(e))
        print("f是求平均值?", sess.run(f))
        print("g是均匀分布的随机数?", sess.run(g))


if __name__ == "__main__":
    method_name = input("Enter your input:")
    if method_name == 'tfLearn':
        tfLearn()
    elif method_name == 'train':
        train();

转载于:https://my.oschina.net/payzheng/blog/1633083

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值