TensorFlow学习02-tf.cond() tf.where()

本文深入探讨了TensorFlow中的控制语句,包括tf.cond()、tf.where()和tf.gather()的功能与用法。通过实例解析,展示了如何利用这些控制语句进行条件判断和数据处理,特别关注了在不同场景下它们的行为差异。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

TensorFlow控制语句

tf.cond()

tf.cond(
pred,
true_fn=None,
false_fn=None,
strict=False,
name=None,
fn1=None,
fn2=None
)
官网描述
当pred=True的时候返回true_fn(),pred=False的时候返回false_fn().
true_fn和false_fn都需要返回相同大小和类型的list of output tensors。
Warning: 任何在true_fn和false_fn之外创造的tensor和operation都会被执行无论运行哪一个判断分支

state=tf.Variable(0)
one=tf.constant(1)
new_value=tf.add(state,one)
update=tf.assign(state,new_value)
init=tf.initialize_all_variables()
cond=tf.cond(tf.greater_equal(new_value,2),lambda:update,lambda:tf.constant(-1))
with tf.Session() as sess:
    sess.run(init)
    for _ in range(3):
        print("new_value",new_value.eval())
        print(sess.run(cond))
        #sess.run(update)

运行结果

state=tf.Variable(0)
one=tf.constant(1)
init=tf.initialize_all_variables()
def fn_1():
    new_value=tf.add(state,one)
    update=tf.assign(state,new_value)
    return update;
def fn_2(): return tf.constant(-1)
cond=tf.cond(tf.greater_equal(new_value,2),fn_1,lambda:tf.constant(-1))
with tf.Session() as sess:
    sess.run(init)
    for _ in range(3):
        print("new_value",new_value.eval())
        print(sess.run(cond))
        #sess.run(update)

运行结果
通过比较发现,如果new_value和update在fn_1()中定义的话,那么只会当条件判断分支选择运行fn_1()时候才会运行;如果new_value,update在fn_1()之外定义,那么每次判断都会执行一次

tf.Where()

tf.where(
condition,
x=None,
y=None,
name=None
)

根据conditioin的值要么返回x要么返回y中元素。
如果x,y都为None那么返回condition中true元素的坐标。这些坐标按照2-D张量的形式被返回,第一维(row)反对true element的个数,第二维(column)返回true element对应坐标。注意返回的shape取决于input中有多少true element。 索引是output按照row-major的顺序。
如果x,y都为非空,那么x,y必须同样的形状大小。condition要么和x的第一位匹配,要么和x同样大小

arr=tf.constant([[-1,0,1,2],
                [2,0,-3,-5]])
pred=tf.less(arr,0)
ans=tf.where(pred)
with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    print("pred is")
    print(sess.run(pred))
    print("tf.where")
    print(sess.run(ans))

在这里插入图片描述

arr=tf.constant([[-1,0,1,2],
                [2,0,-3,-5]],tf.float32)
pred=tf.less(arr,0)
ans=tf.where(pred,arr,tf.zeros([2,4],tf.float32))
with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    print("pred is")
    print(sess.run(pred))
    print("tf.where")
    print(sess.run(ans))

在这里插入图片描述

tf.gather()

tf.gather(
params,
indices,
validate_indices=None,
name=None,
axis=0
)

按照indice从params整合slice

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值