很多博客都分享了如何不更新某一层的参数,但是当遇到类似于embedding部分更新的情况,就必须局部更新某一层的参数。自己实践之后来分享一下。
tf.stop_gradients()
这个api是可以实现部分更新参数的,参考博客https://blog.youkuaiyun.com/hustchenze/article/details/84672430
但是只有在loss和你的target tensor直接发生关系的时候才能实现,否则会直接阻挡整个层的bp,也就是说如果你想部分更新最后一层的参数,那么用这个api是ok的(这是我个人实际使用的时候发现的,各位可以先做尝试看看)
import tensorflow as tf
import numpy as np
def entry_stop_gradients(target, mask):
mask_h = tf.abs(mask-1)
return tf.stop_gradient(mask_h * target) + mask * target
mask = np.array([1., 0, 1, 1, 0, 0, 1, 1, 0, 1])
mask_h = np.abs(mask-1)
emb = tf.constant(np.ones([10, 5]))
matrix = entry_stop_gradients(emb, tf.expand_dims(mask,1))
parm = np.random.randn(5, 1)
t_parm = tf.constant(parm)
#这里的loss和emb直接相关
loss = tf.reduce_sum(tf.matmul(matrix, t_parm))
grad1 = tf.gradients(loss, emb)
grad2 = tf.gradients(loss, matrix)
with tf.Session() as sess: