最近在做模型压缩,看到这篇文章,分享一下。
模型压缩:主要是为了降低运算量和节省功耗。
模型计算量是衡量深度学习是否适合在移动或嵌入式端计算的最重要指标,通常用 GOP 单位来表示。例如,流行的 ResNet-18 的计算量大约是 4 GOP,而 VGG-16 则为大约 31 GOP。深度学习模型必须储存在内存里面,而内存其实还分为片上内存和片外内存两种。片上内存就是 SRAM cache,是处理器集成在芯片上用来快速存取重要数据的内存模块。片上内存会占据宝贵的芯片面积,因此处理器中集成的片上内存大小通常在 1-10 MB 这个数量级。片外内存则是主板上的 DDR 内存,这种内存可以做到容量很大(>1 GB),但是其访问速度较慢。片外内存,离处理器很远,一次访问需要消耗很大能量。更关键的是,访问片外内存所需要的能量是巨大的。根据 Song Han 在论文中的估计,一次片外内存访问消耗的能量是一次乘-加法运算的 200 倍,同时也是一次访问片上内存所需能量的 128 倍。换句话说,一次片外内存访问相当于做 200 次乘法运算!因此,我们为了减小能量消耗,必须减少片外内存访问,或者说我们需要尽可能把模型的权重数据和每层的中间运算结果存储在片上内存而非片外内存。这也是为什么 Google TPU 使用了高达 28MB 片上内存的原因,这样的话我们就必须从深度学习模型大小方面想办法,尽量减小模型尺寸,让模型尽可能地能存储在片上内存,或者至少一层网络的权重数据可以存在片上内存。
常见的模型压缩几种方法:
模型压缩可以分为两大类:第一类是大幅调整模型结构(包括网络拓扑连接,运算等等),直接训练出一个结构比较苗条的模型;第二类是在已有模型的基础上小幅修改,通常不涉及重新训练(模型压缩)。
Bengio 的 Binarized Neural Network 可谓是第一类模型的先驱者,将神经元 activation 限制为-1 或 1,从而极大地降低了运算量。Google 也于一个多月前发表的 MobileNet,使用了 depth-wise convolution 来降低运算量以及模型大小。Depthwise convolution 能大幅降低运算量,但是同时不同特征之间的权重参数变成线性相关。理论上减小了自由度,但是由于深度学习网络本身就存在冗余,因此实际测试中性能并没有降低很多。MobileNet 的计算量仅为 1GOP 上下,而模型大小只有 4MB 多一些,但能在 ImageNet 上实现 90% 左右的 top-5 准确率。在这条路上努力的人也很多,前不久 Face++也发表了 ShuffleNet,作为 MobileNet 的进一步进化形式也取得了更小尺寸的模型。未来我们预期会有更多此类网络诞生。
第二种方法则是保持原有模型的大体架构,但是通过种种方法进行压缩而不用重新训练,即模型压缩。一种思路就是在数据编码上想办法。大家都知道数据在计算机系统中以二进制形式表示,传统的全精度 32-bit 浮点数可以覆盖非常大的数字范围,但是也很占内存,同时运算时硬件资源开销也大。实际上在深度学习运算中可能用不上这么高的精度,所以最简单直接的方法就是降低精度,把原来 32-bit 浮点数计算换成 16-bit 浮点数甚至 8-bit 定点数。一方面,把数据的位长减小可以大大减少模型所需的存储空间(1KB 可以存储 256 个 32-bit 浮点数,但可以存储 1024 个 8-bit 定点数),另一方面低精度的运算单元硬件实现更简单,也能跑得更快。当然,随着数据精度下降模型准确率也会随之下降,所以随之也产生了许多优化策略,比如说优化编码(原本的定点数是线性编码数字之间的间距相等,但是可以使用非线性编码在数字集中的地方使数字间的间距变小增加精度,而在数字较稀疏的地方使数字间距较大。非线性编码的方法在数字通讯重要已经有数十年的应用,8-bit 非线性编码在合适的场合可以达到接近 16-bit 线性编码的精度)等等。业界的大部分人都已经开始使用降低精度的方案,Nvidia 带头推广 16-bit 浮点数以及 8-bit 定点数计算,还推出了 Tensor RT 帮助优化精度。
除了编码优化之外,另一个方法是网络修剪(network pruning)。大家知道在深度学习网络中的神经元往往是有冗余的,不少神经元即使拿掉对精度影响也不大。网络修建就是这样的技术,在原有模型的基础上通过观察神经元的活跃程度,把不活跃的神经元删除,从而达到降低模型大小减小运算量的效果。当然,网络修剪和编码优化可以结合起来。Song Han 发表在 2016 年 ICLR 上的 Deep Compression 就同时采用了修剪以及编码优化的方法,从而实现 35 倍的模型大小压缩。
原文链接: http://www.zhiding.cn/techwalker/documents/J9UpWRDfVYHE5ToOGy30k4fU9v9ep3gPUOb3TSAsig