Tensorflow:tf.gradient()用法以及参数stop_gradient理解

本文深入解析了TensorFlow中tf.gradients函数的使用方法,详细介绍了如何通过该函数进行微分计算,包括参数ys、xs及stop_gradients的具体作用。通过实例演示了不同参数设置下梯度计算的变化,帮助读者理解反向传播过程。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

tf.gradient()

tf.gradients(
    ys,
    xs,
    grad_ys=None,
    name='gradients',
    colocate_gradients_with_ops=False,
    gate_gradients=False,
    aggregation_method=None,
    stop_gradients=None
)

ys : 类型是张量或者张量列表,类似于目标函数,需要被微分的函数
xs:类型是张量或者张量列表,需要求微分的对象。(上述即为:dys/dxs)
stop_gradients: 可选参数,类型是张量或者张量列表,不需要通过微分的对象(比较抽象,看完下面的例子)

用一个例子来帮助理解

a = tf.constant(0.)
b = 2 * a
g = tf.gradients(a + b, [a, b])
with tf.Session() as sess:
    print(sess.run(g))
结果:[3.0, 1.0]

a = tf.constant(0.)
b = 2 * a
g = tf.gradients(a + b, [a, b], stop_gradients=[a])
with tf.Session() as sess:
    print(sess.run(g))
结果:[3.0, 1.0]

a = tf.constant(0.)
b = 2 * a
g = tf.gradients(a + b, [a, b], stop_gradients=[b])
with tf.Session() as sess:
    print(sess.run(g))
结果:[1.0, 1.0]  

可以看出,第一个参数ys是准备被微分的函数,第二个参数即xs填的是反向传播是需要求导的参数,第三个参数即stop_gradient,在反向传播时,如果填了参数b,那么a + b中a,b都是独立的,否则a + b= 3a(因为在本例中b = 2a)


如果觉得我有地方讲的不好的或者有错误的欢迎给我留言,如果对您有帮助,帮我点个赞哦~,感谢大家阅读

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值