网络结构:类UNet左右对称网络,主要设计体现在Decoder阶段上采样模块(文中提到:1D的上采样decoder模块,对于稀疏点的输入会比较鲁棒,Our 1D decoder has similar effect at roads with sparse samples, and they are not affected by sampling intervals.)
classDecoderBlock1DConv4(nn.Module):def__init__(self, in_channels, n_filters):super(DecoderBlock1DConv4, self).__init__()
self.conv1 = nn.Conv2d(in_channels, in_channels //4,1)
self.norm1 = nn.BatchNorm2d(in_channels //4)
self.relu1 = nonlinearity
self.deconv1 = nn.ConvTranspose2d(
in_channels //4, in_channels //8,(1,9), padding=(0,4))
self.deconv2 = nn.ConvTranspose2d(
in_channels //4, in_channels //8,(9,1), padding=(4,0))
self.deconv3 = nn.ConvTranspose2d(
in_channels //4, in_channels //16,(9,1), padding=(4,0))
self.deconv4 = nn.ConvTranspose2d(
in_channels //4, in_channels //16,(1,9), padding=(0,4))
self.norm2 = nn.BatchNorm2d(in_channels //4+ in_channels //8)
self.relu2 = nonlinearity
self.conv3 = nn.Conv2d(
in_channels //4+ in_channels //8, n_filters,1)
self.norm3 = nn.BatchNorm2d(n_filters)
self.relu3 = nonlinearity
defforward(self, x):
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x1 = self.deconv1(x)# vertical
x2 = self.deconv2(x)# horizonal
x3 = self.inv_h_transform(self.deconv3(self.h_transform(x)))# diag
x4 = self.inv_v_transform(self.deconv4(self.v_transform(x)))# anti-diag
x = torch.cat((x1, x2, x3, x4),1)
x = F.interpolate(x, scale_factor=2)
x = self.norm2(x)
x = self.relu2(x)
x = self.conv3(x)
x = self.norm3(x)
x = self.relu3(x)return x
defh_transform(self, x):
shape = x.size()
x = torch.nn.functional.pad(x,(0, shape[-1]))
x = x.reshape(shape[0], shape[1],-1)[...,:-shape[-1]]
x = x.reshape(shape[0