GitHub地址:https://github.com/TuSimple/mx-maskrcnn
首先表达一下对凯明大神和RBG大神的膜拜!
20171124 一定要在终端上使用命令行下载源代码,因为mxnet官网上会建议下载0.11.0版本,但是这个版本有问题,在后续训练中会报错,所以直接clone,这样会下载最新的mxnet
训练准备
1.首先下载Cityscapes数据(gtFine_trainvaltest.zip和leftImg8bit_trainvaltest.zip)下载地址是https://www.cityscapes-dataset.com/,它需要注册,并等待管理员进行确认,大家可以先试着下载,下载速度还可以。
2.下载Resnet-50预训练模型
bash scripts/download_res50.sh 得到resnet-50-0000.params
3.使用ROIAlign运算符构建MXNet
cp rcnn/CXX_OP/* incubator-mxnet/src/operator/
将rcnn/CXX_OP/下的文件夹复制到了 incubator-mxnet/src/operator/
(我一开始是先安装incubator-mxnet,后拷贝进去,一直有错,一定要先拷贝进去,将align也编译进去,不然一直会报AttribuuteError: 'module' object has no attribute 'ROIAlign')
从源代码建立MXNet请参考https://mxnet.incubator.apache.org/get_started/build_from_source.html(如果是从GitHub上下载的源代码,则目录中有incubator-mxnet,这边就不需要下载了)
我是在Ubuntu14.04下安装mxnet
3.1首先需要Ubuntu>=13.10,debian >=8
3.2需要BLAS库,可以安装ATLAS、OpenBLAS、MKL,我安装的是atlas
3.3需要opencv库
3.4需要cuda和cudnn用于加速,cuda7.5和8.0都需要cudnn5
3.5build 如果你是直接clone的源代码,那你的文件夹下就会有一个incubator-mxnet文件夹,直接cd进去
首先复制如下代码(因为0.11.0有问题,所以我下载的是0.12.0的版本)
(上面将mxnet装在了mxnet文件夹下,0.11.0的版本有问题)当不需要opencv时 当需要GPU和opencv同时支持时 这里需要注意一下,上面blas有几个选择,你用的哪个,就把use_blas改为哪个,你像我用的是altas,所以我是上面那种写法use_cuda_path改为自己的cuda地址
build C++package
想buildC++package时,在上面那句话后面加上USE_CPP_PACKAGE=1
3.6我并没有安装Scala、Julia、Perl这三个包,除了python包是必须的,其他的都不是必须的
安装Python包
cd python;
python setup.py install
有些时候需要安装setuptools和numpy(sudo apt-get install python-numpy)
如果在Python中可以import mxnet,则安装成功
mxnet.__version__可以查看menet的版本号
安装R包,
R-package时mxnet文件夹下的一个文件夹 这里说句我的错误,make rpkg的过程中,一直在lib/local/......找不到libmxnet.so(文件夹的具体路径我忘记了),这个文件是在mxnet/lib下,我的解决方案是将这个文件复制到上面的文件夹下,重新执行上述过程就解决了。安装需要一段时间,安装完成后,可以进行测试。
4.构建相关cython代码
make
5.开始训练
bash scripts/train_alternate.sh
如果没有报错,会进入这个界面
我的电脑配置不太好,使用一个GPU,源代码是使用了4个GPU,需要在train_alternate.sh中的最后一句话改为
这个大家视情况而定
这是官方给的可能出现的错误,我一开始就有第一个问题,后来发现是我忘了第三步中的复制了,马虎啊
现在正在训练中,可以发现RPNLogLoss一直在变小,不知道会训练到什么时候,先写这些,等训练好了再继续
大概一天以后,我的报错了
When generating RPN detection, after training RPN1, the processing turned down. The error message is shown as below:
Traceback (most recent call last):
File "train_alternate_mask_fpn.py", line 116, in <module>
main()
File "train_alternate_mask_fpn.py", line 113, in main
args.rcnn_epoch, args.rcnn_lr, args.rcnn_lr_step)
File "train_alternate_mask_fpn.py", line 39, in alternate_train
vis=False, shuffle=False, thresh=0)
File "/home/jiawenhe/workspace/mx-maskrcnn/rcnn/tools/test_rpn.py", line 60, in test_rpn
arg_params=arg_params, aux_params=aux_params)
File "/home/jiawenhe/workspace/mx-maskrcnn/rcnn/core/tester.py", line 22, in __init__
self._mod.bind(provide_data, provide_label, for_training=False)
File "/home/jiawenhe/workspace/mx-maskrcnn/rcnn/core/module.py", line 141, in bind
force_rebind=False, shared_module=None)
File "/usr/local/lib/python2.7/dist-packages/mxnet-0.12.0-py2.7.egg/mxnet/module/module.py", line 417, in bind
state_names=self._state_names)
File "/usr/local/lib/python2.7/dist-packages/mxnet-0.12.0-py2.7.egg/mxnet/module/executor_group.py", line 231, in __init__
self.bind_exec(data_shapes, label_shapes, shared_group)
File "/usr/local/lib/python2.7/dist-packages/mxnet-0.12.0-py2.7.egg/mxnet/module/executor_group.py", line 327, in bind_exec
shared_group))
File "/usr/local/lib/python2.7/dist-packages/mxnet-0.12.0-py2.7.egg/mxnet/module/executor_group.py", line 603, in _bind_ith_exec
shared_buffer=shared_data_arrays, **input_shapes)
File "/usr/local/lib/python2.7/dist-packages/mxnet-0.12.0-py2.7.egg/mxnet/symbol/symbol.py", line 1491, in simple_bind
raise RuntimeError(error_msg)
RuntimeError: simple_bind error. Arguments:
data: (1, 3, 1024, 2048)
im_info: (1, 3L)
[21:01:05] src/storage/./pooled_storage_manager.h:102: cudaMalloc failed: out of memory
原作者在恢复别人问题时,让他从这一步开始恢复训练,本来还弄不懂怎么个恢复训练法,后来想通了,这是训练rpn,将这句话屏蔽掉就可以了
1208 第三步的时候会出现资源的错误,错误地址https://github.com/TuSimple/mx-maskrcnn/issues/39,大家可以参考一下
1208 我是前天又重新安装的mxnet,换了一个新的环境跑的程序,第一步每次迭代3个小时,跑到现在,第三步还没跑完,每次迭代8个半小时
1218 终于跑完了,进行了24次迭代,圆满了,,这是生成的文件
评估准备 bash scripts/download_cityscapescripts.sh bash scripts/eval.sh
模型从https://pan.baidu.com/s/1o8n4VMU中下载
测试图片 bash scripts/demo_single_image.sh
我的因为是链接的服务器,所以在imshow出错了,改为使用opencv保存即可