如题,在用m1max的mps加速,对CIFAR10进行图片训练的时候出现上述情况。
前提方式:
已经引入了
device = torch.device("mps")
解决方式:添加—input.to(device)
在输入端加入如下代码
imgs = imgs.to(device)
targets = targets.to(device)
原来的代码:
for data in train_dataloader:
imgs, targets = data
output = model(imgs)
loss = loss_fn(output, targets)
加入后:
for data in train_dataloader:
imgs, targets = data
imgs = imgs.to(device)
targets = targets.to(device)
output = model(imgs)
loss = loss_fn(output, targets)
亲测解决问题