MedMamba代码解释及用于糖尿病视网膜病变分类

MedMamba原理和用于糖尿病视网膜病变检测尝试

1.MedMamba原理

image-20241010110028101

MedMamba发表于2024.9.28,是构建在Vision Mamba基础之上,融合了卷积神经网的架构,结构如下图:

image-20241010110201286

原理简述就是图片输入后按通道输入后切分为两部分,一部分走二维分组卷积提取局部特征,一部分利用Vision Mamba中的SS2D模块提取所谓的全局特征,两个分支的输出通过通道维度的拼接后,经过channel shuffle增加信息融合。

2.代码解释

模型代码就在源码的MedMamba.py文件下,对涉及到的代码我进行了详细注释:

  • mamba部分

    基本上是使用Vision Mamaba的SS2D:

class SS2D(nn.Module):
    def __init__(
        self,
        d_model,
        d_state=16,
        # d_state="auto", # 20240109
        d_conv=3,
        expand=2,
        dt_rank="auto",
        dt_min=0.001,
        dt_max=0.1,
        dt_init="random",
        dt_scale=1.0,
        dt_init_floor=1e-4,
        dropout=0.,
        conv_bias=True,
        bias=False,
        device=None,
        dtype=None,
        **kwargs,
    ):
        # 设置设备和数据类型的关键参数
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.d_model = d_model # 模型维度
        self.d_state = d_state # 状态维度
        # self.d_state = math.ceil(self.d_model / 6) if d_state == "auto" else d_model # 20240109
        self.d_conv = d_conv # 卷积核的大小
        self.expand = expand  # 扩展因子
        self.d_inner = int(self.expand * self.d_model)  # 内部维度,等于模型维度乘以扩展因子
        # 时间步长的秩,默认为模型维度除以16
        self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
        # 输入投影层,将模型维度投影到内部维度的两倍,用于后续操作
        self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
        # 深度卷积层,输入和输出通道数相同,组数等于内部维度,用于空间特征提取
        self.conv2d = nn.Conv2d(
            in_channels=self.d_inner,
            out_channels=self.d_inner,
            groups=self.d_inner,
            bias=conv_bias,
            kernel_size=d_conv,
            padding=(d_conv - 1) // 2, # 保证输出的空间维度与输入相同
            **factory_kwargs,
        )
        self.act = nn.SiLU() # 激活函数使用 SiLU
        # 定义多个线性投影层,将内部维度投影到不同大小的向量,用于时间步长和状态
        self.x_proj = (
            nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), 
            nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), 
            nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), 
            nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), 
        )
        # 将四个线性投影层的权重合并为一个参数,方便计算
        self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K=4, N, inner)
        # 删除单独的投影层以节省内存
        del self.x_proj
        # 初始化时间步长的线性投影,定义四组时间步长投影参数
        self.dt_projs = (
            self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs),
            self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs),
            self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs),
            self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs),
        )
        # 将时间步长的权重和偏置参数合并为可训练参数
        self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K=4, inner, rank)
        self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K=4, inner)
        del self.dt_projs
        # 初始化 S4D 的 A 参数,用于状态更新计算
        self.A_logs = self.A_log_init(self.d_state, self.d_inner, copies=4, merge=True) # (K=4, D, N)
        # 初始化 D 参数,用于跳跃连接的计算
        self.Ds = self.D_init(self.d_inner, copies=4, merge=True) # (K=4, D, N)
        # 选择核心的前向计算函数版本,默认为 forward_corev0
        # self.selective_scan = selective_scan_fn
        self.forward_core = self.forward_corev0
        # 输出层的层归一化,归一化到内部维度
        self.out_norm = nn.LayerNorm(self.d_inner)
        # 输出投影层,将内部维度投影回原始模型维度
        self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
        # 设置 dropout 层,如果 dropout 参数大于 0,则应用随机失活以防止过拟合
        self.dropout = nn.Dropout(dropout) if dropout > 0. else None

    @staticmethod
    def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4, **factory_kwargs):
        dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs)
        # 初始化用于时间步长计算的线性投影层
        # Initialize special dt projection to preserve variance at initialization
        # 特殊初始化方法,用于保持初始化时的方差不变
        dt_init_std = dt_rank**-0.5 * dt_scale
        if dt_init == "constant": # 初始化为常数
            nn.init.constant_(dt_proj.weight, dt_init_std)
        elif dt_init == "random": # 初始化为均匀随机数
            nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std)
        else:
            raise NotImplementedError

        # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
        # 初始化偏置,以便在使用 F.softplus 时,结果处于 dt_min 和 dt_max 之间
        dt = torch.exp(
            torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
            + math.log(dt_min)
        ).clamp(min=dt_init_floor)
        # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
        # softplus 的逆操作,确保偏置初始化在合适范围内
        inv_dt = dt + torch.log(-torch.expm1(-dt))
        with torch.no_grad():
            dt_proj.bias.copy_(inv_dt)  # 设置偏置参数
        # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
        dt_proj.bias._no_reinit = True # 将该偏置标记为不重新初始化
        
        return dt_proj
  • SS_Conv_SSM

    这部分就是论文提出的创新点,图片中的结构

    class SS_Conv_SSM(nn.Module):
        def __init__(
            self,
            hidden_dim: int = 0,
            drop_path: float = 0,
            norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
            attn_drop_rate: float = 0,
            d_state: int = 16,
            **kwargs,
        ):
            super().__init__()
            # 初始化第一个归一化层,归一化的维度是隐藏维度的一半
            self.ln_1 = norm_layer(hidden_dim//2)
            # 初始化自注意力模块 SS2D,输入维度为隐藏维度的一半
            self.self_attention = SS2D(d_model=hidden_dim//2,
                                       dropout=attn_drop_rate,
                                       d_state=d_state,
                                       **kwargs)
            # DropPath 层,用于随机丢弃路径,提高模型的泛化能力
            self.drop_path = DropPath(drop_path)
            # 定义卷积模块,由多个卷积层和批量归一化层组成,用于特征提取
            self.conv33conv33conv11 = nn.Sequential(
                nn.BatchNorm2d(hidden_dim // 2),
                nn.Conv2d(in_channels=hidden_dim//2,out_channels=hidden_dim//2,kernel_size=3,stride=1,padding=1),
                nn.BatchNorm2d(hidden_dim//2),
                nn.ReLU(),
                nn.Conv2d(in_channels=hidden_dim // 2, out_channels=hidden_dim // 2, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(hidden_dim // 2),
                nn.ReLU(),
                nn.Conv2d(in_channels=hidden_dim // 2, out_channels=hidden_dim // 2, kernel_size=1, stride=1),
                nn.ReLU()
            )
            # 注释掉的最终卷积层,可能用于进一步调整输出维度
            # self.finalconv11 = nn.Conv2d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=1, stride=1)
        def forward(self, input: torch.Tensor):
            # 将输入张量沿最后一个维度分割为左右两部分
            input_left, input_right = input.chunk(2,dim=-1)
            # 对右侧输入进行归一化和自注意力操作,之后应用 DropPath 随机丢弃
            x = self.drop_path(self.self_attention(self.ln_1(input_right)))
            # 将左侧输入从 (batch_size, height, width, channels)
            # 转换为 (batch_size, channels, height, width) 以适应卷积操作
            input_left = input_left.permute(0,3,1,2).contiguous()
            input_left = self.conv33conv33conv11(input_left)
            # 将卷积后的左侧输入转换回原来的形状 (batch_size, height, width, channels)
            input_left = input_left.permute(0,2,3,1).contiguous()
            # 将左侧和右侧的输出在最后一个维度上拼接起来
            output = torch.cat((input_left,x),dim=-1)
            # 对拼接后的输出进行通道混洗,增加特征的融合
            output = channel_shuffle(output,groups=2)
            # 返回最终的输出,增加残差连接,将输入与输出相加
            return output+input
    
  • VSSLayer

    有以上结构堆叠构成网络结构

    class VSSLayer(nn.Module):
        """ A basic Swin Transformer layer for one stage.
        Args:
            dim (int): Number of input channels.
            depth (int): Number of blocks.
            drop (float, optional): Dropout rate. Default: 0.0
            attn_drop (float, optional): Attention dropout rate. Default: 0.0
            drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
            norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
            downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
            use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
        """
    
        def __init__(
            self, 
            dim, 
            depth, 
            attn_drop=0.,
            drop_path=0., 
            norm_layer=nn.LayerNorm, 
            downsample=None, 
            use_checkpoint=False, 
            d_state=16,
            **kwargs,
        ):
            super().__init__()
            # 设置输入通道数
            self.dim = dim
            # 是否使用检查点
            self.use_checkpoint = use_checkpoint
            # 创建 SS_Conv_SSM 块列表,数量为 depth
            self.blocks = nn.ModuleList([
                SS_Conv_SSM(
                    hidden_dim=dim, # 隐藏层维度等于输入维度
                    # 处理随机深度的丢弃率
                    drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                    norm_layer=norm_layer, # 使用的归一化层
                    attn_drop_rate=attn_drop, # 注意力丢弃率
                    d_state=d_state, # 状态维度
                )
                for i in range(depth)]) # 重复 depth 次构建块
            # 
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值