DataLoader
Dataloader可以将自己的数据装换成Tensor,然后有效的迭代数据。可以很有效的简化数据的读取过程,方便炼丹。
一、 首先介绍一个简单的例子:
- 加载头文件:
import torch
import torch.utils.data as Data
torch.manual_seed(1)
- 生成torch数据
x = torch.linspace(1, 10, 10)
y = torch.linspace(10, 1, 10)
- 将生成的数据做成一个DataSet和Dataloader
torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(
dataset = torch_dataset,
batch_size = BATCH_SIZE,
shuffle = True,
num_workers = 2
)
- 利用Dataloader来迭代数据
BATCH_SIZE = 5
for epoch in range(3):
for step, (batchX, batchY) in enumerate(loader):
print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',
batchX.numpy(), '| batch y: ', batchY.numpy())
输出:
Epoch: 0 | Step: 0 | batch x: [ 4. 6. 7. 10. 8.] | batch y: [7. 5. 4. 1. 3.]
Epoch: 0 | Step: 1 | batch x: [5. 3. 2. 1. 9.] | batch y: [ 6. 8. 9. 10. 2.]
Epoch: 1 | Step: 0 | batch x: [ 4. 2. 5. 6. 10.] | batch y: [7. 9. 6. 5. 1.]
Epoch: 1 | Step: 1 | batch x: [3. 9. 1. 8. 7.] | batch y: [ 8. 2. 10. 3. 4.]
Epoch: 2 | Step: 0 | batch x: [ 4. 10. 9. 8. 7.] | batch y: [7. 1. 2. 3. 4.]
Epoch: 2 | Step: 1 | batch x: [6. 1. 2. 5. 3.] | batch y: [ 5. 10. 9. 6. 8.]
二、batchsize 不能被 数据长度整除
上面一个玩具例子中,我们可以发现batchsize=5, 数据长度为10,刚好两个step可以取尽数据。如果batchsize=8呢,我们发现,第二次迭代数据时,数据长度只剩下2
loader = Data.DataLoader(
dataset = torch_dataset,
batch_size = 8,
shuffle = True,
num_workers = 2,
drop_last=True
)
for epoch in range(3):
for step, (batchX, batchY) in enumerate(loader):
print('Epoch: ', epoch, '| Step: ', step, '| batch x: '