- 参考
https://github.com/bubbliiiing/yolo3-pytorch
train.py流程:加载数据dataloader.py正向传播tiny.py反向传播loss.py
- dataloader.py
import cv2
import numpy as np
from PIL import Image
from torch.utils.data.dataset import Dataset
class YoloDataset(Dataset):
def __init__(self, annotation_lines, input_shape, train):
super(YoloDataset, self).__init__()
self.annotation_lines = annotation_lines
self.input_shape = input_shape
self.length = len(self.annotation_lines)
self.train = train
def __len__(self):
return self.length
def __getitem__(self, index):
index = index % self.length
image, box = self.get_random_data(self.annotation_lines[index], self.input_shape[0:2], random = self.train)
image = np.transpose(np.array(image, dtype=np.float32)/255.0, (2, 0, 1))
box = np.array(box, dtype=np.float32)
if len(box) != 0:
box[:, [0, 2]] = box[:, [0, 2]] / self.input_shape[1]
box[:, [1, 3]] = box[:, [1, 3]] / self.input_shape[0]
box[:, 2:4] = box[:, 2:4] - box[:, 0:2]
box[:, 0:2] = box[:, 0:2] + box[:, 2:4] / 2
return image, box
def rand(self, a=0, b=1):
return np.random.rand()*(b-a) + a
def get_random_data(self, annotation_line, input_shape, jitter=.3, hue=.02, sat=1.5, val=1.5, random=True):
line = annotation_line.split()
label_line = line[0][:-4]+'.txt'
boxes = []
for lin in open(label_line):
t = lin.split()
boxes.append([t[1],t[2],t[3],t[4],t[0]])
box = np.array(boxes, dtype=np.float32)
image = Image.open(line[0])
iw, ih = image.size
h, w = input_shape
if len(box) > 0:
box[:, [0,2]] = box[:, [0,2]] * iw
box[:, [1,3]] = box[:, [1,3]] * ih
box[:, 0:2] = box[:, 0:2] - box[:, 2:4] / 2
box[:, 2:4] = box[:, 0:2] + box[:, 2:4]
if not random:
scale = min(w/iw, h/ih)
nw = int(iw*scale)
nh = int(ih*scale)
dx = (w-nw)//2
dy = (h-nh)//2
image = image.resize((nw,nh), Image.BICUBIC)
new_image = Image.new('RGB', (w,h), (128,128,128))
new_image.paste(image, (dx, dy))
image_data = np.array(new_image, np.float32)
if len(box)>0:
np.random.shuffle(box)
box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx
box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy
box[:, 0:2][box[:, 0:2]<0] = 0
box[:, 2][box[:, 2]>w] = w
box[:, 3][box[:, 3]>h] = h
box_w = box[:, 2] - box[:, 0]
box_h = box[:, 3] - box[