解决方法如下: _, y_hat = torch.max(y_hat, 1) # or y_hat = torch.argmax(y_hat.to('cpu'), axis=1).to('mps') 实测,第一个快一秒