关于Tensorflow
说实话这到底是个啥我还没明白(明白了再删这句话)。
也并不准备了解的多么详细,还是老办法,只看看用过的用到的,其他那些慢慢来,又不是准备考试,也没必要一口吃成大胖子。
不废话了。
因为是在学强化学习的过程中使用到了tensorflow,所以找到的代码是movan几年前提供的,他用的tensorflow版本是1.0大概。虽然现在2020年已经更新到了2.0版本,但是为了方便,又考虑到自己是个菜鸡, 就先看1.0版本吧,2.0里做的那些改动等日后若是需要进阶学习 tf 再了解也不迟。
所以导入的时候使用
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
这样,就可以继续使用1.0中的一些函数。第二句的作用是屏蔽一些紧急操作而出现的错误,比如在调用placeholder的时候,如果没有这一句,依然会因为eager execution而出错。(感谢大佬们的细致研究,佩服佩服,啥都会处理)
附一个大佬链接 https://blog.youkuaiyun.com/qq_39777550/article/details/104224296
一些用过的函数
Session
小助教说,可以把session看成是一张白纸,可以在上面写东西。
在tf中,必须先定义变量,并且必须进行 initialize变量,通过下面给出的方法(直接run(initializer) 或者 定义init函数之后run(init))
tf.assign(a, b) 是将b的值赋给a,然后返回的值是a赋值之后的值。
在tf的变量中,要想进行print或者赋值assign等一系列操作,在定义这些操作的时候是并没有执行的,只有在sess.run()这些操作才是真正的执行。也就是说,所有的操作都要通过sess.run()来执行。
import tensorflow as tf
state = tf.Variable(0) #这时变量state赋予初始值0
one = tf.constant(1) #one赋值常量1
update = tf.assign(state, one)
init = tf.global_variables_initializer() #变量初始化的函数
with tf.Session() as sess:
sess.run(init) #必须先初始化所有变量,并且要以这种方式
# 也可以
# sess.run(tf.global_variables_initializer())
print(sess.run(update)) # 通过这种方式打印结果
还有,小助教给的创建变量的方法是这样的,使用的是tf.get_variable() 函数:
var_init_value = [[2.0, 4.0, 6.0]]
var = tf.get_variavle( name = 'myvar',
shape = [1, 3],
dtype = tf.float32,
initializer = tf.constant_initializer(var_init_value)
)
print(var)
别忘了变量需要 initializer 激活!
placeholder
这个函数在新版本中被移除了,但是movan的代码里就是靠这个函数在运作。且不管这么多了。
placeholder是tf中的占位符,暂时储存变量。Tensorflow如果想要从外部传入data,就要用到 tf.placeholder() ,然后进行下面操作:
#需要定义placeholder的type,一般是float32
input1 = tf.placeholder(dtype = tf.float32, shape = [None], name = "a_placeholder")
input2 = tf.placeholder(dtype = tf.float32, shape = [None], name = "b_placeholder")
output = tf.multiply(input1, input2)
# 需要传入的值放在feed_dict中,并且一一对应
with tf.Session() as sess:
print(sess.run(output, feed_dict = {input1 : [7.], input2 : [2.]}))
# 结果是14
还要注意的是,tf的运算非常高级,当你进行矩阵的运算,就算两个相运算的矩阵维度不一样,他也会自动匹配成一样的矩阵。虽然可以利用这种性质(叫做broadcasting),但是在当你无意识犯错的时候,他是不会报错的……一定小心。
Tips
- 可以用 tf.global_variables() 打印出所有的变量,以检查自己在优化的变量是哪些,有没有错误。
- 可以用help(tf.reduce_mean) 来看看这个函数的用法。