pytorch Dataset 的ImageFolder
ImageFolder例子
def load_data(root_dir,domain,batch_size):
transform = transforms.Compose([
transforms.Grayscale(),
transforms.Resize([28, 28]),
transforms.ToTensor(),
transforms.Normalize(mean=(0,0,0),std=(1,1,1)),
]
)
image_folder = datasets.ImageFolder(
root=root_dir + domain,
transform=transform
)
data_loader = torch.utils.data.DataLoader(dataset=image_folder,batch_size=batch_size,shuffle=True,num_workers=2,drop_last=True
)
return data_loader
data_src = data_loader.load_data(
root_dir=rootdir, domain='amazon', batch_size=BATCH_SIZE[0])
for e in tqdm(range(1, N_EPOCH + 1)):
model = train(model=model, optimizer=optimizer,
epoch=e, data_src=data_src, data_tar=data_tar)
【PyTorch学习笔记】14:划分训练-验证-测试集,使用正则化项
torch.utils.data.random_split源码
本文详细介绍如何使用PyTorch的ImageFolder加载图像数据集,包括数据预处理、批处理和加载过程。通过实例演示了如何定义数据转换,如灰度转换、调整大小、张量化及归一化,并展示如何创建DataLoader进行数据批量加载。
2092

被折叠的 条评论
为什么被折叠?



