特征 ndarray 转 tensor torch.Tensor(train_examples[0][0])
报错:
TypeError: new(): data must be a sequence (got numpy.float64)
需要改成:
torch.Tensor(train_examples[0][0].reshape(1, n_feature))
特征 ndarray 转 tensor torch.Tensor(train_examples[0][0])
报错:
TypeError: new(): data must be a sequence (got numpy.float64)
需要改成:
torch.Tensor(train_examples[0][0].reshape(1, n_feature))