TensorDataset
TensorDataset
是PyTorch中torch.utils.data
模块的一部分,它包装张量到一个数据集中,并允许对这些张量进行索引,以便能够以批量的方式加载它们。
当你有多个数据源(如特征和标签)时,TensorDataset
能够让你把它们打包成一个数据集,这在训练模型时非常有用。
介绍
TensorDataset
接收任意数量
的张量作为输入,前提
是这些张量的第一维度大小(也就是数据点的数量
)相同。
每个张量的第一维被视为数据的长度。当对TensorDataset
进行索引时,它会返回一个元组,其中包含每个张量在对应索引处的数据。
用法示例
下面是一个使用TensorDataset
的简单示例,包括如何创建它,以及如何与DataLoader
结合使用,以便于批量加载数据
。
首先,你需要有一些数据。在这个例子中,我们将创建一些随机数据来模拟特征(X
)和标签(y
)。
import torch