class GluConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride):
super(GluConv2d, self).__init__()
self.out_channels = out_channels
self.conv1 = nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=kernel_size,stride=stride)
self.conv2 = nn.Conv2d(in_channels=in_channels,out_channels=out_channels, kernel_size=kernel_size,stride=stride)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
out1 = self.conv1(x)
out2 = self.sigmoid(self.conv2(x))
out = out1 * out2
return out
class GluConvTranspose2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, output_padding=(0,0)):
super(GluConvTranspose2d, self).__init__()
self.out_channels = out_channels
self.conv1 = nn.ConvTranspose2d(in_channels=in_channels,out_channels=out_channels,kernel_size=kernel_size,stride=stride,output_padding=output_padding)
self.conv2 = nn.ConvTranspose2d(in_channels=in_channels,out_channels=out_channels,kernel_size=kernel_size,stride=stride,output_padding=output_padding)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
out1 = self.conv1(x)
out2 = self.sigmoid(self.conv2(x))
out = out1 * out2
return out
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = GluConv2d(6, 32, (1,3), (1,2))
self.conv2 = GluConv2d(32, 32, (1,3), (1,2))
self.conv3 = GluConv2d(32, 32, (1,3), (1,2))
self.conv4 = GluConv2d(32, 64, (1,3), (1,2))
self.conv5 = GluConv2d(64, 64, (1,3), (1,2))
self.conv5_t_1 = QuantGluConvTranspose2d(192,32, (1,3), (1,2))#CNN
self.conv4_t_1 = QuantGluConvTranspose2d(96, 16, (1,3), (1,2))
self.conv3_t_1 = QuantGluConvTranspose2d(48, 16, (1,3), (1,2))
self.conv2_t_1 = QuantGluConvTranspose2d(48, 16, (1,3), stride=(1,2), output_padding=(0,1))
self.conv1_t_1 = QuantGluConvTranspose2d(48, 1, (1,3), (1,2))
self.conv5_t_2 = QuantGluConvTranspose2d(192,32, (1,3), (1,2))
self.conv4_t_2 = QuantGluConvTranspose2d(96, 16, (1,3), (1,2))
self.conv3_t_2 = QuantGluConvTranspose2d(48, 16, (1,3), (1,2))
self.conv2_t_2 = QuantGluConvTranspose2d(48, 16, (1,3), stride=(1,2), output_padding=(0,1))
self.conv1_t_2 = QuantGluConvTranspose2d(48, 1, (1,3), (1,2))
self.bn1 = nn.BatchNorm2d(32)
self.bn2 = nn.BatchNorm2d(32)
self.bn3 = nn.BatchNorm2d(32)
self.bn4 = nn.BatchNorm2d(64)
self.bn5 = nn.BatchNorm2d(64)
self.bn5_t_1 = nn.BatchNorm2d(32)
self.bn4_t_1 = nn.BatchNorm2d(16)
self.bn3_t_1 = nn.BatchNorm2d(16)
self.bn2_t_1 = nn.BatchNorm2d(16)
self.bn1_t_1 = nn.BatchNorm2d(1)
self.bn5_t_2 = nn.BatchNorm2d(32)
self.bn4_t_2 = nn.BatchNorm2d(16)
self.bn3_t_2 = nn.BatchNorm2d(16)
self.bn2_t_2 = nn.BatchNorm2d(16)
self.bn1_t_2 = nn.BatchNorm2d(1)
self.elu = nn.ELU(inplace=True)
self.fc1 = nn.Linear(in_features=161, out_features=161)
self.fc2 = nn.Linear(in_features=161, out_features=161)
def forward(self, x):
out = x
e1 = self.elu(self.bn1(self.conv1(out)))
e2 = self.elu(self.bn2(self.conv2(e1)))
e3 = self.elu(self.bn3(self.conv3(e2)))
e4 = self.elu(self.bn4(self.conv4(e3)))
e5 = self.elu(self.bn5(self.conv5(e4)))
d5_1 = self.elu(torch.cat((self.bn5_t_1(self.conv5_t_1(out)), e4), dim=1))
d4_1 = self.elu(torch.cat((self.bn4_t_1(self.conv4_t_1(d5_1)), e3), dim=1))
d3_1 = self.elu(torch.cat((self.bn3_t_1(self.conv3_t_1(d4_1)), e2), dim=1))
d2_1 = self.elu(torch.cat((self.bn2_t_1(self.conv2_t_1(d3_1)), e1), dim=1))
d1_1 = self.elu(self.bn1_t_1(self.conv1_t_1(d2_1)))
d5_2 = self.elu(torch.cat((self.bn5_t_2(self.conv5_t_2(out)), e4), dim=1))
d4_2 = self.elu(torch.cat((self.bn4_t_2(self.conv4_t_2(d5_2)), e3), dim=1))
d3_2 = self.elu(torch.cat((self.bn3_t_2(self.conv3_t_2(d4_2)), e2), dim=1))
d2_2 = self.elu(torch.cat((self.bn2_t_2(self.conv2_t_2(d3_2)), e1), dim=1))
d1_2 = self.elu(self.bn1_t_2(self.conv1_t_2(d2_2)))
out = torch.cat([out1, out2], dim=1)
return out
网络结构如上,如何将此网络进行量化
最新发布