import torch
import torch.nn as nn
import torch.nn.functional as F
# from inplace_abn import InPlaceABN, InPlaceABNSync
class CAB_low(nn.Module):
def __init__(self, features):
super(CAB_low, self).__init__()
self.conv1 = nn.Conv2d(features * 4, features, kernel_size=1)
self.bn1 = nn.BatchNorm2d(features)
self.maxpooling = nn.MaxPool2d(2)
self.conv2 = nn.Conv2d(features * 4, 1, kernel_size=1)
self.conv3 = nn.Conv2d(features * 2, features, kernel_size=1)
self.bn3 = nn.BatchNorm2d(features)
self.delta_gen1 = nn.Sequential(
nn.Conv2d(features * 2, features, kernel_size=1, bias=False),
# InPlaceABNSync(features),
nn.BatchNorm2d(features),
nn.ReLU(),
nn.Conv2d(features, 2, kernel_size=3, padding=1, bias=False)
)
self.delta_gen2 = nn.Sequential(
nn.Conv2d(features * 2, features, kernel_size=1, bias=False),
# InPlaceABNSync(features),
nn.BatchNorm2d(features),
nn.ReLU(),
nn.Conv2d(features, 2, kernel_size=3, padding=1, bias=False)
)
self.delta_gen1[3].weight.data.zero_()
self.delta_gen2[3].weight.data.zero_()
# https://github.com/speedinghzl/AlignSeg/issues/7
# the normlization item is set to [w/s, h/s] rather than [h/s, w/s]
# the function bilinear_interpolate_torch_gridsample2 is standard implementation, please use bilinear_interpolate_torch_gridsample2 for training.
def bilinear_interpolate_torch_gridsample(self, input, size, delta=0):
out_h, out_w = size
n, c, h, w = input.shape
s = 1.0
norm = torch.tensor([[[[w / s, h / s]]]]).type_as(input).to(input.device)
w_list = torch.linspace(-1.0, 1.0, out_h).view(-1, 1).repeat(1, out_w)
h_list = torch.linspace(-1.0, 1.0, out_w).repeat(out_h, 1)
grid = torch.cat((h_list.unsqueeze(2), w_list.unsqueeze(2)), 2)
grid = grid.repeat(n, 1, 1, 1).type_as(input).to(input.device)
grid = grid + delta.permute(0, 2, 3, 1) / norm
output = F.grid_sample(input, grid)
return output
def bilinear_interpolate_torch_gridsample2(self, input, size, delta=0):
out_h, out_w = size
n, c, h, w = input.shape
s = 2.0
norm = torch.tensor([[[[(out_w - 1) / s, (out_h - 1) / s]]]]).type_as(input).to(input.device) # not [h/s, w/s]
w_list = torch.linspace(-1.0, 1.0, out_h).view(-1, 1).repeat(1, out_w)
h_list = torch.linspace(-1.0, 1.0, out_w).repeat(out_h, 1)
grid = torch.cat((h_list.unsqueeze(2), w_list.unsqueeze(2)), 2)
grid = grid.repeat(n, 1, 1, 1).type_as(input).to(input.device)
grid = grid + delta.permute(0, 2, 3, 1) / norm
output = F.grid_sample(input, grid, align_corners=True)
return output
def forward(self, low_stage, high_stage):
high_stage1 = F.relu(self.bn1(self.conv1(high_stage)))
h, w = low_stage.size(2), low_stage.size(3)
high_stage1 = F.interpolate(input=high_stage1, size=(h, w), mode='bilinear', align_corners=True)
concat = torch.cat((low_stage, high_stage1), 1)
delta1 = self.delta_gen1(concat)
delta2 = self.delta_gen2(concat)
high_stage1 = self.bilinear_interpolate_torch_gridsample(high_stage1, (h, w), delta1)
low_stage1 = self.bilinear_interpolate_torch_gridsample(low_stage, (h, w), delta2)
high_stage1 += low_stage1
if low_stage.size(2) == high_stage.size(2):
low_stage2 = low_stage
low_stage2 = F.sigmoid(self.conv2(low_stage2))
high_stage2 = torch.mul(high_stage, low_stage2)
high_stage2 = torch.cat((high_stage2, high_stage), dim=1)
high_stage2 = F.relu(self.bn3(self.conv3(high_stage2)))
else:
# low_stage2 = self.maxpooling(low_stage)
# low_stage2 = F.sigmoid(self.conv2(low_stage2))
high_stage2 = F.interpolate(input=high_stage, scale_factor=2, mode='bilinear', align_corners=True)
high_stage2 = F.sigmoid(self.conv2(high_stage2))
# high_stage2 = torch.mul(high_stage, low_stage2)
# high_stage2 = torch.cat((high_stage2, high_stage), dim=1)
# high_stage2 = F.relu(self.bn3(self.conv3(high_stage2)))
low_stage2 = torch.mul(low_stage, high_stage2)
low_stage2 = torch.cat((low_stage2, low_stage), dim=1)
low_stage2 = F.relu(self.bn3(self.conv3(low_stage2)))
return high_stage1, low_stage2
class CAB_high(nn.Module):
def __init__(self, features):
super(CAB_high, self).__init__()
self.conv1 = nn.Conv2d(features * 4, features, kernel_size=1)
self.bn1 = nn.BatchNorm2d(features)
self.maxpooling = nn.MaxPool2d(2)
self.conv2 = nn.Conv2d(features, 1, kernel_size=1)
self.conv3 = nn.Conv2d(features * 8, features, kernel_size=1)
self.bn3 = nn.BatchNorm2d(features)
self.delta_gen1 = nn.Sequential(
nn.Conv2d(features * 2, features, kernel_size=1, bias=False),
# InPlaceABNSync(features),
nn.BatchNorm2d(features),
nn.ReLU(),
nn.Conv2d(features, 2, kernel_size=3, padding=1, bias=False)
)
self.delta_gen2 = nn.Sequential(
nn.Conv2d(features * 2, features, kernel_size=1, bias=False),
# InPlaceABNSync(features),
nn.BatchNorm2d(features),
nn.ReLU(),
nn.Conv2d(features, 2, kernel_size=3, padding=1, bias=False)
)
self.delta_gen1[3].weight.data.zero_()
self.delta_gen2[3].weight.data.zero_()
# https://github.com/speedinghzl/AlignSeg/issues/7
# the normlization item is set to [w/s, h/s] rather than [h/s, w/s]
# the function bilinear_interpolate_torch_gridsample2 is standard implementation, please use bilinear_interpolate_torch_gridsample2 for training.
def bilinear_interpolate_torch_gridsample(self, input, size, delta=0):
out_h, out_w = size
n, c, h, w = input.shape
s = 1.0
norm = torch.tensor([[[[w / s, h / s]]]]).type_as(input).to(input.device)
w_list = torch.linspace(-1.0, 1.0, out_h).view(-1, 1).repeat(1, out_w)
h_list = torch.linspace(-1.0, 1.0, out_w).repeat(out_h, 1)
grid = torch.cat((h_list.unsqueeze(2), w_list.unsqueeze(2)), 2)
grid = grid.repeat(n, 1, 1, 1).type_as(input).to(input.device)
grid = grid + delta.permute(0, 2, 3, 1) / norm
output = F.grid_sample(input, grid)
return output
def bilinear_interpolate_torch_gridsample2(self, input, size, delta=0):
out_h, out_w = size
n, c, h, w = input.shape
s = 2.0
norm = torch.tensor([[
duiqi
最新推荐文章于 2024-08-16 09:09:28 发布