from torch import nn
import math
import torch
import ipdb
class GLU(nn.Module):
def __init__(self,input_size):
super(GLU, self).__init__()
self.fc1=nn.Linear(input_size,input_size)
self.fc2=nn.Linear(input_size,input_size)
self.sigmoid=nn.Sigmoid()
def forward(self,x):
sig=self.sigmoid(self.fc1(x))
x=self.fc2(x)
return torch.mul(sig,x)
class TimeDistributed(nn.Module):
def __init__(self, module, batch_first=False):
super(TimeDistributed, self).__init__()
self.module = module
self.batch_first = batch_first
def forward(self, x):
if len(x.size()) <= 2:
return self.module(x)
x_reshape = x.contiguous().view(-1, x.size(-1))
y = self.module(x_reshape)
if self.batch_first:
y = y.contiguous().view(x.size(0), -1, y.size(-1))
else:
y = y.view(-1, x.size(1), y.size(-1))
return y
class GRN(nn.Module):
def __init__(self,input_size,hidden_state_size,output_size,drop_out,hidden_context_size=None,batch_first=False):
super(GRN, self).__init__()
self.input_size=input_size
self.output_size=output_size
self.hidden_context_size=hidden_context_size
self.hidden_state_size=hidden_state_size
self.drop_out=drop_out
if self.input_size!=self.output_size:
self.skip_layer=TimeDistributed(nn.Linear(self.input_size,self.output_size))
self.fc1=TimeDistributed(nn.Linear(self.input_size,self.hidden_state_size),batch_first=batch_first)
self.elu1=nn.ELU()
if self.hidden_context_size is not None:
self.context=TimeDistributed(nn.Linear(self.hidden_context_size,self.hidden_state_size),batch_first=batch_first)
self.fc2=TimeDistributed(nn.Linear(self.hidden_state_size,self.output_size),batch_first=batch_first)
self.dropout=nn.Dropout(self.drop_out)
self.ln=TimeDistributed(nn.LayerNorm(self.output_size),batch_first=batch_first)
self.gate=TimeDistributed(GLU(self.output_size),batch_first=batch_first)
def forward(self,x,context=None):
if self.input_size!=self.output_size:
residual=self.skip_layer(x)
else:
residual=x
x=self.fc1(x)
if context is not None:
context=self.context(context)
x=x+context
x=self.elu1(x)
x=self.fc2(x)
x=self.dropout(x)
x=self.gate(x)
x=x+residual
x=self.ln(x)
return x
class PositionalEncoder(nn.Module):
def __init__(self,d_model,max_seq_len=160):
super(PositionalEncoder, self).__init__()
self.d_model=d_model
pe=torch.zeros(max_seq_len,d_model)
for pos in range(max_seq_len):
for i in range(0,d_model,2):
pe[pos, i] = \
math.sin(pos / (10000 ** ((2 * i) / d_model)))
pe[pos, i + 1] = \
math.cos(pos / (10000 ** ((2 * (i + 1)) / d_model)))
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self,x):
with torch.no_grad():
x=x*math.sqrt(self.d_model)
seq_len=x.size(0)
pe=self.pe[:,:seq_len].view(seq_len,1,self.d_model)
x=x+pe
return x
class VSN(nn.Module):
def __init__(self,input_size,num_inputs,hidden_size,drop_out,context=None):
super(VSN, self).__init__()
self.hidden_size=hidden_size
self.input_size=input_size
self.num_inputs=num_inputs
self.drop_out=drop_out
self.context=context
self.flattened_grn=GRN(input_size=self.num_inputs*self.input_size,hidden_state_size=self.hidden_size,output_size=self.num_inputs,drop_out=self.drop_out,hidden_context_size=self.context)
self.single_variable_grns=nn.ModuleList()
for i in range(self.num_inputs):
self.single_variable_grns.append(GRN(self.input_size,self.hidden_size,self.hidden_size,self.drop_out))
self.softmax=nn.Softmax()
def forward(self,embedding,context=None):
sparse_weights=self.flattened_grn(embedding,context)
sparse_weights=self.softmax(sparse_weights).unsqueeze(2)
var_outputs=[]
for i in range(self.num_inputs):
var_outputs.append(self.single_variable_grns[i](embedding[:,:,(i*self.input_size):(i+1)*self.input_size]))
var_outputs=torch.stack(var_outputs,dim=-1)