class MLP_BiGRU(nn.Module):
'''基本参数定义'''
def __init__(self, num_layers, rnn_hidden_size, encoder_input_size, encoder_hidden_size, depth=300):
#super 调用父类的方法,解决多继承问题
super(MLP_BiGRU, self).__init__()
self.depth = depth #深度
self.num_layers = num_layers #层数
self.rnn_hidden_size = rnn_hidden_size #隐藏
'''编码,对lon,lat,sst,date进行编码'''
#num_embeddings可以理解为待编码数据的(变化大小),例如经度有360个变量
#embedding_dim编码向量的大小,这里统一设置为16
self.embedding_lat = nn.Embedding(90, 16)
self.embedding_lon = nn.Embedding(360, 16)
self.embedding_sst = nn.Embedding(64, 16) #SST变化范围为64
self.embedding_date = nn.Embedding(7, 16) #1993.01-1993
#Sequential顺序类容器
self.mlp = nn.Sequential(
nn.Linear(encoder_input_size, encoder_hidden_size),
#双曲正切(tanh)激活函数,将输入数据的每个元素映射到(-1,1)区间
nn.Tanh(),
nn.Linear(encoder_hidden_size, encoder_hidden_size),
nn.Tanh(),
nn.Linear(encoder_hidden_size, encoder_hidden_size),
nn.Tanh(),
nn.Linear(encoder_hidden_size, depth),
)
#input_size,输入数据的维度, hidden_size, 处于隐藏状态的特征数量
#num_layers, 循环层数; bidirectional, True, 双向GRU
self.bi_gru = nn.GRU(input_size=1, hidden_size=rnn_hidden_size,
num_layers=num_layers, batch_first=True, bidirectional=True)
#创建线性层模块,实现仿射变换
self.fc = nn.Linear(rnn_hidden_size * 2 * depth, depth)
self.flatten = nn.Flatten(1, -1)
def forward(self, x): #输入的x的大小为[6,300]
'''cat 进行不同张量的连接,维度为1; trunc返回新张量,截取整数;
frac只取小数; unsquee拓展维度,一维表变二维等;long长整数'''
'''纬度 sst没问题;此代码中使用unsqueeze是为了解决维度匹配的问题'''
#x0 纬度
x0 = torch.cat((self.embedding_lat(torch.trunc(x[:, 0]).long()), torch.frac(x[:, 0]).unsqueeze(1)), 1) #torch.unsqueeze(num),在第num+1增加新维度,size为1,如(3,4)转换为(3,1,4)
#x1 经度
x1 = torch.cat((self.embedding_lon(torch.trunc(x[:, 1]).long()), torch.frac(x[:, 1]).unsqueeze(1)), 1)
#x2 日期
x2 = self.embedding_date(x[:, 2].long())
#x3温度
x3 = torch.cat((self.embedding_sst(torch.trunc(x[:, -2]).long()), torch.frac(x[:, -2]).unsqueeze(1)), 1)
#x的大小为batch_size*input_size
x = torch.cat((x0, x1, x2, x3, x[:, -1].unsqueeze(1)), 1)
x = self.mlp(x.float())
output, h = self.bi_gru(x.unsqueeze(-1)) #扩展维度,-1表示在最后增加新维度。如(3,4)增加至(3,4,1)
output = self.flatten(output)
output = self.fc(output)
return output
MLP_BiGRU模型构建—自备份用
于 2024-10-11 15:23:33 首次发布