默认是CPU,如果想要用GPU需要:
- 安装配置cuda,然后更新/下载支持gpu版本的pytorch,可以参考:https://blog.youkuaiyun.com/weixin_35757704/article/details/124315569
- 设置device:
然后将数据与模型后面都额外加上device = torch.device('cuda' if torch.cuda.is_available else 'cpu')
.to(device)
即可
示例程序
import torch
import torch.nn as nn
# 一个简单的模型
class LinearRegressionModel(nn.Module):
def __init__(self, input_shape, output_shape):
super(LinearRegressionModel, self).__init__()
self.linear = nn.Linear(input_shape, output_shape)
def forward(self, x):
out