import os
import torch
import torch.nn.functional as F
import torch.nn as nn
class dilated_cnn(nn.Module):
def __init__(self, in_channel):
super(dilated_cnn, self).__init__()
self.fconv_d1 = nn.Conv1d(in_channels=in_channel, out_channels=64, kernel_size=5, dilation =1, padding=2)
self.fconv_d2 = nn.Conv1d(in_channels=64, out_channels=64, kernel_size=5, dilation =2, padding=(5-1)*2//2)
self.fconv_d3 = nn.Conv1d(in_channels=64, out_channels=64, kernel_size=5, dilation =4, padding=(5-1)*4//2 )
self.fconv_d4 = nn.Conv1d(in_channels=64, out_channels=64, kernel_size=5, dilation =8, padding=(5-1)*8//2)
self.fconv1x1 = nn.Conv1d(in_channels=64, out_channels=in_channel, kernel_size=1, dilation =1)
self.bn1 = nn.BatchNorm1d(num_features=64)
self.bn2 = nn.BatchNorm1d(num_features=64)
self.bn3 = nn.BatchNorm1d(num_features=64)
self.bn4 = nn.BatchNorm1d(num_features=64)
def forward(self, x1):
x = F.relu(self.bn1(self.fconv_d1(x1)))
#print("x1:",x.shape)
x = F.relu(self.bn2(self.fconv_d2(x)))
#print("x2:",x.shape)
#x = self.dropout1(x)
x = F.relu(self.bn3(self.fconv_d3(x)))
#print("x3:",x.shape)
x = F.relu(self.bn4(self.fconv_d4(x)))
#print("x4:",x.shape)
res_glu_1 = F.tanh(x)*torch.sigmoid(x)
x = self.fconv1x1(res_glu_1)
x = x + x1
return x
if __name__=="__main__":
dcnn = dilated_cnn(in_channel=16)
input = torch.randn(20, 16, 100)
x = dcnn(input)
print("x:",x.shape)
03-16