def get_data_set(args, train_flag=True):
if train_flag:
data_set = torchvision.datasets.__dict__[args.data_set]
(root=args.data_path, train=True,
transform=get_transformer(args.data_set, args.imsize,
args.cropsize, args.crop_padding, args.hflip), download=True)
else:
data_set = torchvision.datasets.__dict__[args.data_set]
(root=args.data_path, train=False,
transform=get_transformer(args.data_set), download=True)
return data_set
data_set = get_data_set(args, train_flag=True)
此篇博客介绍了如何使用PyTorch的torchvision模块加载数据集,区分了训练和验证数据集,并应用了自定义的transform参数。讲解了根据不同数据集选择对应的dataset类并设置训练状态和预处理步骤。
1126

被折叠的 条评论
为什么被折叠?



