参考代码:pdarts
1. 概述
导读:这篇文章在DARTS的工作的基础上,指出直接使用DARTS方法得到的模型搜索结果直接运用到大数据集(如ImageNet)的之后会存在性能下降的问题,文章经过分析得到这是由于在网络搜索和验证(另外的数据集)中是存在模型结构深度上的deep gap的。直接使用对应数据集(如ImageNet)进行搜索会带来很大的计算开销,这也是先有很多网络搜索都去小数据集上去提取网络结构,之后针对大的数据集对网络再进行放大的原因。对此一种简单的解决办法便是采用多级递进的方式进行搜索,边搜索边排除一些无效的搜索空间,从而可以在有限的计算资源下去缩小这个深度上的deep gap。对于网络搜索的稳定性问题,文章提出了一些正则化的策略,约束网络搜索过程中的skip-connection。文章的方法在CIFAR-10数据集上测试错误率为2.5%,ImageNet上top-1/5错误率为24.4%和7.4%。
再DARTS系列的网络搜索方法中,首先在小数据集上进行网络搜索,之后将其当大到目标数据集上(图a),而文章所使用的是多级递进的方式(图b),具体见下图所示:
具体的,文章对网络搜索的优化集中在两个方面:
- 1)多级逼近:上图中使用多级递进方式进行改进是为了缩小两个域之间的gap,从而防止了在小数据集上搜索的到的结构迁移到大的大数据上不是适配的问题(小网络和大网络它们对应的搜索空间是不一样的)。同时通过排除无效搜索空间的方式减少资源消耗,从而使得可以搜索得更深;
- 2)搜索稳定性:随着网络搜索的进行搜索算法会渐渐偏向于选择更多的skip connection,这就会导致网络的collapse(浅层的网络在梯度优化过程中变化更为迅速,而不是像传统认知中趋向更深的网络结构),文章为了改善这个问题,提出两点改进:使用操作(作用在shortcut OP上)级别的dropout去减少skip connection的作用;控制skip connection的产生;
2. 方法设计
2.1 搜索空间的近似
这里所说的搜索空间的近似,是使用多层递进的形式去逼近大网络的搜索空间,在不同层的搜索过程中逐渐增加搜索网络的深度,同时排除网络结构中概率较低的部分,用于减少显存占用,其过程描述为下图所示:
在整个分级搜索的过程中,在实现网络深度增加的同时,很好控制了显存的增加,下表展示了几个阶段进行搜索显存的占用情况:
2.2 搜索空间的正则化
由于在网络搜索的过程中,网络搜索会趋向于去使用skip connection(model collapse),而不是诸如卷积/池化这类操作。文章对此的解释是skip connection拥有更快的梯度下降属性,这就使得在诸如CIFAR-10这样的数据集上很容易拟合,从而使得搜索算法会给这个操作较大的选中概率,这就导致了网络不能抽取到较深的语义信息,从而导致在别的数据集下性能较低,其实这便是一种过拟合现象。
对此,文章从网络正则化的角度对skip connection进行约束,具体为如下的两个方法:
- 1)skip connection之后添加dropout:这里文章是在skip connection操作的后面使用一个概率去将其丢弃,这个丢弃的概率是随着搜索的阶段衰减的(固定的丢弃概率会妨碍网络性能的提升,skip connection对于网络也是相当重要的),这样在较为初始的阶段使用较大的概率,而在后面的阶段使用较小的概率,从而不仅约束了skip connection的选择,保证了搜索空间其他操作参数的有效学习,又有效保证了skip connection的作用,兼顾了网络的性能提升;
- 2)每个搜索单元skip connection固定为 M M M个:上文中提到skip connection对于网络性能也是很重要的,因而不能将其减少过多,将其设置为一个合理的范围就好了(文中为2~4),这是在最后一级之后的结果。其实就是在网络搜索完成之后,按照这个超参数选择其中概率最大的几个skip connection。其中需要注意的是这里的正则化需要在第一步正则化之后进行,从而减少低质量skip connection被选中的概率;
使用上述的正则化约束确实可以对skip connection进行约束,但是引入了较多的超参数,如drop rate/ M M M,这就相当依赖人为的设置了,个人觉得是后续可以优化的点。
3. 实验结果
CIFAR10 and CIFAR100数据集:
ImageNet数据集: