Talk is cheap,show me the code! 哈哈,先上几段常用的代码,以语义分割的DRIVE数据集加载为例:
DRIVE数据集的目录结构如下,下载链接DRIVE,如果官网下不了,到Kaggle官网可以下到:
1. 定义DriveDataset类,每行代码都加了注释,其中collate_fn()看不懂没关系:
import os
from PIL import Image
import numpy as np
from torch.utils.data import Dataset
class DriveDataset(Dataset): # 继承Dataset类
def __init__(self, root: str, train: bool, transforms=None):
super(DriveDataset, self).__init__()
self.flag = "training" if train else "test" # 根据train这个布尔类型确定需要处理的是训练集还是测试集
data_root = os.path.join(root, "DRIVE", self.flag) # 得到数据集根目录
assert os.path.exists(data_root), f"path '{data_root}' does not exists." # 判断路径是否存在
self.transforms = transforms # 初始化图像变换操作
img_names = [i for i in os.listdir(os.path.join(data_root, "images")) if i.endswith(".tif")] # 遍历图像文件夹获取每个图像的文件名
self.img_list = [os.path.join(data_root, "images", i) for i in img_names] # 获取图像路径
self.manual = [os.path.join(data_root,