BiO-Net: Learning Recurrent Bi-directional Connections for Encoder-Decoder Architecture

本文介绍了在2020年MICCAI会议上提出的BiONet模型,该模型利用递归双向连接优化编码器-解码器结构,适用于医学图像处理。网络详细地展示了预处理、编码器、中间层和解码器的设计,以及完整的代码实现。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

MICCAI 2020 学习用于 编码器-解码器结构的递归双向连接

BiO系列项目地址

目录

MICCAI 2020 学习用于 编码器-解码器结构的递归双向连接

网络框架

系列框图​

完整网络代码


关键词:递归、双向

网络框架

 

 系列框图

完整网络代码

import numpy as np
import torch
from torch.nn import Module, Sequential, Conv2d, BatchNorm2d, ConvTranspose2d, ReLU, MaxPool2d, Sigmoid, Parameter
from torch import tensor, cat


class BiONet(Module):

    def __init__(self,
                 num_classes: int = 3,
                 iterations: int = 2,
                 multiplier: float = 1.0,
                 num_layers: int = 4,
                 integrate: bool = False):

        super(BiONet, self).__init__()
        #  参数
        self.iterations = iterations
        self.multiplier = multiplier
        self.num_layers = num_layers
        self.integrate = integrate
        self.batch_norm_momentum = 0.01
        #  生成通道参数,此处通道是从第一个Encoder输出的量开始的,直到语义向量
        self.filters_list = [int(32 * (2 ** i) * self.multiplier) for i in range(self.num_layers + 1)]
        #  预处理卷积块,不参与循环,最终输出的是32*256*256
        self.pre_transform_conv_block = Sequential(
            # 这里看代码实现,应该永远和第一个Encoder输出的层数是一样的
            Conv2d(3, self.filters_list[0], kernel_size=(3, 3), padding=(1, 1), stride=(1, 1)),  # 生成f[1]*512*512
            ReLU(),
            BatchNorm2d(self.filters_list[0], momentum=self.batch_norm_momentum),
            Conv2d(self.filters_list[0], self.filters_list[0], kernel_size=(3, 3), padding=(1, 1), stride=(1, 1)),
            # 生成f[1]*512*512
            ReLU(),
            BatchNorm2d(self.filters_list[0], momentum=self.batch_norm_momentum),
            Conv2d(self.filters_list[0], self.filters_list[0], kernel_size=(3, 3), padding=(1, 1), stride=(1, 1)),
            # 生成f[1]*512*512
            ReLU(),
            BatchNorm2d(self.filters_list[0], momentum=self.batch_norm_momentum),
            MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=(0, 0))
        )
        self.reuse_convs = []  # encoder复用的卷积核,每个encoder对应一个元组(共3个卷积核,不包括ReLU)
        self.encoders = []  # 由encoder构成的列表。由于encoder的一部分不参与循环,因此每个encoder是一个元组(两个CONV的Sequential, DOWN)
        self.reuse_deconvs = []  # decoder复用的卷积、反卷积核,每个decoder对应一个元组(共3个卷积核,不包括ReLU)
        self.decoders = []  # 由decoder构成的列表。由于decoder的一部分不参与循环,因此每个decoder是一个元组(两个CONV的Sequential, UP)
        for iteration in range(self.iterations):
            for layer in range(self.num_layers):

                #  创建encoders的部分。虽然部分代码可以合写,但是为了看起来清晰(而且构造函数也不是很要求效率),所以分开考虑encoder和decoder
                #  和层次有关的常数
                in_channel = self.filters_list[layer] * 2  # 由于有输出部分传入的数据,因此需要翻倍输入通道
                mid_channel = self.filters_list[layer]
                out_channel = self.filters_list[layer + 1]
                #  创建encoders模型
                if iteration == 0:
                    #  创建并添加复用卷积核
                    #  只有最后一个卷积核负责升高通道
                    conv1 = Conv2d(in_channel, mid_channel, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1))
                    conv2 = Conv2d(mid_channel, mid_channel, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1))
                    conv3 = Conv2d(mid_channel, out_channel, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1))
                    self.reuse_convs.append((conv1, conv2, conv3))
                #  创建encoder
                #  首先构造两个CONV
                convs = Sequential(
                    self.reuse_convs[layer][0],
                    ReLU(),
                    BatchNorm2d(mid_channel, momentum=self.batch_norm_momentum),
                    self.reuse_convs[layer][1],
                    ReLU(),
                    BatchNorm2d(mid_channel, momentum=self.batch_norm_momentum)
                )
                #  构建DOWN
                down = Sequential(
                    self.reuse_convs[layer][2],
                    ReLU(),
                    BatchNorm2d(out_channel, momentum=self.batch_norm_momentum),
                    MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=(0, 0))
                )
                self.add_module("iteration{0}_layer{1}_encoder_convs".format(iteration, layer), convs)
                self.add_module("iteration{0}_layer{1}_encoder_down".format(iteration, layer), down)
                self.encoders.append((convs, down))

                #  创建decoders的部分,仿照encoders
                #  和层次有关的常数,注意本部分不需要mid_channel,因为第一个卷积核就已经升高维度了
                in_channel = self.filters_list[self.num_layers - layer] + self.filters_list[self.num_layers - 1 - layer]
                out_channel = self.filters_list[self.num_layers - 1 - layer]
                #  创建decoders模型
                if iteration == 0:
                    #  创建并添加复用卷积核
                    #  从第一个卷积核就升高通道数
                    conv1 = Conv2d(in_channel, out_channel, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1))
                    conv2 = Conv2d(out_channel, out_channel, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1))
                    conv3 = ConvTranspose2d(out_channel, out_channel, kernel_size=(3, 3), padding=(1, 1), stride=(2, 2),
                                            output_padding=(1, 1))  # 此处和tensorflow有点区别,为了完整的形状,需要用output补一补
                    self.reuse_deconvs.append((conv1, conv2, conv3))
                #  创建encoder
                #  首先构造两个CONV
                convs = Sequential(
                    self.reuse_deconvs[layer][0],
                    ReLU(),
                    BatchNorm2d(out_channel, momentum=self.batch_norm_momentum),
                    self.reuse_deconvs[layer][1],
                    ReLU(),
                    BatchNorm2d(out_channel, momentum=self.batch_norm_momentum)
                )
                #  构造UP
                up = Sequential(
                    self.reuse_deconvs[layer][2],
                    ReLU(),
                    BatchNorm2d(out_channel, momentum=self.batch_norm_momentum)
                )
                self.add_module("iteration{0}_layer{1}_decoder_convs".format(iteration, layer), convs)
                self.add_module("iteration{0}_layer{1}_decoder_up".format(iteration, layer), up)
                self.decoders.append((convs, up))
        #  创建middle层
        self.middles = Sequential(
            Conv2d(self.filters_list[-1], self.filters_list[-1], kernel_size=(3, 3), padding=(1, 1), stride=(1, 1)),
            ReLU(),
            BatchNorm2d(self.filters_list[-1], momentum=self.batch_norm_momentum),
            Conv2d(self.filters_list[-1], self.filters_list[-1], kernel_size=(3, 3), padding=(1, 1), stride=(1, 1)),
            ReLU(),
            BatchNorm2d(self.filters_list[-1], momentum=self.batch_norm_momentum),
            ConvTranspose2d(self.filters_list[-1], self.filters_list[-1], kernel_size=(3, 3), padding=(1, 1),
                            stride=(2, 2), output_padding=(1, 1)),
            ReLU(),
            BatchNorm2d(self.filters_list[-1], momentum=self.batch_norm_momentum)
        )
        self.post_transform_conv_block = Sequential(
            Conv2d(self.filters_list[0] * self.iterations, self.filters_list[0], kernel_size=(3, 3), padding=(1, 1),
                   stride=(1, 1)) if self.integrate else Conv2d(self.filters_list[0],
                                                                self.filters_list[0], kernel_size=(3, 3),
                                                                padding=(1, 1), stride=(1, 1)),
            ReLU(),
            BatchNorm2d(self.filters_list[0], momentum=self.batch_norm_momentum),
            Conv2d(self.filters_list[0], self.filters_list[0], kernel_size=(3, 3), padding=(1, 1), stride=(1, 1)),
            ReLU(),
            BatchNorm2d(self.filters_list[0], momentum=self.batch_norm_momentum),
            Conv2d(self.filters_list[0], 1, kernel_size=(1, 1), stride=(1, 1)),
            Sigmoid(),
        )

    def forward(self, x: tensor) -> tensor:
        enc = [None for i in range(self.num_layers)]
        dec = [None for i in range(self.num_layers)]
        all_output = [None for i in range(self.iterations)]
        x = self.pre_transform_conv_block(x)
        e_i = 0
        d_i = 0
        for iteration in range(self.iterations):
            for layer in range(self.num_layers):
                if layer == 0:
                    x_in = x
                x_in = self.encoders[e_i][0](cat([x_in, x_in if dec[-1 - layer] is None else dec[-1 - layer]], dim=1))
                enc[layer] = x_in
                x_in = self.encoders[e_i][1](x_in)
                e_i = e_i + 1
            x_in = self.middles(x_in)
            for layer in range(self.num_layers):
                x_in = self.decoders[d_i][0](cat([x_in, enc[-1 - layer]], dim=1))
                dec[layer] = x_in
                x_in = self.decoders[d_i][1](x_in)
                d_i = d_i + 1
            all_output[iteration] = x_in
        if self.integrate:
            x_in = cat(all_output, dim=1)
        x_in = self.post_transform_conv_block(x_in)
        return x_in


if __name__ == '__main__':
    Total_params = 0
    Trainable_params = 0
    NonTrainable_params = 0
    model = BiONet()
    x = torch.rand((1, 3, 512, 512))
    out = model(x)
    # 遍历model.parameters()返回的全局参数列表
    for param in model.parameters():
        mulValue = np.prod(param.size())  # 使用numpy prod接口计算参数数组所有元素之积
        Total_params += mulValue  # 总参数量
        if param.requires_grad:
            Trainable_params += mulValue  # 可训练参数量
        else:
            NonTrainable_params += mulValue  # 非可训练参数量

    print(f'Total params: {Total_params}')
    print(f'Trainable params: {Trainable_params}')
    print(f'Non-trainable params: {NonTrainable_params}')
    print(out.shape)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值