torch.utils.data.TensorDataset(data_tensor, target_tensor)
包装数据和目标张量的数据集。
通过沿着第一个维度索引两个张量来恢复每个样本。
参数:
data_tensor (Tensor) - 包含样本数据
target_tensor (Tensor) - 包含样本目标(标签)
class TensorDataset(Dataset):
r"""Dataset wrapping tensors.
Each sample will be retrieved by indexing tensors along the first dimension.
Arguments:
*tensors (Tensor): tensors that have the same size of the first dimension.
"""
'''数据集包装张量。
每个样本将通过沿第一维索引张量来检索。
参数:
*张量(张量):具有与第一维相同大小的张量。'''
def __init__(self, *tensors):
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
self.tensors = tensors
def __getitem__(self, index):
return tuple(tensor[index] for tensor in self.tensors)
def __len__(self):
return self.tensors[0].size(0)
本文详细介绍了PyTorch中TensorDataset的使用方法,这是一种用于包装数据和目标张量的数据集,通过沿第一维度索引张量来恢复每个样本。文章深入解析了构造函数参数,包括data_tensor和target_tensor,并提供了类的实现代码。
1278

被折叠的 条评论
为什么被折叠?



