源码:datasets.py
- 获取图片路径;
def __getitem__(self, index):
img_path = self.img_files[index % len(self.img_files)].rstrip()
- 读取图片,转为
RGB
格式,并且数据格式转为tensor
;
img = transforms.ToTensor()(Image.open(img_path).convert('RGB'))
- 数据预处理,如果数据不是三个
channels
,转换成三个;
if len(img.shape) != 3:
img = img.unsqueeze(0)
img = img.expand((3, img.shape[1:]))
- 数据如果不是正方形的,做
pad
处理,长方形转成正方形,缺失的做一个填补;
_, h, w = img.shape
h_factor, w_factor = (h, w) if self.normalized_labels else (1, 1)
# Pad to square resolution
img, pad = pad_to_square(img, 0)
_, padded_h, padded_w = img.shape
- 读取标签;