tf.gradients
tf.gradients(
ys,
xs,
grad_ys=None,
name=’gradients’,
colocate_gradients_with_ops=False,
gate_gradients=False,
aggregation_method=None,
stop_gradients=None
)
功能
ys
, xs
均是包含tensor的list
, ys中的每个tensor和xs中的tensor相关
计算步骤:(求和可以理解为batch内的累加)
Σlen(ys)i∂ys[i]∂xs[j],∀j∈[0,len(xs)−1]Σilen(ys)∂ys[i]∂xs[j],∀j∈[0,len(xs)−1]
stop_gradients
用于指定偏导停止链式法则的节点。
import tensorflow as tf
a = tf.constant([0.0])
b = a * 2
pg = tf.gradients(a+b, [a, b])
with tf.Session() as sess:
res = sess.run(pg)
print(res)
[array([3.], dtype=float32), array([1.], dtype=float32)]
import tensorflow as tf
a = tf.constant([0.0])
b = a * 2
pg = tf.gradients(a+b, [a, b], stop_gradients=[a,b])
with tf.Session() as sess:
res = sess.run(pg)
print(res)
[array([1.], dtype=float32), array([1.], dtype=float32)]
和tf.stop_gradients区别
tf.stop_gradients
是在构建图的过程中使用,指定停止链式法则的节点
stop_gradients
是在构建图之后使用
当程序运行时,碰到以上两种方式定义的stop_gradients, 均会停止链式法则
,进而求得部分偏导。