学习笔记:数据集Dataset
代码:datasets.py
import math
import os
import shutil
import random
from pathlib import Path
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
from PIL import Image, ExifTags
from tqdm import tqdm
from build_utils.utils import xyxy2xywh, xywh2xyxy
help_url = 'https://github.com/ultralytics/yolov3/wiki/Train-Custom-Data'
img_formats = ['.bmp', '.jpg', '.jpeg', '.png', '.tif', '.dng'] # 支持的图片格式
# get orientation in exif tag
# 找到图像exif信息中对应旋转信息的key值
for orientation in ExifTags.TAGS.keys():
if ExifTags.TAGS[orientation] == "Orientation":
break
def exif_size(img):
"""
获取图像的原始img size
通过exif的orientation信息判断图像是否有旋转,如果有旋转则返回旋转前的size
:param img: PIL图片
"""
s = img.size # (w, h)
try:
rotation = dict(img._getexif().items())[orientation]
if rotation == 6: # rotation 270 顺时针翻转90度
s = (s[1], s[0])
elif rotation == 8: # ratation 90 逆时针翻转90度
s = (s[1], s[0])
except:
pass # 如果图像的exif信息中没有旋转信息,则跳过
return s
class LoadImagesAndLabels(Dataset):
def __init__(self, path, img_size, batch_size, augment=False, hyp=None, rect=False, cache_images=False,
single_cls=False,
pad=0.0, rank=-1):
"""
path: data / my_train_data.txt路径 或 data/my_val_data.txt路径
argment: 训练集设置为True(augment_hsv),验证集设置为False
hpy: 超参数字典
rect: 是否使用rectangular training
cache_images: 是否缓存图片到内存中
"""
try:
path = str(Path(path))
if os.path.isfile(path):
with open(path, 'r') as r:
f = r.read().splitlines()
else:
raise Exception("%s does not exist" % path)
self.img_files = [x for x in f if os.path.splitext(x)[-1].lower() in img_formats] # 是支持的图片格式
self.img_files.sort() # 防止不同系统排序不同,导致shape文件出现差异
except Exception as e:
raise FileNotFoundError("error...")
n = len(self.img_files)
assert n > 0, "not image in self.img_files"
# 按照 batch_size 划分batch
bi = np.floor(np.arange(n) / batch_size).astype(np.int)
nb = bi[-1] + 1 # number of batch
self.n = n # 图像总数
self.batch = bi # 记录哪些图片属于哪个batch
self.img_size = img_size
self.augment = augment # 是否启用augment_hsv
self.hyp = hyp # 超参数字典
self.rect = rect # 是否使用rectangular training
self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training)
# 遍历设置图像对应的label路径
# (./my_yolo_dataset/train/images/2009_004012.jpg) -> (./my_yolo_dataset/train/labels/2009_004012.txt)
self.label_files = [x.replace("images", "labels").replace(os.path.splitext(x)[-1], ".txt")
for x in self.img_files]
# 查看data文件下是否缓存有对应数据集的.shapes文件,里面存储了每张图像的width, height
sp = path.replace(".txt", ".shape")
try:
with open(sp, "r") as f:
s = [x.split() for x in f.read().splitlines()]
assert len(s) == n, "shapefile out of aync" # shape文件中的行数 应该与 图像个数对应
except Exception as e:
# 读取每张图片的size信息
if rank in [-1, 0]:
image_files = tqdm(self.img_files, desc='Reading image shapes')
else:
image_files = self.img_files
s = [exif_size(Image.open(f)) for f in image_files]
np.savetxt(sp, s, fmt="%g") # 将所有图片的shape信息保存在.shape文件中
self.shapes = np.array(s, dtype=np.float64)
if self.rect:
# 如果为ture,训练网络时,会使用类似原图像比例的矩形(让最长边为img_size),而不是img_size x img_size
s = self.shapes # wh
# 计算每个图片的高/宽比
ar = s[:, 1] / s[:, 0] # aspect ratio
# argsort函数返回的是数组值从小到大的索引值
# 按照高宽比例进行排序,这样后面划分的每个batch中的图像就拥有类似的高宽比
irect = ar.argsort()
# 根据排序后的顺序重新设置图像顺序、标签顺序以及shape顺序
self.img_files = [self.img_files[i] for i in irect]
self.label_files = [self.label_files[i] for i in irect]
self.shapes = s[irect] # wh
ar = ar[irect]
# set training image shapes
# 计算每个batch采用的统一尺度
shapes = [[1, 1]] * nb # nb: number of batches
for i in range(nb):