1. 背景
这里的任务是要在TF框架下运用现在已经较为成熟的剪枝算法,由于TF中对于可训练参数的管理与Pytorch、Caffe并不相同,这篇文章对其TF实现卷积参数剪枝进行了探究。
2. TF中可训练参数的存储
TF中是将网络构建为一个计算图,在一个计算图中可以通过集合(collection)来管理不同类别的资源。比如通过tf.add_to_collection函数可以将资源加入一个或多个集合中,然后通过tf.get_collection获取一个集合里面的所有资源(如张量,变量,或者运行TensorFlow程序所需的队列资源等等)。下表就是TF中的一些集合及其释义:
| 集合名称 | 集合内容 | 使用场景 |
|---|---|---|
tf.GraphKeys.VARIABLES | 所有变量 | 持久化 TensorFlow 模型 |
tf.GraphKeys.TRAINABLE_VARIABLES | 可学习的变量(一般指神经网络中的参数) | 模型训练、生成模型可视化内容 |
tf.GraphKeys.SUMMARIES | 日志生成相关的张量 | TensorFlow 计算可视化 |
tf.GraphKeys.QUEUE_RUNNERS | 处理输入的 | QueueRunner |
tf.GraphKeys.MOVING_AVERAGE_VARIABLES | 所有计算了滑动平均值的变量 | 计算变量的滑动平均值 |
- 1)TensorFlow中的所有变量都会被自动加入
tf.GraphKeys.VARIABLES集合中,通过tf.global_variables()函数可以拿到当前计算图上的所有变量(包含参与训练与不参与的)。拿到计算图上的所有变量有助于持久化整个计算图的运行状态。 - 2)当构建机器学习模型时,比如神经网络,可以通过变量声明函数中的trainable参数来区分需要优化的参数(比如神经网络的参数)和其他参数(比如迭代的轮数,即超参数),若
trainable = True,则此变量会被加入tf.GraphKeys.TRAINABLE_VARIABLES集合。然后通过tf.trainable_variables()函数便可得到所有需要优化的参数(参与剪枝的参数都在这个集合里面)。TensorFlow中提供的优化算法会将tf.GraphKeys.TRAINABLE_VARIABLES集合中的变量作为 默认的优化对象。
参考文章:TensorFlow 常用的函数
3. 获取与修改可训练参数
可训练参数的获取:
在上面通过tf.trainable_variables可以获取到残友训练的参数名称,拿到这个参数名称之后可以通过get_tensor_by_name函数获取可训练参数Tensor,之后就可以统计里面参数的分布,按照剪枝算法来剪除冗余的参数。
with tf.Session() as sess:
varvar = sess.graph.get_tensor_by_name('A/var:0')
print(sess.run(varvar))
在做完有效参数选择之后,就需要将修改之后的参数更新到原有的计算图中,这里使用tf.assign()来实现:
with tf.Session() as sess:
sess.run(tf.assign(varvar, new_varvar))
本文探讨了在TensorFlow框架下实现卷积参数剪枝的方法。首先介绍了背景,说明了TF中可训练参数的存储方式,包括如何通过集合管理资源。接着详细讲解了如何获取和修改可训练参数,特别是如何利用trainable属性区分优化参数,并通过函数进行参数的更新操作。
1542

被折叠的 条评论
为什么被折叠?



