原帖:向日葵智能
经过第四节和第五节的总结,对 tensorflow 的认识越来越深了。现在觉得它有点像一门特殊的编程语言,如果想使用它,就得先了解它的语法(规则)。虽然说第二节被称为 tensorflow 界的 “hello world”,我还是希望能够利用 tensorflow 做些
自认为简单
的事情。所以,本节先通过 python 的 numpy 模块生成一个线性函数,并且用 tensorflow 逼近它。
生成数据
先生成 100 个随机数,随机数类型为 float32,然后,做一个斜率为 0.5,偏置为 0.8 的直线,python 代码如下
可以参照python基础,安装并使用matplotlib库画图小节,把该直线画出来,有个直观感受。python 代码如下:
执行,直线图如下:
描述 tensorflow 模型
模型依然是简单的
y = w*x + b
权值矩阵 w 默认给了 -1.0 到 1.0 范围内的随机数,偏置 b 则默认全是 0。损失函数定义为计算值 y
和实际值 ydata
的平方和的平均值,训练 train 则采用学习因子为 0.3 的梯度下降法,最后初始化所有变量。以下是 python 代码:
建立 session 会话,训练模型
之前几节提到,tensorflow 是以计算图为元素的,本节的代码都是建立在默认图 tf
上的,loss
train
init
都是图上的节点(op),计算图都是在 session
里进行的,所以咱们要先建立 session,然后就可以 run(init) 了。
代码共训练 301 次,每 30 次输出一次结果。这里需要说明的是,w
和 b
是图 tf
的节点,想要观察其表示的值,需要通过 session 的 run 方法。全部 python 代码如下:
运行之,得到结果如下:
可以看出,w 和 b 分别非常接近 0.5
和 0.8
,成功了。