MLP_BiGRU模型构建—自备份用

class MLP_BiGRU(nn.Module):
    '''基本参数定义'''
    def __init__(self, num_layers, rnn_hidden_size, encoder_input_size, encoder_hidden_size, depth=300):
    	#super 调用父类的方法,解决多继承问题
        super(MLP_BiGRU, self).__init__()
        self.depth = depth #深度
        self.num_layers = num_layers #层数
        self.rnn_hidden_size = rnn_hidden_size #隐藏
        
        '''编码,对lon,lat,sst,date进行编码'''
        #num_embeddings可以理解为待编码数据的(变化大小),例如经度有360个变量
        #embedding_dim编码向量的大小,这里统一设置为16
        self.embedding_lat = nn.Embedding(90, 16) 
        self.embedding_lon = nn.Embedding(360, 16) 
        self.embedding_sst = nn.Embedding(64, 16) #SST变化范围为64
        self.embedding_date = nn.Embedding(7, 16) #1993.01-1993
		#Sequential顺序类容器
        self.mlp = nn.Sequential(
            nn.Linear(encoder_input_size, encoder_hidden_size),
            #双曲正切(tanh)激活函数,将输入数据的每个元素映射到(-1,1)区间
            nn.Tanh(),
            nn.Linear(encoder_hidden_size, encoder_hidden_size),
            nn.Tanh(),
            nn.Linear(encoder_hidden_size, encoder_hidden_size),
            nn.Tanh(),
            nn.Linear(encoder_hidden_size, depth),
        )
		#input_size,输入数据的维度, hidden_size, 处于隐藏状态的特征数量
		#num_layers, 循环层数; bidirectional, True, 双向GRU
        self.bi_gru = nn.GRU(input_size=1, hidden_size=rnn_hidden_size,
                                  num_layers=num_layers, batch_first=True, bidirectional=True)
		#创建线性层模块,实现仿射变换
        self.fc = nn.Linear(rnn_hidden_size * 2 * depth, depth)

        self.flatten = nn.Flatten(1, -1)

    def forward(self, x): #输入的x的大小为[6,300]
        '''cat 进行不同张量的连接,维度为1; trunc返回新张量,截取整数; 
        frac只取小数; unsquee拓展维度,一维表变二维等;long长整数'''
        '''纬度 sst没问题;此代码中使用unsqueeze是为了解决维度匹配的问题'''
        #x0 纬度
        x0 = torch.cat((self.embedding_lat(torch.trunc(x[:, 0]).long()), torch.frac(x[:, 0]).unsqueeze(1)), 1) #torch.unsqueeze(num),在第num+1增加新维度,size为1,如(3,4)转换为(3,1,4)
        #x1 经度
        x1 = torch.cat((self.embedding_lon(torch.trunc(x[:, 1]).long()), torch.frac(x[:, 1]).unsqueeze(1)), 1)
        #x2 日期
        x2 = self.embedding_date(x[:, 2].long())
        #x3温度
        x3 = torch.cat((self.embedding_sst(torch.trunc(x[:, -2]).long()), torch.frac(x[:, -2]).unsqueeze(1)), 1)
        #x的大小为batch_size*input_size
        x = torch.cat((x0, x1, x2, x3, x[:, -1].unsqueeze(1)), 1)
       
        x = self.mlp(x.float())
        output, h = self.bi_gru(x.unsqueeze(-1)) #扩展维度,-1表示在最后增加新维度。如(3,4)增加至(3,4,1)
        output = self.flatten(output)
        output = self.fc(output)

        return output
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值