1
class BraTS_ori(Dataset):
def __init__(self, img_paths, mask_paths):
class BraTS_new(Dataset):
def __init__(self, root,image_file):
2
datasets = {
'new': BraTS_new,
'ori': BraTS_ori,
}
def get_segmentation_dataset(name, **kwargs):
return datasets[name.lower()](**kwargs)
3 get_segmentation_dataset传参数时,额外的参数需要 关键字参数=要传入的参数
train_dataset = get_segmentation_dataset('new', root=root, image_file = train_file)
val_dataset = get_segmentation_dataset('new', root=root, iamge_file = valid_file)