用SSD-PyTorch训练自己的数据集-完整教程
本机系统环境
- Ubuntu 18.04
- Miniconda3 4.6.14
- python 3.7.3
- PyTorch 1.0.1
SSD-PyTorch安装与测试
复制源码
GitHub: SSD: Single Shot MultiBox Object Detector, in PyTorch
git clone https://github.com/amdegroot/ssd.pytorch.git
下载数据集
COCO
sh data/scripts/COCO2014.sh
VOC 2007 & 2012
sh data/scripts/VOC2007.sh
sh data/scripts/VOC2012.sh
运行脚本后会把数据集存放在~/data中
修改源码
由于PyTorch版本差异,需要对源码几处地方进行修改(行数可能略有差异):
附上一个比较详细的帖子:目标检测——SSD源码解读
- 将所有 “.data[0]” 修改为 “.item()”
- multibox_loss.py:
- 调换第97,98行:
loss_c = loss_c.view(num, -1)
loss_c[pos] = 0 # filter out pos boxes for now
- 修改第114行为:
N = num_pos.data.sum().double()
loss_l = loss_l.double()
loss_c = loss_c.double()
- 修改第90行为:
# loss_l = F.smooth_l1_loss(loc_p, loc_t, size_average=False)
loss_l = F.smooth_l1_loss(loc_p, loc_t, reduction='sum')
- 修改第112行为:
# loss_c = F.cross_entropy(conf_p, targets_weighted, size_average=False)
loss_c = F.cross_entropy(conf_p, targets_weighted, reduction='sum')
下载权重文件
# 预训练权重
mkdir weights
cd weights
wget https://s3.amazonaws.com/amdegroot-models/vgg16_reducedfc.pth
# 训练好的VOC权重
wget https://s3.amazonaws.com/amdegroot-models/ssd300_mAP_77.43_v2.pth
测试
- 检测:
用jupyter notebook运行demo/demo.ipynb - 训练:
python train.py
默认训练数据集是VOC,需提前下载好数据集和预训练权重 - eval:
python eval.py
训练自己的数据集
准备数据集
数据集建议按照VOC或COCO格式进行打包整理,本人使用的是VOC格式,VOC数据集框架如下图所示(图源见水印):
只需要Annotations,ImageSets/Main,JPEGImages这三个文件夹即可,train.txt, val.txt, trainval.txt中分别是训练集,验证集,所有图片的文件名(不需要后缀)。
源码修改
- 修改数据集相关源码
- 复制一份data/voc0712.py,将VOC_CLASSES和VOC_ROOT分别修改为自己的类别名和数据集路径即可。
- 修改数据配置相关源码
- 修改data/config.py,复制一份voc相关设置,修改其中的类别数,steps等数据
- 注意:num_classes应设置为实际类别数+1
训练与测试
- 训练
python train.py
- 测试
python eval.py