pytorch自定义Dataset并使用torchvision的Transform

最近用了pytorch, 使用上比Tensorflow爽的多,尤其是在读取数据的部分,冗长而繁杂的api令人望而却步,而且由于Tensorflow不支持与numpy的无缝切换,导致难以使用现成的pandas等格式化数据读取工具,造成了很多不必要的麻烦

pytorch自定义读取数据和进行Transform的部分请见文档:
http://pytorch.org/tutorials/beginner/data_loading_tutorial.html

但是按照文档中所描述所完成的自定义Dataset只能够使用自定义的Transform步骤,而torchvision包中已经给我们提供了很多图像transform步骤的实现,为了使用这些已经实现的Transform步骤,我们可以使用如下方法定义Dataset:

class FaceLandmarkDataset(Dataset):
    def __len__(self) -> int:
        return len(self.
### 如何在 PyTorch使用 Dataset 为了理解如何在 PyTorch使用 `Dataset` 类,可以考虑一个具体的例子来加载图像数据集对其进行转换处理。 通过 `torchvision.datasets.ImageFolder` 可以方便地将文件夹中的图片读取到一个临时的数据集中。这允许快速创建自定义数据集而无需编写额外的类[^1]: ```python import torch from torch.utils.data import Dataset, DataLoader from torchvision.datasets import ImageFolder from torchvision import transforms train_augs = transforms.Compose([ transforms.Resize(224), transforms.ToTensor() ]) data_images = ImageFolder(root='../classify-leaves', transform=train_augs) ``` 上述代码展示了如何设置基本变换(调整大小和转为张量),将这些变换应用于从指定路径加载的所有图像。这里的关键在于传递给 `ImageFolder` 的两个参数:根目录 (`root`) 和变换函数 (`transform`)。 对于更复杂的场景,可能需要继承 `Dataset` 实现自己的子类。通常情况下只需要重写三个方法即可完成定制化需求: - 初始化方法 `__init__()`: 设置初始化属性,比如加载数据列表、标签等; - 获取长度的方法 `__len__()`: 返回整个数据集样本数量; - 下标索引访问方法 `__getitem__(idx)`: 支持按索引获取单个条目; 下面是一个简单的模板用于构建自定义数据集: ```python class CustomDataset(Dataset): def __init__(self, data_list, labels=None, transform=None): self.data = data_list self.labels = labels self.transform = transform def __len__(self): return len(self.data) def __getitem__(self, idx): sample = self.data[idx] if self.transform: sample = self.transform(sample) if self.labels is not None: label = self.labels[idx] return sample, label return sample ``` 此结构使得能够灵活地适应不同类型的输入源,轻松集成各种预处理操作。当与 `DataLoader` 结合使用时,还可以简化批量采样过程以及多线程加速等功能[^3]。
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值