图像加载
# 自定义数据集,重写抽象方法
class LoadImagesAndLabels(Dataset):
# YOLOv5 train_loader/val_loader, loads images and labels for training and validation
cache_version = 0.6 # dataset labels *.cache version
rand_interp_methods = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4]
def __init__(self,
path,
img_size=640,
batch_size=16,
augment=False,
hyp=None,
rect=False,
image_weights=False,
cache_images=False,
single_cls=False,
stride=32,
pad=0.0,
min_items=0,
prefix=''):
self.img_size = img_size # 输入图片的分辨率大小
self.augment = augment # 数据增强
self.hyp = hyp # 超参数
self.image_weights = image_weights # 图片采样权重
self.rect = False if image_weights else rect # 矩形训练
# mosaic数据增强
self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training)
# mosaic增强的边界值
self.mosaic_border = [-img_size // 2, -img_size // 2]
self.stride = stride # 模型下采样的步长
self.path = path
self.albumentations = Albumentations(size=img_size) if augment else None
# 加载图像路径
try:
f = [] # image files
for p in path if isinstance(path, list) else [path]:
# 获取数据集路径path,包含图片路径的txt文件或者包含图片的文件夹路径
# 使用pathlib.Path生成与操作系统无关的路径
p = Path(p) # os-agnostic
# 处理目录路径
# 如果 p 是一个目录 (p.is_dir()),使用 glob.glob 递归查找目录下所有的文件。
if p.is_dir(): # dir
# str(p / '**' / '*.*') 构造了一个模式来匹配所有文件。'**' 表示递归到所有子目录,'*.*' 匹配所有文件类型。
f += glob.glob(str(p / '**' / '*.*'), recursive=True)
# f = list(p.rglob('*.*')) # pathlib
# 处理文件路径
# 如果 p 是一个文件 (p.is_file()),假设该文件包含图像文件路径,每行一个路径。
elif p.is_file(): # file
with open(p) as t:
# 打开文件并读取其内容,使用strip()去掉首尾空白字符,使用splitlines()按行分割文件内容
t = t.read().strip().splitlines()
# parent文件的父目录路径,os.sep是操作系统的路径分隔符
parent = str(p.parent) + os.sep
# 遍历文件中每一行路径,如果路径以./开头,则替换为父目录路径。否则直接使用路径
f += [x.replace('./', parent, 1) if x.startswith('./') else x for x in t] # to global path
# f += [p.parent / x.lstrip(os.sep) for x in t] # to global path (pathlib)
# 处理路径不存在的情况
else:
# p 既不是文件也不是目录,则抛出一场
raise FileNotFoundError(f'{prefix}{p} does not exist')
# 过滤和排序图像文件路径
# self.im_files 存储所有图像文件的路径
self.im_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS)
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
# 验证结果和异常处理
assert self.im_files, f'{prefix}No images found'
except Exception as e:
raise Exception(f'{prefix}Error loading data from {path}: {e}\n{HELP_URL}') from e
预处理
缩放与填充:使用 letterbox 函数将图像调整到指定尺寸,同时保持原始宽高比,并用填充填补边界。
颜色通道转换:将图像从 BGR 格式转换为 RGB 格式,以适应模型的输入要求。
归一化:将像素值缩放到 [0, 1] 范围,通常通过除以 255 实现。
维度变换:将图像维度从 HWC(高度、宽度、通道)转换为 CHW(通道、高度、宽度),并确保数据的连续性。
数据增强(可选):如果提供了自定义的转换函数,可以应用数据增强技术,如随机裁剪、翻转等,增强训练数据的多样性
def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
# Resize and pad image while meeting stride-multiple constraints
shape = im.shape[:2] # current shape [height, width]
if isinstance(new_shape, int):
new_shape = (new_shape, new_shape)
# Scale ratio (new / old)
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
if not scaleup: # only scale down, do not scale up (for better val mAP)
r = min(r, 1.0)
# Compute padding
ratio = r, r # width, height ratios
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
if auto: # minimum rectangle
dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
elif scaleFill: # stretch
dw, dh = 0.0, 0.0
new_unpad = (new_shape[1], new_shape[0])
ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
dw /= 2 # divide padding into 2 sides
dh /= 2
if shape[::-1] != new_unpad: # resize
im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
return im, ratio, (dw, dh)
训练
for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
callbacks.run('on_train_epoch_start')
model.train()
# Update image weights (optional, single-GPU only)
if opt.image_weights:
cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights
iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
# Update mosaic border (optional)
# b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
# dataset.mosaic_border = [b - imgsz, -b] # height, width borders
mloss = torch.zeros(3, device=device) # mean losses
if RANK != -1:
train_loader.sampler.set_epoch(epoch)
pbar = enumerate(train_loader)
LOGGER.info(('\n' + '%11s' * 7) % ('Epoch', 'GPU_mem', 'box_loss', 'obj_loss', 'cls_loss', 'Instances', 'Size'))
if RANK in {-1, 0}:
pbar = tqdm(pbar, total=nb, bar_format=TQDM_BAR_FORMAT) # progress bar
optimizer.zero_grad()
for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
callbacks.run('on_train_batch_start')
ni = i + nb * epoch # number integrated batches (since train start)
imgs = imgs.to(device, non_blocking=True).float() / 255 # uint8 to float32, 0-255 to 0.0-1.0
# Warmup
if ni <= nw:
xi = [0, nw] # x interp
# compute_loss.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round())
for j, x in enumerate(optimizer.param_groups):
# bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 0 else 0.0, x['initial_lr'] * lf(epoch)])
if 'momentum' in x:
x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']])
# Multi-scale
if opt.multi_scale:
sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs # size
sf = sz / max(imgs.shape[2:]) # scale factor
if sf != 1:
ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
# Forward
with torch.cuda.amp.autocast(amp):
pred = model(imgs) # forward
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
if RANK != -1:
loss *= WORLD_SIZE # gradient averaged between devices in DDP mode
if opt.quad:
loss *= 4.
# Backward
scaler.scale(loss).backward()
# Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
if ni - last_opt_step >= accumulate:
scaler.unscale_(optimizer) # unscale gradients
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0) # clip gradients
scaler.step(optimizer) # optimizer.step
scaler.update()
optimizer.zero_grad()
if ema:
ema.update(model)
last_opt_step = ni
# Log
if RANK in {-1, 0}:
mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
pbar.set_description(('%11s' * 2 + '%11.4g' * 5) %
(f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
callbacks.run('on_train_batch_end', model, ni, imgs, targets, paths, list(mloss))
if callbacks.stop_training:
return
# end batch ------------------------------------------------------------------------------------------------
# Scheduler
lr = [x['lr'] for x in optimizer.param_groups] # for loggers
scheduler.step()
if RANK in {-1, 0}:
# mAP
callbacks.run('on_train_epoch_end', epoch=epoch)
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
final_epoch = (epoch + 1 == epochs) or stopper.possible_stop
if not noval or final_epoch: # Calculate mAP
results, maps, _ = validate.run(data_dict,
batch_size=batch_size // WORLD_SIZE * 2,
imgsz=imgsz,
half=amp,
model=ema.ema,
single_cls=single_cls,
dataloader=val_loader,
save_dir=save_dir,
plots=False,
callbacks=callbacks,
compute_loss=compute_loss)
# Update best mAP
fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
stop = stopper(epoch=epoch, fitness=fi) # early stop check
if fi > best_fitness:
best_fitness = fi
log_vals = list(mloss) + list(results) + lr
callbacks.run('on_fit_epoch_end', log_vals, epoch, best_fitness, fi)
# Save model
if (not nosave) or (final_epoch and not evolve): # if save
ckpt = {
'epoch': epoch,
'best_fitness': best_fitness,
'model': deepcopy(de_parallel(model)).half(),
'ema': deepcopy(ema.ema).half(),
'updates': ema.updates,
'optimizer': optimizer.state_dict(),
'opt': vars(opt),
'git': GIT_INFO, # {remote, branch, commit} if a git repo
'date': datetime.now().isoformat()}
# Save last, best and delete
torch.save(ckpt, last)
if best_fitness == fi:
torch.save(ckpt, best)
if opt.save_period > 0 and epoch % opt.save_period == 0:
torch.save(ckpt, w / f'epoch{epoch}.pt')
del ckpt
callbacks.run('on_model_save', last, epoch, final_epoch, best_fitness, fi)
# EarlyStopping
if RANK != -1: # if DDP training
broadcast_list = [stop if RANK == 0 else None]
dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
if RANK != 0:
stop = broadcast_list[0]
if stop:
break # must break all DDP ranks
评价
# Metrics
for si, pred in enumerate(preds):
labels = targets[targets[:, 0] == si, 1:]
nl, npr = labels.shape[0], pred.shape[0] # number of labels, predictions
path, shape = Path(paths[si]), shapes[si][0]
correct = torch.zeros(npr, niou, dtype=torch.bool, device=device) # init
seen += 1
if npr == 0:
if nl:
stats.append((correct, *torch.zeros((2, 0), device=device), labels[:, 0]))
if plots:
confusion_matrix.process_batch(detections=None, labels=labels[:, 0])
continue
# Predictions
if single_cls:
pred[:, 5] = 0
predn = pred.clone()
scale_boxes(im[si].shape[1:], predn[:, :4], shape, shapes[si][1]) # native-space pred
# Evaluate
if nl:
tbox = xywh2xyxy(labels[:, 1:5]) # target boxes
scale_boxes(im[si].shape[1:], tbox, shape, shapes[si][1]) # native-space labels
labelsn = torch.cat((labels[:, 0:1], tbox), 1) # native-space labels
correct = process_batch(predn, labelsn, iouv)
if plots:
confusion_matrix.process_batch(predn, labelsn)
stats.append((correct, pred[:, 4], pred[:, 5], labels[:, 0])) # (correct, conf, pcls, tcls)