第一种情况 比较简单的处理是
x.to(device)
model.to(device)
第二、实际模型没有定义好,有一部分没有定义好,遗漏了
报错信息
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/rnn.py", line 951, in forward
result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
RuntimeError: Input and parameter tensors are not at the same device, found input tensor at cuda:0 and parameter tensor at cpu
错误的代码
self.lstm.weight_ih_l0 = PyroSample(
dist.Normal(0, prior_scale)
).expand([4 * hidden_size, nput_size]).to_event(2))
self.lstm.bias_ih_l0 = PyroSample(
dist.Normal(
torch.tensor([0.], device=self.device),
torch.tensor([prior_scale], device=self.device)
).expand([4 * hidden_size]).to_event(1))
修改后的代码
import torch
import pyro
from pyro.distributions import Normal, Gamma
from pyro.nn import PyroModule, PyroSample
import torch.nn as nn
class Model(PyroModule):
def __init__(self, input_size=1, num_classes=1, hidden_size=3, num_layers=1,
prior_scale=50.0
, device='cuda:0'
):
super().__init__()
self.device = device
self.num_classes = num_classes
self.num_layers = num_layers
self.input_size = input_size
self.hidden_size = hidden_size
self.activation = nn.ReLU() # or nn.ReLU()
# Correctly initialize the LSTM layer with bidirectional=False
self.lstm = PyroModule[nn.LSTM](input_size,
hidden_size,
num_layers,
batch_first=True,
bidirectional=False).to(device)
self.linear = PyroModule[nn.Linear](hidden_size, 512).to(device) # Adjusted for unidirectional LSTM
self.fc = PyroModule[nn.Linear](512, num_classes).to(device)
# Initialize weights and biases for each layer
# Input to hidden layer
self.lstm.weight_ih_l0 = PyroSample(
dist.Normal(
torch.tensor([0.], device=self.device),
torch.tensor([prior_scale], device=self.device)
).expand([4 * hidden_size, input_size]).to_event(2))
self.lstm.bias_ih_l0 = PyroSample(
dist.Normal(
torch.tensor([0.], device=self.device),
torch.tensor([prior_scale], device=self.device)
).expand([4 * hidden_size]).to_event(1))
# Hidden to hidden layer
self.lstm.weight_hh_l0 = PyroSample(
dist.Normal(
torch.tensor([0.], device=self.device),
torch.tensor([prior_scale], device=self.device)
).expand([4 * hidden_size, hidden_size]).to_event(2))
self.lstm.bias_hh_l0 = PyroSample(
dist.Normal(
torch.tensor([0.], device=self.device),
torch.tensor([prior_scale], device=self.device)
).expand([4 * hidden_size]).to_event(1))
self.linear.weight = PyroSample(
dist.Normal(
torch.tensor([0.], device=self.device),
torch.tensor([prior_scale], device=self.device)
).expand([128, hidden_size]).to_event(2))
self.linear.bias = PyroSample(
dist.Normal(
torch.tensor([0.], device=self.device),
torch.tensor([prior_scale], device=self.device)
).expand([128]).to_event(1))
self.fc.weight = PyroSample(
dist.Normal(
torch.tensor([0.], device=self.device),
torch.tensor([prior_scale], device=self.device)
).expand([num_classes, 128]).to_event(2))
self.fc.bias = PyroSample(
dist.Normal(
torch.tensor([0.], device=self.device),
torch.tensor([prior_scale], device=self.device)
).expand([num_classes]).to_event(1))