TensorFlow控制语句
tf.cond()
tf.cond(
pred,
true_fn=None,
false_fn=None,
strict=False,
name=None,
fn1=None,
fn2=None
)
官网描述
当pred=True的时候返回true_fn(),pred=False的时候返回false_fn().
true_fn和false_fn都需要返回相同大小和类型的list of output tensors。
Warning: 任何在true_fn和false_fn之外创造的tensor和operation都会被执行无论运行哪一个判断分支
state=tf.Variable(0)
one=tf.constant(1)
new_value=tf.add(state,one)
update=tf.assign(state,new_value)
init=tf.initialize_all_variables()
cond=tf.cond(tf.greater_equal(new_value,2),lambda:update,lambda:tf.constant(-1))
with tf.Session() as sess:
sess.run(init)
for _ in range(3):
print("new_value",new_value.eval())
print(sess.run(cond))
#sess.run(update)
state=tf.Variable(0)
one=tf.constant(1)
init=tf.initialize_all_variables()
def fn_1():
new_value=tf.add(state,one)
update=tf.assign(state,new_value)
return update;
def fn_2(): return tf.constant(-1)
cond=tf.cond(tf.greater_equal(new_value,2),fn_1,lambda:tf.constant(-1))
with tf.Session() as sess:
sess.run(init)
for _ in range(3):
print("new_value",new_value.eval())
print(sess.run(cond))
#sess.run(update)
通过比较发现,如果new_value和update在fn_1()中定义的话,那么只会当条件判断分支选择运行fn_1()时候才会运行;如果new_value,update在fn_1()之外定义,那么每次判断都会执行一次
tf.Where()
tf.where(
condition,
x=None,
y=None,
name=None
)
根据conditioin的值要么返回x要么返回y中元素。
如果x,y都为None那么返回condition中true元素的坐标。这些坐标按照2-D张量的形式被返回,第一维(row)反对true element的个数,第二维(column)返回true element对应坐标。注意返回的shape取决于input中有多少true element。 索引是output按照row-major的顺序。
如果x,y都为非空,那么x,y必须同样的形状大小。condition要么和x的第一位匹配,要么和x同样大小
arr=tf.constant([[-1,0,1,2],
[2,0,-3,-5]])
pred=tf.less(arr,0)
ans=tf.where(pred)
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
print("pred is")
print(sess.run(pred))
print("tf.where")
print(sess.run(ans))
arr=tf.constant([[-1,0,1,2],
[2,0,-3,-5]],tf.float32)
pred=tf.less(arr,0)
ans=tf.where(pred,arr,tf.zeros([2,4],tf.float32))
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
print("pred is")
print(sess.run(pred))
print("tf.where")
print(sess.run(ans))
tf.gather()
tf.gather(
params,
indices,
validate_indices=None,
name=None,
axis=0
)
按照indice从params整合slice