安装
1. 安装python,cuda,torch
按照官方文档使用cuda按照 https://mmclassification.readthedocs.io/en/latest/install.html
(1)python3.7
(2)目前cuda支持版本为9.2,10.1,10.2,11.0,11.1
!!!不支持cuda10,为了同时使用tf1和tf2装的cuda10不可
Windows下可以安装多个版本的cuda:https://blog.youkuaiyun.com/qq_17783559/article/details/112916708
(3)conda命令安装,目前支持版本为1.3-1.8
2. 安装mmvc
按照官方文档进行安装 https://github.com/open-mmlab/mmcv
pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.7.0/index.html
其中mmvc_version写1.3.0
3. 安装mmclassification
按照官方文档 https://mmclassification.readthedocs.io/en/latest/install.html
git clone https://github.com/open-mmlab/mmclassification.git
cd mmclassification
pip install -e .
训练
使用自定义数据集
1. 准备好数据集,按照meta, train, val的格式组织
2. 生成文件写入其类别
3. mmcls/datasets目录下创建py文件,定义类
4. 修改configs以及一系列文件
训练
python tools/train.py configs/resnet/resnet_mydataset.py
python tools/train.py [配置文件]
Tips:
- 训练过程中生成的checkpoint保存在work_dirs文件夹中
测试
python tools/test.py configs/resnet/resnet_mydataset.py checkpoints/resnet_mydataset.pth --out result_mydataset.pickle
python tools/test.py [配置文件] [断点文件] --out [输出文件名]
Tips:
-
checkpoint文件要放在checkpoints文件夹中才可以用
-
测试的时候不指定metric的话输出文件中写入的是对每个样本的分类情况,指定metric的话写入的则是整体的情况