tensorflow实践系列 -- 线性规划 (batch and stochastic)

本文探讨了在TensorFlow中进行线性规划时,批量梯度下降(Batch Gradient Descent)与随机梯度下降(Stochastic Gradient Descent)的实践应用。通过比较两者,可以看到随机梯度下降的训练过程更加震荡。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

sess = tf.Session()
#batch gradient descent
x_vals = np.random.normal(1,0.1,100).reshape([100,1])
y_vals = np.repeat(10.0,100).reshape([100,1])

#model placehold
x_data = tf.placeholder(dtype=tf.float32,shape=[None,1])
y_target = tf.placeholder(dtype=tf.float32,shape=[None,1])
w = tf.Variable(tf.random_normal(shape=[1,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))

#cost function
y_hat = tf.add(tf.matmul(x_data,w),b)
#print(tf.reduce_mean(tf.square(x_vals - y_vals)))

loss = tf.reduce_mean(tf.square(y_hat - y_target))

#init variable
init = tf.global_variables_initializer()
sess.run(init)
#opt
my_opt = tf.train.GradientDescentOptimizer(0.02)
train_step = my_opt.minimize(loss)

#batch gradient descent
 batch_cache = []
 for i in range(100):
     sess.run(train_step,feed_dict={x_data:x_vals,y_target:y_vals})
     if (i+1) % 5 == 0:
         print('w: '+str(sess.run(w)))
         print('b: '+str(sess.run(b)))
         tmploss = sess.run(loss, feed_dict={x_data: x_vals, y_target: y_vals})
         print('loss: ' + str(tmploss))
         batch_cache.append(tmploss)
 plt.plot(range(0, 100, 5), batch_cache, 'b-', label='batch Loss')

#stochastic gradient descent
stochastic_cache = []
for i in range(100):
    rand_index = np.random.choice(100)
    rand_x = [x_vals[rand_index]]
    rang_y = [y_vals[rand_index]]
    sess.run(train_step, feed_dict={x_data: rand_x, y_target: rang_y})
    if i % 5 == 0:
        print('w: ' + str(sess.run(w)))
        print('b: ' + str(sess.run(b)))
        tmploss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rang_y})
        print('loss: ' + str(tmploss))
        stochastic_cache.append(tmploss)
print("===over===")
plt.plot(range(0,100,5),stochastic_cache,'b-',label='Stochastic Loss')
plt.show()
print(1.05*sess.run(w) + sess.run(b))

batch gradient decent




syochastic graident descent


可以明显看出,随机梯度下降相比批量梯度下降要震荡的比较多


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值