deeplab整体结构模型forward函数

class AutoDeeplab(nn.Module):
def __init__(self, num_classes, num_layers, criterion=None, filter_multiplier=8, block_multiplier=5, step=5, cell=cell_level_search.Cell):
....................................
def forward(self, x):
self.level_4 = []
self.level_8 = []
self.level_16 = []
self.level_32 = []
temp = self.stem0(x)
temp = self.stem1(temp)
self.level_4.append(self.stem2(temp))
count = 0
normalized_betas = torch.randn(self._num_layers, 4, 3).cuda()
if torch.cuda.device_count() > 1:
..........这边处理与下面类似,因此只保留else部分。
else:
normalized_alphas = F.softmax(self.alphas, dim=-1)
for layer in range(len(self.betas)):
if layer == 0:
normalized_betas[layer][0][1:] = F.softmax(self.betas[layer][0][1:], dim=-1) * (2/3)
elif layer == 1:
normalized_betas[layer][0][1:] = F.softmax(self.betas[layer][0][1:], dim=-1) * (2/3)
normalized_betas[layer][1] = F.softmax(self.betas[layer][1], dim=-1)
elif layer == 2:
normalized_betas[layer][0][1:] = F.softmax(self.betas[layer][0][1:], dim=-1) * (2/3)
normalized_betas[layer][1] = F.softmax(self.betas[layer][1], dim=-1)
normalized_betas[layer][2] = F.softmax(self.betas[layer][2], dim=-1)
else:
normalized_betas[layer][0][1:] = F.softmax(self.betas[layer][0][1:], dim=-1) * (2/3)
normalized_betas[layer][1] = F.softmax(self.betas[layer][1], dim=-1)
normalized_betas[layer][2] = F.softmax(self.betas[layer][2], dim=-1)
normalized_betas[layer][3][:2] = F.softmax(self.betas[layer][3][:2], dim=-1) * (2/3)
for layer in range(self._num_layers):
if layer == 0:
level4_new, = self.cells[count](None, None, self.level_4[-1], None, normalized_alphas)
count += 1
level8_new, = self.cells[count](None, self.level_4[-1], None, None, normalized_alphas)
count += 1
level4_new = normalized_betas[layer][0][1] * level4_new
level8_new = normalized_betas[layer][0][2] * level8_new
self.level_4.append(level4_new)
self.level_8.append(level8_new)
elif layer == 1:
level4_new_1, level4_new_2 = self.cells[count](self.level_4[-2],
None,
self.level_4[-1],
self.level_8[-1],
normalized_alphas)
count += 1
level4_new = normalized_betas[layer][0][1] * level4_new_1 + normalized_betas[layer][1][0] * level4_new_2
level8_new_1, level8_new_2 = self.cells[count](None,
self.level_4[-1],
self.level_8[-1],
None,
normalized_alphas)
count += 1
level8_new = normalized_betas[layer][0][2] * level8_new_1 + normalized_betas[layer][1][2] * level8_new_2
level16_new, = self.cells[count](None,
self.level_8[-1],
None,
None,
normalized_alphas)
level16_new = normalized_betas[layer][1][2] * level16_new
count += 1
self.level_4.append(level4_new)
self.level_8.append(level8_new)
self.level_16.append(level16_new)
elif layer == 2:
..............
elif layer == 3:
..............
else:
..............
aspp_result_4 = self.aspp_4(self.level_4[-1])
aspp_result_8 = self.aspp_8(self.level_8[-1])
aspp_result_16 = self.aspp_16(self.level_16[-1])
aspp_result_32 = self.aspp_32(self.level_32[-1])
upsample = nn.Upsample(size=x.size()[2:], mode='bilinear', align_corners=True)
aspp_result_4 = upsample(aspp_result_4)
aspp_result_8 = upsample(aspp_result_8)
aspp_result_16 = upsample(aspp_result_16)
aspp_result_32 = upsample(aspp_result_32)
sum_feature_map = aspp_result_4 + aspp_result_8 + aspp_result_16 + aspp_result_32
return sum_feature_map
cell结构forward函数
class Cell(nn.Module):
def __init__(self, steps, block_multiplier, prev_prev_fmultiplier,
prev_fmultiplier_down, prev_fmultiplier_same, prev_fmultiplier_up,
filter_multiplier):
.......................
def forward(self, s0, s1_down, s1_same, s1_up, n_alphas):
if s1_down is not None:
s1_down = self.prev_feature_resize(s1_down, 'down')
s1_down = self.preprocess_down(s1_down)
size_h, size_w = s1_down.shape[2], s1_down.shape[3]
if s1_same is not None:
s1_same = self.preprocess_same(s1_same)
size_h, size_w = s1_same.shape[2], s1_same.shape[3]
if s1_up is not None:
s1_up = self.prev_feature_resize(s1_up, 'up')
s1_up = self.preprocess_up(s1_up)
size_h, size_w = s1_up.shape[2], s1_up.shape[3]
all_states = []
if s0 is not None:
s0 = F.interpolate(s0, (size_h, size_w), mode='bilinear', align_corners=True) if (s0.shape[2] != size_h) or (s0.shape[3] != size_w) else s0
s0 = self.pre_preprocess(s0) if (s0.shape[1] != self.C_out) else s0
if s1_down is not None:
states_down = [s0, s1_down]
all_states.append(states_down)
if s1_same is not None:
states_same = [s0, s1_same]
all_states.append(states_same)
if s1_up is not None:
states_up = [s0, s1_up]
all_states.append(states_up)
else:
if s1_down is not None:
states_down = [0, s1_down]
all_states.append(states_down)
if s1_same is not None:
states_same = [0, s1_same]
all_states.append(states_same)
if s1_up is not None:
states_up = [0, s1_up]
all_states.append(states_up)
final_concates = []
for states in all_states:
offset = 0
for i in range(self._steps):
new_states = []
for j, h in enumerate(states):
branch_index = offset + j
if self._ops[branch_index] is None:
continue
new_state = self._ops[branch_index](
h, n_alphas[branch_index])
new_states.append(new_state)
s = sum(new_states)
offset += len(states)
states.append(s)
concat_feature = torch.cat(states[-self.block_multiplier:], dim=1)
final_concates.append(concat_feature)
return final_concates