win10安装mmdetection以及训练私有数据集
1、安装mmdetection
参考:https://www.huaweicloud.com/articles/b0247a0d742f451efbf435acfc79a40d.html
默认已经安装过anaconda、cuda、cudnn、pytorch等等一系列的包,安装mmdetection主要有两个模块,一个是mmcv一个是mmdet,mmcv可以直接pip install。mmdet稍微复杂点,貌似我用的git clone,git clone安装请自行百度,如果不行的话试试pip install mmdet。哈哈哈抱歉我也不太记得了,反正就是尝试。安装好了之后可能会报错,貌似是要更新一下mmcv还是mmdet的版本。(下次我一定做笔记)总之win10是可以安装mmdetection,虽然官网没有给教程。
2、将自己的数据集制作成coco格式
楼主表达能力很差,直接上代码吧。参考:https://blog.youkuaiyun.com/qq_15969343/article/details/80848175。做了一些小修改
import json
import os
import cv2
import shutil
dataset = {'categories': [], 'images': [], 'annotations': []}
# 根路径,里面包含images(图片文件夹),annos.txt(bbox标注),classes.txt(类别标签),以及annotations文件夹(如果没有则会自动创建,用于保存最后的json)
root_path = r'D:\PycharmProjects\GitHubProjects\yolov5-master\PigDetection\dataprocess\新建文件夹'
img_path = r'D:\PycharmProjects\GitHubProjects\yolov5-master\PigDetection\train_img_bbox'
# 用于创建训练集或验证集
phase = 'instances_val2017'
# 训练集和验证集划分的界线
split = 399
# 打开类别标签
with open(os.path.join(root_path, 'classes.txt')) as f:
classes = f.read().strip().split()
# 建立类别标签和数字id的对应关系
for i, cls in enumerate(classes, 1):
dataset['categories'].append({'supercategory': 'mark', 'id': i, 'name': cls})
# dataset['categories'].append({'supercategory': 'mark', 'id': 1, 'name': 'pig'})
# 读取images文件夹的图片名称
indexes = [f for f in os.listdir(os.path.join(root_path, 'images'))]
# 判断是建立训练集还是验证集
if phase == 'instances_train2017':
indexes = [line for i, line in enumerate(indexes) if i <= split]
train_path = os.path.join(root_path, 'coco/train2017')
if not os.path.exists(train_path):
os.makedirs(train_path)
for i in indexes:
shutil.copy(os.path.join(img_path, i), train_path)
elif phase == 'instances_val2017':
indexes = [line for i, line in enumerate(indexes) if i > split]
val_path = os.path.join(root_path, 'coco//val2017')
if not os.path.exists(val_path):
os.makedirs(val_path)
for i in indexes:
print(os.path.join(img_path, i))
shutil.copy(os.path.join(img_path, i), val_path)
# 读取Bbox信息
with open(os.path.join(root_path, 'annos.txt')) as tr:
annos = tr.readlines()
all = 0
for k, index in enumerate(indexes):
# 用opencv读取图片,得到图像的宽和高
# print(os.path.join(img_path, index))
img = cv2.imread(os.path.join(img_path, index))
height, width = img.shape[:2]
# 添加图像的信息到dataset中
dataset['images'].append({'file_name': index,
'id': k,
'width': width,
'height': height})
for ii, anno in enumerate(annos):
parts = anno.strip().split()
# print('parts:', parts)
# 如果图像的名称和标记的名称对上,则添加标记
if parts[0] == index:
# 类别
cls_id = parts[1]
# x_min
x1 = float(parts[2])
# y_min
y1 = float(parts[3])
# x_max
x2 = float(parts[4])
# y_max
y2 = float(parts[5])
width = max(0, x2 - x1)
height = max(0, y2 - y1)
dataset['annotations'].append({
'area': width * height,
'bbox': [x1, y1, width, height],
'category_id': int(cls_id),
'id': all,
'image_id': k,
'iscrowd': 0,
# mask, 矩形是从左上角点按顺时针的四个顶点
'segmentation': [[x1, y1, x2, y1, x2, y2, x1, y2]]
})
all += 1
# 保存结果的文件夹
folder = os.path.join(root_path, 'annotations')
if not os.path.exists(folder):
os.makedirs(folder)
json_name = os.path.join(root_path, 'annotations/{}.json'.format(phase))
with open(json_name, 'w') as f:
json.dump(dataset, f)
3、训练自己的数据集
首先下载mmdetection代码到本地,官网:https://github.com/open-mmlab/mmdetection。貌似官网教程说的挺详细的,我在这里仅仅说一些要改动的坑。以configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py为例
1、修改class_names。路径C:\ProgramData\Anaconda3\Lib\site-packages\mmdet-2.14.0-py3.8.egg\mmdet\core\evaluation(博主之前修改的是下载的mmdetection里面的mmdet包一直报错)
def coco_classes():
# return [
# 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
# 'truck', 'boat', 'traffic_light', 'fire_hydrant', 'stop_sign',
# 'parking_meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
# 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
# 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
# 'sports_ball', 'kite', 'baseball_bat', 'baseball_glove', 'skateboard',
# 'surfboard', 'tennis_racket', 'bottle', 'wine_glass', 'cup', 'fork',
# 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
# 'broccoli', 'carrot', 'hot_dog', 'pizza', 'donut', 'cake', 'chair',
# 'couch', 'potted_plant', 'bed', 'dining_table', 'toilet', 'tv',
# 'laptop', 'mouse', 'remote', 'keyboard', 'cell_phone', 'microwave',
# 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
# 'scissors', 'teddy_bear', 'hair_drier', 'toothbrush'
# ]
return ['pig']
2、修改coco.py。路径C:\ProgramData\Anaconda3\Lib\site-packages\mmdet-2.14.0-py3.8.egg\mmdet\datasets
@DATASETS.register_module()
class CocoDataset(CustomDataset):
# CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
# 'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
# 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
# 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
# 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
# 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat',
# 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
# 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
# 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
# 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
# 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',
# 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
# 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock',
# 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
CLASSES = ('pig',)
假如只有一个类别就需要加上逗号。不加会报错让你加逗号
3、修改faster_rcnn_r50_fpn.py,(因为博主的config参数是faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py,在configs里面打开这个文件可以看到models的路径)之后找到faster_rcnn_r50_fpn.py文件把num_classes这个参数全部都改为你的类别个数。这里只需要改动一个即可。路径D:\PycharmProjects\GitHubProjects\mmdetection-master\mmdetection-master\tools\configs_base_\models\faster_rcnn_r50_fpn.py。这里说明一下为什么configs在tools文件夹下。貌似win10系统不放到这里会报错,所以我直接把configs文件夹复制到这里了。
4、运行tools包下面的trian.py。配置参数是configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py。成功!(写的有点粗糙,慢慢再补充)