TensorFlow中卷积参数剪枝的实现基础探究

本文探讨了在TensorFlow框架下实现卷积参数剪枝的方法。首先介绍了背景,说明了TF中可训练参数的存储方式,包括如何通过集合管理资源。接着详细讲解了如何获取和修改可训练参数,特别是如何利用trainable属性区分优化参数,并通过函数进行参数的更新操作。
部署运行你感兴趣的模型镜像

1. 背景

这里的任务是要在TF框架下运用现在已经较为成熟的剪枝算法,由于TF中对于可训练参数的管理与Pytorch、Caffe并不相同,这篇文章对其TF实现卷积参数剪枝进行了探究。

2. TF中可训练参数的存储

TF中是将网络构建为一个计算图,在一个计算图中可以通过集合(collection)来管理不同类别的资源。比如通过tf.add_to_collection函数可以将资源加入一个或多个集合中,然后通过tf.get_collection获取一个集合里面的所有资源(如张量,变量,或者运行TensorFlow程序所需的队列资源等等)。下表就是TF中的一些集合及其释义:

集合名称集合内容使用场景
tf.GraphKeys.VARIABLES所有变量持久化 TensorFlow 模型
tf.GraphKeys.TRAINABLE_VARIABLES可学习的变量(一般指神经网络中的参数)模型训练、生成模型可视化内容
tf.GraphKeys.SUMMARIES日志生成相关的张量TensorFlow 计算可视化
tf.GraphKeys.QUEUE_RUNNERS处理输入的QueueRunner
tf.GraphKeys.MOVING_AVERAGE_VARIABLES所有计算了滑动平均值的变量计算变量的滑动平均值
  • 1)TensorFlow中的所有变量都会被自动加入tf.GraphKeys.VARIABLES集合中,通过tf.global_variables()函数可以拿到当前计算图上的所有变量(包含参与训练与不参与的)。拿到计算图上的所有变量有助于持久化整个计算图的运行状态。
  • 2)当构建机器学习模型时,比如神经网络,可以通过变量声明函数中的trainable参数来区分需要优化的参数(比如神经网络的参数)和其他参数(比如迭代的轮数,即超参数),若trainable = True,则此变量会被加入tf.GraphKeys.TRAINABLE_VARIABLES集合。然后通过tf.trainable_variables()函数便可得到所有需要优化的参数(参与剪枝的参数都在这个集合里面)。TensorFlow中提供的优化算法会将tf.GraphKeys.TRAINABLE_VARIABLES集合中的变量作为 默认的优化对象。

参考文章:TensorFlow 常用的函数

3. 获取与修改可训练参数

可训练参数的获取:
在上面通过tf.trainable_variables可以获取到残友训练的参数名称,拿到这个参数名称之后可以通过get_tensor_by_name函数获取可训练参数Tensor,之后就可以统计里面参数的分布,按照剪枝算法来剪除冗余的参数。

with tf.Session() as sess:
	varvar = sess.graph.get_tensor_by_name('A/var:0')
	print(sess.run(varvar))

在做完有效参数选择之后,就需要将修改之后的参数更新到原有的计算图中,这里使用tf.assign()来实现:

with tf.Session() as sess:
	sess.run(tf.assign(varvar, new_varvar))

您可能感兴趣的与本文相关的镜像

TensorFlow-v2.15

TensorFlow-v2.15

TensorFlow

TensorFlow 是由Google Brain 团队开发的开源机器学习框架,广泛应用于深度学习研究和生产环境。 它提供了一个灵活的平台,用于构建和训练各种机器学习模型

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值