从coco数据集中提取子集-yolov5使用coco数据集子集训练测试
本文使用coco2017,暂未验证其他数据集
1. python提取子集
文件名 get_coco_datasets.py
#注意这里测试的是coco2017
import os
import shutil
src_dir = '../datasets/coco' #使用相对路径,需要根据自己的路径做修改
dst_dir = '../datasets/cat_dog_coco' #指定子集保存位置
#src_dir : coco数据集的路径,只需要指定到根目录
#dst_dir : 子集保存位置
#data_type : 挑选训练集的数据还是验证集,这里可选train2017 val2017
#sub_list : 将自己需要的子集的names数据新建一个list传入,main有使用例子
def get_coco_dataset(src_dir,dst_dir,data_type = 'train2017',sub_list = []):
if not os.path.exists(dst_dir): os.makedirs(dst_dir)
if not os.path.exists('{}/images/train2017'.format(dst_dir)):os.makedirs('{}/images/train2017'.format(dst_dir))
if not os.path.exists('{}/images/val2017'.format(dst_dir)): os.makedirs('{}/images/val2017'.format(dst_dir))
if not os.path.exists('{}/labels/train2017'.format(dst_dir)): os.makedirs('{}/labels/train2017'.format(dst_dir))
if not os.path.exists('{}/labels/val2017'.format(dst_dir)): os.makedirs('{}/labels/val2017'.format(dst_dir))
file_name = os.listdir(os.path.join(src_dir, 'images', data_type))
name_list = list()
label_list = list()
for name in file_name:
img_path = os.path.join(src_dir, 'images', data_type, name)
label_path = os.path.join(src_dir, 'labels', data_type, name)
try:
with open(label_path.replace('jpg', 'txt')) as f:
content = f.readlines()
iscatdog = 0
for line in content:
old_line = line
old_line = old_line.split()
line = line.strip()
data = line.split()
if int(data[0]) in sub_list:
iscatdog = 1
print(name)
shutil.copy(img_path, os.path.join(dst_dir,'images',data_type, name))
name_list.append(os.path.join('./images',data_type, name) + "\n")
old_line[0] = str(sub_list.index(int(old_line[0])))
new_line = ' '.join(old_line)
label_list.append(new_line)
if iscatdog == 1:
try:
with open(os.path.join(os.path.join(dst_dir,'labels',data_type), name.replace('jpg', 'txt')), 'x') as sf:
sf.writelines(label_list)
label_list.clear()
except FileExistsError:
pass
except FileNotFoundError:
pass
try:
with open(os.path.join(dst_dir, data_type+'.txt'), 'x') as va:
va.writelines(name_list)
name_list.clear()
except FileExistsError:
pass
print('success! save path to '+data_type+'.txt')
if __name__ == '__main__':
cat_dog_list = [15,16] #通过查看coco训练集的yaml文件得到15代表cat 16代表dog,新建一个list传入get_coco_dataset函数即可
get_coco_dataset(src_dir,dst_dir,'train2017',cat_dog_list )
get_coco_dataset(src_dir,dst_dir,'val2017',cat_dog_list )
2. 指定data位置yaml
在yolov5源码中,找到data文件夹,新建一个yaml文件,文件名cat_dog.yaml,填入如下内容
# ├── yolov5
# └── datasets
# └── coco ← downloads here (20.1 GB)
# └── cat_dog_coco
# └── train2017.txt
# └── val2017.txt
path: ../datasets/cat_dog_coco# dataset root dir
train: train2017.txt # train images
val: val2017.txt # val images
# Classes
names:
0: cat
1: dog
3. 使用yolov5训练挑选出来的子集
文件位置说明
yolov5–>get_coco_datasets.py
yolov5–>data–>cat_dog.yaml
修改模型训练数量
这里选用yolov5x作为测试模型,需要找到yolov5–>models–>yolov5x.yaml文件
将nc修改为你子集的class数量,这里为2
训练代码
python .\train.py --data cat_dog.yaml --weight yolov5x.pt --img 640 --batch-size 32 --cfg .\models\yolov5x.yaml