目标读者:假设读者是已经熟悉python,并且已经看了一些tensorflow示例程序,希望能了解tensorflow的内在编码规则、特点和高效的编码方式。如果是这样,本文会适合你。
本文切入点是介绍tensorflow与python/numpy的不同以及语法习惯/编程思路的转换,然后介绍TF语言的重要特征和推荐的编程习惯及方法;
如果读者还不熟悉numpy或python,请先学习相应教程。如果你还未看过tensorflow代码实例,可先学习tensorflow最新官方教程 https://www.tensorflow.org/tutorials, 或这个历史版本的中文教程:http://www.tensorfly.cn/tfdoc/get_started/introduction.html;看两三个示例,看懂就够了。
本文总结&修改自:https://github.com/vahidk/EffectiveTensorflow,大量素材也来源于此。
喜欢阅读英文资料的同学也可以直接阅读原文,目录相同,但部分实例和讲解角度不同。
============
目录
1. numpy 与 tensorflow的最大不同,符号型语言
2. 区分static shape和dynamic shape
============
1. numpy 与 tensorflow的最大不同,符号型语言
numpy是解释性语言,tensorflow是符号性语言。如果你是从c++(编译性语言) -> java -> python 的路径学习的编程语言,那么应该还没有接触过符号语言。类似的符号语言,还有 Hadoop(Map Reduce),Spark (scala, pyspark),如果了解它们则可以在学习过程中类比它们与tensorflow的异同。
符号语言(symbolic)最大的特点就是暗含了一个框架规则,不需要把程序的每一步都清清楚楚的写出来,但你也必须按框架要求的接口编程。Tensorflow中只需要把数据流图定义好,然后指定数据,系统执行时会使用大量框架约束好的,自动执行的行为,这部分不需要显示定义。在tensorflow中,和numpy最显著地区别就是你不需要再显示去计算梯度了,只需要调用.minimize(),tensorflow就会自己去算loss所依赖的所有w的梯度,并更新这些w。这会大大减少编程成本,当然同时也会增加学习成本。(类似于Map Reduce,你只需要实现mapper和reducer就好,剩下的就配置超参指定数据源,剩下的系统自己搞定)
因此,同样的操作,可以看看下图numpy和tensor编程时的异同:
通过代码可以看到:
- tensorflow中直接print(z) 会报错,所有的节点在sess.run()中执行前都是虚拟的,不包含值。print时会得到:
-
Tensor("MatMul:0", shape=(10, 10), dtype=float32)
- 其原因是tensorflow中,代码分为graph定义 和 session执行两部分,且必须先定义graph再定义session。前者定义数据流图,但不会立刻执行,后者才会触发执行。它还有以下特点:
- Graph部分代码:是特殊的代码块,一般都是在GPU中并行执行的,这里应该定义且仅定义矩阵计算代码。部分python代码将无法被正常执行(但可能不报错),比如print();另外一定要杜绝使用for, while 语句,并尽量少使用 if 语句,已提升运行性能。
- Session部分代码:位于普通python代码函数中,一般都是在CPU中串行执行的,可以认为是普通的python代码。这个函数可以自由调用各种python函数,比如print(),for, np.xxx等。
- 交互规则:在session执行之前,graph中节点是没有值的。session可以循环执行图中节点,后者除了varaible类型节点不会被重复初始化外,图中的所有节点值均会存储一份实例在内存,并在上次更新结果的基础上继续更新。如果是多GPU并行,它们还会在多卡之间通过参数服务器同步。
- 归属区分:除了像tf.session()这种流程接口函数,所有你能遇到矩阵计算相关的 tf.xxx 函数(区别于np.xxx函数)都是graph函数,它们都只能在graph定义时出现,并在sess.run()中通过运行图节点被真正执行,不能在普通函数中正常执行(如上面直接print(z))。
- 兼容性:Tensorflow支持你把graph和session代码混着写,即你可以把它们写进同一个函数。这方便了我们在交互式环境中调试代码,但不是好的编程习惯。正式代码一定要把graph和session代码分开写。另外在Graph中是可以调用 np.xxx 矩阵计算函数的,np.xxx得到的数据不具备Graph节点的特点,即其值不会被记忆,也不能被自动更新,一般只用于常量的声明,但是并不是好习惯,提倡用tf.xxx创建常量。
- 注意:tf实例代码中,这里是tf.random_norm()生成tensor,所以可以直接执行。
- 如果是tf.get_variable("somename", shape=[10,10]) 生成的是variable,则需要先sess.run(tf.global_vairables_initializer()) 后才能计算后续节点。
2. 区分static shape和dynamic shape
tensorflow的调试过程中经常要打印tensor的shape出来,来确定我们的各项操作正确性(因为一般tensor都太大,你不太可能打印其中的每个元素,会刷屏)。tensorflow中定义了两种shape类型:
- 静态shape: 其类型是 <class 'tenso