Tensorflow更新一层中的部分参数方法

本文探讨了在TensorFlow中如何部分更新模型参数,特别是针对embedding层的情况。介绍了`tf.stop_gradients()`的限制及其在损失直接依赖目标张量时的效果。当embedding作为中间层时,建议使用`compute_gradients`和`apply_gradients`结合的方式,通过mask操作来设定要更新的参数部分。通过示例展示了如何验证参数是否成功进行了部分更新。

很多博客都分享了如何不更新某一层的参数,但是当遇到类似于embedding部分更新的情况,就必须局部更新某一层的参数。自己实践之后来分享一下。

tf.stop_gradients()

这个api是可以实现部分更新参数的,参考博客https://blog.youkuaiyun.com/hustchenze/article/details/84672430
但是只有在loss和你的target tensor直接发生关系的时候才能实现,否则会直接阻挡整个层的bp,也就是说如果你想部分更新最后一层的参数,那么用这个api是ok的(这是我个人实际使用的时候发现的,各位可以先做尝试看看)

import tensorflow as tf
import numpy as np
 
def entry_stop_gradients(target, mask):
    mask_h = tf.abs(mask-1)
    return tf.stop_gradient(mask_h * target) + mask * target
 
mask = np.array([1., 0, 1, 1, 0, 0, 1, 1, 0, 1])
mask_h = np.abs(mask-1)
 
emb = tf.constant(np.ones([10, 5]))
 
matrix = entry_stop_gradients(emb, tf.expand_dims(mask,1))
 
parm = np.random.randn(5, 1)
t_parm = tf.constant(parm)

#这里的loss和emb直接相关
loss = tf.reduce_sum(tf.matmul(matrix, t_parm))
grad1 = tf.gradients(loss, emb)
grad2 = tf.gradients(loss, matrix)
with tf.Session() as sess:
    
评论 17
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值