1.创建conda环境
推荐通过conda创建虚拟环境,具体操作可见linux系统下创建anaconda新环境及问题解决
2.clone代码并安装依赖库
git clone https://github.com/facebookresearch/detr.git
conda install -c pytorch pytorch torchvision
conda install cython scipy
pip install -U 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'
pip install git+https://github.com/cocodataset/panopticapi.git
3.准备自己的数据集
- 使用COCO格式数据集,其文件目录如下:
- 其中,annotations包含训练集和验证集对应的json文件
- train2017包含训练集图片;val2017包含验证集图片
※如果事先准备好了VOC格式的数据集,则可通过脚本进行转换,详见VOC格式数据集转为COCO格式数据集脚本
4.下载预训练模型并修改类别参数
- 创建一个python文件,根据自己的目标类别数目对原始用于coco数据集的预训练模型进行转换
import torch
pretrained_weights = torch.load("./detr-r50-e632da11.pth")
num_class = 2 + 1
pretrained_weights["model"]["class_embed.weight"].resize_(num_class+1,256)
pretrained_weights["model"]["class_embed.bias"].resize_(num_class+1)
torch.save(pretrained_weights,'detr_r50_%d.pth'%num_class)
- 更改detr.py中的目标类别数目(这里干脆都改成一样的了)
5.训练
- 运行main.py并传递相应的参数进行训练
python main.py --dataset_file "coco" --coco_path /path/to/coco/ --resume="detr_r50_3.pth"
6.推理
- 同样运行main.py 需要
--eval
及其它相关参数
7.plot
- 借助plot_utils.py,在文件末尾添加下方代码,更改路径并运行即可。
if __name__ == '__main__':
files = list(Path('../outputs/eval').glob('*.pth'))
plot_precision_recall(files)
plt.show()
plot_logs(logs=Path('../outputs/log/'),fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt')
plt.show()
- 参考链接:
https://blog.youkuaiyun.com/w1520039381/article/details/118905718