- 包装数据和目标张量的数据集。
torch.utils.data.TensorDataset(data_tensor, target_tensor)
x_1 = torch.arange(30).reshape(-1,3)
y_1 = torch.arange(10)*3
# TensorDataset对tensor进行打包
dataset = data.TensorDataset(x_1, y_1)
for x_1_train, y_1_train in dataset:
print(x_1_train, y_1_train)
# dataloader进行数据封装
print('=' * 80)
train_loader = data.DataLoader(dataset=dataset, batch_size=4, shuffle=True)
for i, data_ in enumerate(train_loader):
# enumerate返回值有两个,一个是序号,一个是数据(包含训练数据和标签)
x_1_train, y_1_train = data_
print(f' batch:{i+1} x_data:{x_1_train} y_data:{y_1_train}')
运行解果:
tensor([0, 1, 2]) tensor(0)
tensor([3, 4, 5]) tensor(3)
tensor([6, 7, 8]) tensor(6)
tensor([ 9, 10, 11]) tensor(9)
tensor([12, 13, 14]) tensor(12)
tensor([15, 16, 17]) tensor(15)
tensor([18, 19, 20]) tensor(18)
tensor([21, 22