问题描述:
在进行pytorch神经网络批训练的时候,有时会出现报错 TypeError: batch must contain tensors, numbers, dicts or lists; found <class 'torch.autograd.variable.Variable'>
解决办法:
第一步
检查(重点!!!!!):
train_dataset = Data.TensorDataset(train_x, train_y)
train_x,和train_y格式,要求是tensor类,我第一次出错就是因为传入的是variable
可以这样将数据变为tensor类:
train_x = torch.FloatTensor(train_x)
第二步:
train_loader = Data.DataLoader(
dataset=train_dataset,
batch_size=batch_size,
shuffle=True
)
实例化一个DataLoader对象
第三步:
for epoch in range(epochs):
for step, (batch_x, batch_y) in enumerate(train_loader):
batch_x, batch_y = Variable(batch_x), Variable(batch_y)
这样就可以批训练了
需要注意的是:train_loader输出的是tensor,在训练网络时,需要变成Variable