项目动机:现在流行的NAS搜索算法很多,对应着不同的算法版本、搜索空间和数据集,我们希望将他们统一整合成一个算法库,容易对其性能进行多维度验证,也方便组内对算法进行探索改进
项目进展:完成SingPath在CIFAR-100上的评估,实验结果记录在算法运行结果中。正在重新训练DARTS搜索出来的网络架构。
运行方式:参考算法运行结果表格中的不同算法的运算过程。
1.搜索空间
ShuffleNet、MobileNet、Darts、NAS-Bench-101 |
2.数据集
CIFAR、ImageNet |
3.搜索算法
4.算法运行结果
1).SinglePath
搜索算法 | 搜索空间 | 数据集 | 训练过程 | 搜索出来的架构 | train loss | train acc | val loss | val acc | total GPU hours | |
|
ShuffleNet
|
CIFAR-100
| 1.训练SuperNet: bash run_train.sh | - | 1.19 | top1:72.66% top5:85.94% | 1.23 | top1:64.84% top5:89.06% | 1.51 | |
2.演化算法搜索 bash run_server.sh bash run_search.sh | - | - | - | - | - | 5.27 | ||||
3.重新训练评估搜索出来的架构(选择top3的架构) python3 eval.py会生成同时训练多个网络架构的eval.sh shell脚本 bash eval.sh | 1-3-1-3-2-2-3-0-2-3-0-2 | 1.21 | top1:88.85% top5:98.82% | 1.14 | top1:71.11% top5:91.38% | 0.088 | ||||
1-3-2-2-2-1-3-0-2-3-0-2 | 1.21 | top1:88.95% top5:98.79% | 1.15 | top1:70.98% top5:91.21% | 1.16 | |||||
1-3-2-2-2-1-3-1-2-3-1-2 | 1.20 | top1:89.27% top5:98.87% | 1.12 | top1:71.84% top5:91.47% | 0.61 |
2).DARTS
算法 | 数据集 | 优化算法 | 训练阶段 | 搜索出来的架构 | Params(M) | train acc | val acc | GPU time |
---|---|---|---|---|---|---|---|---|
DARTS(report) | CIFAR-10 | 二阶优化 | - | Genotype(normal=[('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 1), ('skip_connect', 0), ('skip_connect', 0), ('dil_conv_3x3', 2)], normal_concat=[2, 3, 4, 5], reduce=[('max_pool_3x3', 0), ('max_pool_3x3', 1), ('skip_connect', 2), ('max_pool_3x3', 1), ('max_pool_3x3', 0), ('skip_connect', 2), ('skip_connect', 2), ('max_pool_3x3', 1)], reduce_concat=[2, 3, 4, 5]) | 3.3 | - | 2.76 ± 0.09 | 1GPU day |
DARTS(round 1) | CIFAR-10 | 二阶优化 | 搜索阶段:(50 epoch) bash run_search.sh | Genotype(normal=[('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('skip_connect', 0), ('sep_conv_3x3', 2), ('skip_connect', 0), ('sep_conv_3x3', 1), ('skip_connect', 0), ('skip_connect', 1)], normal_concat=range(2, 6), reduce=[('max_pool_3x3', 0), ('skip_connect', 1), ('skip_connect', 2), ('max_pool_3x3', 0), ('skip_connect', 2), ('skip_connect', 3), ('skip_connect', 3), ('skip_connect', 2)], reduce_concat=range(2, 6)) | - | 99.71% | 88.95% | 17 GPU hours |
评估阶段(600 epoch): bash run_retrain.sh | - | 2.84 | 99.18% | 96.99%(3.01) | 20 GPU hours | |||
DARTS(round 2) | CIFAR-10 | 二阶优化 | 搜索阶段:(50 epoch) bash run_search.sh | Genotype(normal=[('skip_connect', 0), ('sep_conv_3x3', 1), ('skip_connect', 0), ('skip_connect', 1),('skip_connect', 0), ('sep_conv_3x3', 1), ('skip_connect', 0),('skip_connect', 1)], normal_concat=range(2, 6), reduce=[('max_pool_3x3', 0), ('avg_pool_3x3', 1), ('max_pool_3x3', 0), ('skip_connect', 2), ('max_pool_3x3', 0), ('skip_connect', 2), ('avg_pool_3x3', 0), ('skip_connect', 2)], reduce_concat=range(2, 6)) | - | 99.68% | 89.17% | 19 GPU hours |
评估阶段(600 epoch): bash run_retrain.sh | - | 2.09 | 99.10% | 96.82% | 14.5 GPU hours | |||
DARTS(round 3) | CIFAR-10 | 二阶优化 | 搜索阶段:(50 epoch) bash run_search.sh | Genotype(normal=[('sep_conv_3x3', 1), ('skip_connect', 0), ('skip_connect', 0), ('dil_conv_5x5', 1), ('sep_conv_3x3', 1), ('skip_connect', 0), ('skip_connect', 0), ('dil_conv_3x3', 1)], normal_concat=range(2, 6), reduce=[('max_pool_3x3', 0), ('skip_connect', 1), ('skip_connect', 2), ('max_pool_3x3', 0), ('skip_connect', 2), ('max_pool_3x3', 0), ('skip_connect', 3), ('skip_connect', 2)], reduce_concat=range(2, 6)) | - | 99.62% | 88.96% | 16 GPU hours |
评估阶段(600 epoch): bash run_retrain.sh | - | 2.50 | 99.1% | 96.81% | 14.5 GPU hours | |||
DARTS(round 4) | CIFAR-10 | 二阶优化 | 搜索阶段:(50 epoch) bash run_search.sh | Genotype(normal=[('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('skip_connect', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 2), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1)], normal_concat=range(2, 6), reduce=[('max_pool_3x3', 0), ('max_pool_3x3', 1), ('skip_connect', 2), ('max_pool_3x3', 1), ('skip_connect', 2), ('skip_connect', 3), ('max_pool_3x3', 1), ('skip_connect', 2)], reduce_concat=range(2, 6)) | - | 99.76% | 89.32% | 20 GPU hours |
评估阶段(600 epoch): bash run_retrain.sh | - | 3.89 | 98.81% | 97.25% | 14 GPU hours |
3).PDARTS
算法 | 数据集 | 优化算法 | 训练阶段 | 搜索出来的架构 | Param(M) | train acc | val acc | GPU time |
---|---|---|---|---|---|---|---|---|
PDARTS(report) | CIFAR_10 | - | Genotype(normal=[('skip_connect', 0), ('dil_conv_3x3', 1), ('skip_connect', 0),('sep_conv_3x3', 1), ('sep_conv_3x3', 1), ('sep_conv_3x3', 3), ('sep_conv_3x3',0), ('dil_conv_5x5', 4)], normal_concat=range(2, 6), reduce=[('avg_pool_3x3', 0), ('sep_conv_5x5', 1), ('sep_conv_3x3', 0), ('dil_conv_5x5', 2), ('max_pool_3x3', 0), ('dil_conv_3x3', 1), ('dil_conv_3x3', 1), ('dil_conv_5x5', 3)], reduce_concat=range(2, 6)) | 3.4 | - | 97.5% | 0.3 GPU days | |
PDARTS | CIFAR-10 | 搜索阶段(25 epoch): bash run_search.sh | Genotype(normal=[('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_5x5', 1)], normal_concat=range(2, 6), reduce=[('skip_connect', 0), ('avg_pool_3x3', 1), ('max_pool_3x3', 0), ('dil_conv_3x3', 1), ('max_pool_3x3', 0), ('avg_pool_3x3', 1), ('max_pool_3x3', 0), ('dil_conv_5x5', 2)], reduce_concat=range(2, 6)) | 4.38 | 90.92% | 83.93% | 6.76 GPU hours | |
评估阶段(600 epoch): bash run_retrain.sh
| - | - | 96.92% | 97.27% | 51 GPU hour |