eriklindernoren的YOLOv3如何使用自己的数据集进行train
由于该yolov3没有讲清楚如何训练自己的数据,所以写个博客
当然你需要提供的书image和xml文件。
参考:
1.pytorch版本的 yolov3https://github.com/eriklindernoren/PyTorch-YOLOv3
2.gaintpanda cv的电子书:从零开始学习yolov3
链接:https://pan.baidu.com/s/14y4tYW_5szm8Y_8XTvLgpg 提取码:xflx
1.data下面文件夹分布
2.修改参数
1.把图片放进image,xml文件放进annotation
2.classes.names修改自己的类别,本人使用的是voc2007的数据
3.修改所需的yolov3_custom.cfg文件内部的数据
主要是yolo层的classnum,和filter(以voc2007)
类别为20,通道数为3*(20+5)=75
4.修改config文件夹下的custom.data
这边有个小错误,valid.txt改为val.txt
3.使用下面的generate文件进行一系列txt生成(改文件放置于custom目录下)
运行后会有点慢,耐心!
from __future__ import division
import os
import random
import xml.etree.ElementTree as ET
from os import getcwd
from os.path import join
save_base_path ="Main"
xmlfilepath =r"annotation"
# 生成Main中文件,对应的为xml文件的序号
# trainval_percent表示trainval取90%的数据
# train_percent表示train取trainval得90%
trainval_percent = 0.9
train_perent = 0.9
total_xml = os.listdir(xmlfilepath)
num = len(total_xml)
list = range(num)
tv = int(num*trainval_percent)
tr = int(tv*train_perent)
trainval = random.sample(list,tv)
train = random.sample(trainval,tr)
print("train and val size",tv)
print("train size",tr)
ftrainval = open(join(save_base_path,"train.txt"),"w")
ftest = open(join(save_base_path,"test.txt"),"w")
ftrain = open(join(save_base_path,"train.txt"),"w")
fval = open(join(save_base_path,"val.txt"),"w")
for i in list:
name=total_xml[i][:-4]+'\n'
if i in trainval:
ftrainval.write(name)
if i in train:
ftrain.write(name)
else:
fval.write(name)
else:
ftest.write(name)
ftrainval.close()
ftrain.close()
fval.close()
ftest .close()
# 生成对应的图片路径以及labels下归一化结果(index,x,y,w,h)
sets=["train","test","val"]
classes = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]
# 进行归一化
def convert(size,box):
dw = 1.0/size[0]
dh = 1.0/size[1]
x = (box[0]+box[1])/2.0
y = (box[2]+box[3])/2.0
w = box[1]-box[0]
h = box[3]-box[2]
x = x*dw
w = w*dw
y = y*dh
h = h*dh
return (x,y,w,h)
def convert_annotation(image_id):
#将数据集放于当前目录下
in_file = open("./annotation/%s.xml"%(image_id))
outfile = open("./labels/%s.txt"%(image_id),"w")
tree = ET.parse(in_file)
#获取第一标签
root = tree.getroot()
size = root.find("size")
w = int(size.find("width").text)
h = int(size.find("height").text)
for obj in root.iter("object"):
difficult =obj.find("difficult").text
cls = obj.find("name").text
if cls not in classes or int(difficult)==1:
continue
cls_id = classes.index(cls)
xmlbox = obj.find("bndbox")
b = (float(xmlbox.find('xmin').text),float(xmlbox.find("xmax").text),float(xmlbox.find('ymin').text),\
float(xmlbox.find('ymax').text))
bb = convert((w,h),b)
outfile.write(str(cls_id)+" "+" ".join(str(i) for i in bb)+"\n")
wd =getcwd()
for image_set in sets:
image_ids = open("./Main/%s.txt"%(image_set)).read().strip().split()
list_file = open("./generate_path/%s.txt"%(image_set),"w")
for image_id in image_ids:
list_file.write("data/custom/images/%s.jpg\n"%(image_id))
convert_annotation(image_id)
list_file.close()
4.train下面修改自己的默认参数
略
batchsize最好小一点,取1-4好了