import torch
import torch.nn as nn
import dgl
class GraphLSTAM(nn.Module):
def __init__(self, num_nodes, input_dim, hidden_dim, output_dim, num_layers):
super(GraphLSTAM, self).__init__()
self.graph_conv = dgl.nn.GraphConv(input_dim, hidden_dim)
self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers=num_layers, batch_first=True)
self.attention = nn.Linear(hidden_dim, 1)
self.fc = nn.Linear(hidden_dim, output_dim)
def forward(self, g, features):
# g: DGLGraph object
# features: tensor of shape (batch_size, num_nodes, input_dim)
# Graph convolution layer
h = self.graph_conv(g, features)
h = t