
title: 手把手教你如何用SOLOV2训练自己的数据集
date: 2020-07-24 14:59:11
category: 默认分类
本文介绍 手把手教你如何用SOLOV2训练自己的数据集 <!-- more -->
手把手教你如何用SOLOV2训练自己的数据集
本文由林大佬原创,转载请注明出处,来自腾讯、阿里等一线AI算法工程师组成的QQ交流群欢迎你的加入: 1037662480
最近后台很多小伙伴跟我说能不能出一些实例分割训练的教程, 因为网上很多都是关于加速/部署的, 为了满足大家的愿望, 今天特意给大家带来了现在比较火的SOLO系列算法的训练教程. 确实现在关于实力分割的教程都比较复杂, 这篇文章可以让大家轻松地入门SOTA的实例分割方案, 感兴趣的同学也可以给本文点个赞, 转发一下, 你的支持是我们创作的原始动力!
这篇教程不需要任何神力会员权限, 直接从github clone代码, 先将代码准备好, 就可以开始了:
git clone https://github.com/WXinlong/SOLO
现在网上有好几个不同版本的SOLO开源算法, 但是原作者的这个应该是比较权威的吧, 大家可以用这个版本, 笔者用下来, 这个版本具有几个优点:
- 它基于mmdetection, 模块化, 代码看起来也比较通熟易懂;
- 训练起来没什么坑, 对于没有8卡GPU的同学,用单卡或者两卡也是可以train的, 我们这篇文章会给出大家的具体指导;
但是也有一些缺点:
- 代码pytorch1.5跑不起来,更别说现在最新的pytorch1.7了, 需要我们修改过的代码 (兼容pytorch1.5和mmdetection2.0) 可以移步神力平台获取现成的代码;
- 代码注册新的dataset有点麻烦, 而且我发现(没有确认) 原始的dataset有bug, 相信很多同学在训练自己的数据集的时候会遇到第一个类别被自动忽略的bug, 当然这个bug已经被我们修复了, 详情也可以移步神力平台, 文末会放出我们的代码链接.
当然, 如果你只是训练我们今天的数据集, 那是足够了, 因为今天的数据集的主角很小很小很小, 但是麻雀虽小五脏俱全. 先来看看SOLOv2的分割效果:

这个数据集的名字叫做 坚果数据集.
因为它很小, 所以经常被我用来检测一个算法是不是work, 基本上两分钟就可以出结果. 我也强烈建议大家用起来, 关于数据集的下载, 推荐大家看这篇文章, 这篇文章的博主其实将的比较完全了:
https://www.jianshu.com/p/a94b1629f827www.jianshu.com这里也贴一下下载:
wget https://github.com/Tony607/detectron2_instance_segmentation_demo/releases/download/V0.1/data.zip
数据集的版权credit@Tony607 , 感谢这位作者的工作.
材料都准备好了, 接下来按照步骤来教授大家如何训练吧.
01. SOLO注册自定义数据集
首先, 我们需要注册一个自己的自定义数据集, 在原始的SOLO项目里面, 具体的注册方式为:
a). 在 mmdet/dataset
文件下, 创建一个 coco_toy.py
的文件, 文件中就是我们要注册的数据类.
b). 给数据类添加代码:
import numpy as np
from pycocotools.coco import COCO
from .custom import CustomDataset
from .registry import DATASETS
@DATASETS.register_module
class CocoToyDataset(CustomDataset):
CLASSES = ('date', 'fig', 'hazelnut')
def load_annotations(self, ann_file):
self.coco = COCO(ann_file)
self.cat_ids = self.coco.get_cat_ids(cat_names=self.CLASSES)
self.cat2label = {
cat_id: i for i, cat_id in enumerate(self.cat_ids)}
self.img_ids = self.coco.get_img_ids()
data_infos = []
for i in self.img_ids:
info = self.coco.load_imgs([i])[0]
info['filename'] = info['file_name']
data_infos.append(info)
return data_infos
def get_ann_info(self, idx):
img_id = self.data_infos[idx]['id']
ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
ann_info = self.coco.load_anns(ann_ids)
return self._parse_ann_info(self.data_infos[idx], ann_info)
def get_cat_ids(self, idx):
img_id = self.data_infos[idx]['id']
ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
ann_info = self.coco.load_anns(ann_ids)
return [ann['category_id'] for ann in ann_info]
def _filter_imgs(self, min_size=32):
"""Filter images too small or without ground truths."""
valid_inds = []
ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values())
for i, img_info in enumerate(self.data_infos):
if self.filter_empty_gt and self.img_ids[i] not in ids_with_ann:
continue
if min(img_info['width'], img_info['height']) >= min_size:
valid_inds.append(i)
return valid_inds
def get_subset_by_classes(self):
"""Get img ids that contain any category in class_ids.
Different from the coco.getImgIds(), this function returns the id if
the img contains one of the categories rather than all.
Args:
class_ids (list[int]): list of category ids
Return:
ids (list[int]): integer list of img ids
"""
ids = set()
for i, class_id in enumerate(self.cat_ids):
ids |= set(self.coco.cat_img_map[class_id])
self.img_ids = list(ids)
data_infos = []
for i in self.img_ids:
info = self.coco.load_imgs([i])[0]
info['filename'] = info['file_name']
data_infos.append(info)
return data_infos
def _parse_ann_info(self, img_info, ann_info):
"""Parse bbox and mask annotation.
Args:
ann_info (list[dict]): Annotation info of an image.
with_mask (bool): Whether to parse mask annotations.
Returns:
dict: A dict containing the following keys: bboxes, bboxes_ignore,
labels, masks, seg_map. "masks" are raw annotations and not
decoded into binary masks.
"""
gt_bboxes = []
gt_labels = []
gt_bboxes_ignore = []
gt_masks_ann = []
for i, ann in enumerate(ann_info):
if ann.get('ignore', False):
continue
x1, y1, w, h = ann['bbox']
if ann['area'] <= 0 or w < 1 or h < 1:
continue