forward前向传播部分

deeplab整体结构模型forward函数

class AutoDeeplab(nn.Module):
    def __init__(self, num_classes, num_layers, criterion=None, filter_multiplier=8, block_multiplier=5, step=5, cell=cell_level_search.Cell):
        ....................................

    def forward(self, x):
        # TODO: GET RID OF THESE LISTS, we dont need to keep everything.
        # TODO: Is this the reason for the memory issue ?
        self.level_4 = []
        self.level_8 = []
        self.level_16 = []
        self.level_32 = []
        # 先经过stem0,stem1和stem2共三个初始化节点。如图1中的白色节点部分。
        temp = self.stem0(x)
        temp = self.stem1(temp)
        self.level_4.append(self.stem2(temp))

        count = 0
        # beta参数的维度是12x4x3,即共12层,每层4级,每级有0-上、1-中、2-下共3个方向。
        normalized_betas = torch.randn(self._num_layers, 4, 3).cuda()
        # Softmax on alphas and betas
        if torch.cuda.device_count() > 1:
            ..........这边处理与下面类似,因此只保留else部分。
        else:
            # alpha直接按照列进行softmax操作即可。
            normalized_alphas = F.softmax(self.alphas, dim=-1)
            # 遍历每一层。
            for layer in range(len(self.betas)):
                # 即图1中i=-1层(i=0的前面一层),只有一个白色节点,有向中和向下两个方向,因此是[1:]。
                if layer == 0: 
                    normalized_betas[layer][0][1:] = F.softmax(self.betas[layer][0][1:], dim=-1) * (2/3)
                # 图1中i=0的那层,有两个节点,第一个节点还是只有中下两个方向,第二个节点三个方向都有
                elif layer == 1:
                    normalized_betas[layer][0][1:] = F.softmax(self.betas[layer][0][1:], dim=-1) * (2/3)
                    normalized_betas[layer][1] = F.softmax(self.betas[layer][1], dim=-1)
                # 图1中i=1的那层。
                elif layer == 2:
                    normalized_betas[layer][0][1:] = F.softmax(self.betas[layer][0][1:], dim=-1) * (2/3)
                    normalized_betas[layer][1] = F.softmax(self.betas[layer][1], dim=-1)
                    normalized_betas[layer][2] = F.softmax(self.betas[layer][2], dim=-1)
                # 图1中i>=3的层。每一层的第一级和最后一级都只有两个方向。
                else:
                    normalized_betas[layer][0][1:] = F.softmax(self.betas[layer][0][1:], dim=-1) * (2/3)
                    normalized_betas[layer][1] = F.softmax(self.betas[layer][1], dim=-1)
                    normalized_betas[layer][2] = F.softmax(self.betas[layer][2], dim=-1)
                    normalized_betas[layer][3][:2] = F.softmax(self.betas[layer][3][:2], dim=-1) * (2/3)

        for layer in range(self._num_layers):
            # 即i=0时,计算两个蓝色节点的值。
            if layer == 0:
                level4_new, = self.cells[count](None, None, self.level_4[-1], None, normalized_alphas)
                count += 1
                level8_new, = self.cells[count](None, self.level_4[-1], None, None, normalized_alphas)
                count += 1
                # Relaxation操作。
                level4_new = normalized_betas[layer][0][1] * level4_new
                level8_new = normalized_betas[layer][0][2] * level8_new
                self.level_4.append(level4_new)
                self.level_8.append(level8_new)
            # 即i=1时,指向level4和level8级节点的箭头各有两个,指向level16级节点的箭头有一个。
            elif layer == 1:
                level4_new_1, level4_new_2 = self.cells[count](self.level_4[-2],
                                                               None,
                                                               self.level_4[-1],
                                                               self.level_8[-1],
                                                               normalized_alphas)
                count += 1
                # 这里注意level4_new是如何将两个箭头合并成一个箭头的。level16_new和level8_new都类似。
                # 这里的beta参数会不断优化,最后decode时选择beta值最高的路走。
                level4_new = normalized_betas[layer][0][1] * level4_new_1 + normalized_betas[layer][1][0] * level4_new_2

                level8_new_1, level8_new_2 = self.cells[count](None,
                                                               self.level_4[-1],
                                                               self.level_8[-1],
                                                               None,
                                                               normalized_alphas)
                count += 1
                level8_new = normalized_betas[layer][0][2] * level8_new_1 + normalized_betas[layer][1][2] * level8_new_2

                level16_new, = self.cells[count](None,
                                                 self.level_8[-1],
                                                 None,
                                                 None,
                                                 normalized_alphas)
                level16_new = normalized_betas[layer][1][2] * level16_new
                count += 1

                self.level_4.append(level4_new)
                self.level_8.append(level8_new)
                self.level_16.append(level16_new)
            # 以下部分都类似,不再赘述。
            elif layer == 2:
                ..............
            elif layer == 3:
                ..............
            else: # layer > 3开始,整体结构就基本稳定下来了。
                ..............
        # 在每一个级别的最终输出结果部分进行aspp操作。
        aspp_result_4 = self.aspp_4(self.level_4[-1])
        aspp_result_8 = self.aspp_8(self.level_8[-1])
        aspp_result_16 = self.aspp_16(self.level_16[-1])
        aspp_result_32 = self.aspp_32(self.level_32[-1])
        upsample = nn.Upsample(size=x.size()[2:], mode='bilinear', align_corners=True)
        aspp_result_4 = upsample(aspp_result_4)
        aspp_result_8 = upsample(aspp_result_8)
        aspp_result_16 = upsample(aspp_result_16)
        aspp_result_32 = upsample(aspp_result_32)
        # 最终结果直接求和。
        sum_feature_map = aspp_result_4 + aspp_result_8 + aspp_result_16 + aspp_result_32

        return sum_feature_map

cell结构forward函数

class Cell(nn.Module):
    def __init__(self, steps, block_multiplier, prev_prev_fmultiplier,
                 prev_fmultiplier_down, prev_fmultiplier_same, prev_fmultiplier_up,
                 filter_multiplier):
                 .......................
    
    # s0是前前层cell的输出,s1_down是前层上一级的输出,箭头向下指向当前节点。
    def forward(self, s0, s1_down, s1_same, s1_up, n_alphas):
        # 当前cell的向下获得输入, 同级获得输入和向上获得输入的结构处理。
        if s1_down is not None:
            s1_down = self.prev_feature_resize(s1_down, 'down')
            s1_down = self.preprocess_down(s1_down)
            size_h, size_w = s1_down.shape[2], s1_down.shape[3]
        if s1_same is not None:
            s1_same = self.preprocess_same(s1_same)
            size_h, size_w = s1_same.shape[2], s1_same.shape[3]
        if s1_up is not None:
            s1_up = self.prev_feature_resize(s1_up, 'up')
            s1_up = self.preprocess_up(s1_up)
            size_h, size_w = s1_up.shape[2], s1_up.shape[3]
        # 准备存储计算所有状态。
        all_states = []
        # 前前层节点有对应输出。
        if s0 is not None:
            s0 = F.interpolate(s0, (size_h, size_w), mode='bilinear', align_corners=True) if (s0.shape[2] != size_h) or (s0.shape[3] != size_w) else s0
            s0 = self.pre_preprocess(s0) if (s0.shape[1] != self.C_out) else s0
            # 处理所有可能的state。
            if s1_down is not None:
                states_down = [s0, s1_down]
                all_states.append(states_down)
            if s1_same is not None:
                states_same = [s0, s1_same]
                all_states.append(states_same)
            if s1_up is not None:
                states_up = [s0, s1_up]
                all_states.append(states_up)
        else:
            if s1_down is not None:
                states_down = [0, s1_down]
                all_states.append(states_down)
            if s1_same is not None:
                states_same = [0, s1_same]
                all_states.append(states_same)
            if s1_up is not None:
                states_up = [0, s1_up]
                all_states.append(states_up)

        final_concates = []
        # 对每一个状态
        for states in all_states:
            offset = 0
            # 都要进行_steps步处理,生成_steps个blocks。
            for i in range(self._steps):
                new_states = []
                # 遍历states,states存储每对输入tensor。
                for j, h in enumerate(states):
                    branch_index = offset + j
                    if self._ops[branch_index] is None:
                        continue
                    # 生成当前块的一个分支alpha x ops(h)。
                    new_state = self._ops[branch_index](
                        h, n_alphas[branch_index])
                    new_states.append(new_state)
                # 这里即cell结构relxation操作后的求和部分。
                s = sum(new_states)
                offset += len(states)
                states.append(s)
            # 将所有向量拼接到一起。
            concat_feature = torch.cat(states[-self.block_multiplier:], dim=1)
            final_concates.append(concat_feature)
        # 返回最终拼接完的结果。
        return final_concates
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值