关注B站查看更多手把手教学:
基本用法
torch.utils.data.Dataset
是 PyTorch 中一个非常重要的抽象类,它用于表示数据集,方便数据加载和预处理。通过实现这个类的两个方法 __len__
和 __getitem__
,你可以自定义自己的数据集类。__len__
方法应返回数据集的大小(即样本数),而 __getitem__
方法则根据给定的索引返回一个样本。
以下是一个简单的示例,说明如何使用 torch.utils.data.Dataset
创建一个自定义的数据集类:
import torch
from torch.utils.data import Dataset
class MyCustomDataset(Dataset):
def __init__(self, data, targets):
"""
参数:
data: 样本数据, 形状为 [num_samples, ...] (例如 [num_samples, num_channels, height, width])
targets: 样本标签, 形状为 [num_samples, ...] (例如 [num_samples])
"""
self.data = data
self.targ