nnMamba用于糖尿病视网膜病变检测测试

1.代码修改

源码是针对3D单通道图像的,只需要简单改写为2D就行,修改nnMamba4cls.py代码如下:

# -*- coding: utf-8 -*-
# 作者: Mr Cun
# 文件名: nnMamba4cls.py
# 创建时间: 2024-10-25
# 文件描述:修改nnmamba,使其适应3通道2分类DR分类任务


import torch
import torch.nn as nn
import torch.nn.functional as F
from mamba_ssm import Mamba


def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding."""
    return nn.Conv2d(
        in_planes,
        out_planes,
        kernel_size=3,
        stride=stride,
        padding=dilation,
        groups=groups,
        bias=False,
        dilation=dilation,
    )


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution."""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()

        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


def make_res_layer(inplanes, planes, blocks, stride=1):
    downsample = nn.Sequential(
        conv1x1(inplanes, planes, stride),
        nn.BatchNorm2d(planes),
    )

    layers = []
    layers.append(BasicBlock(inplanes, planes, stride, downsample))
    for _ in range(1, blocks):
        layers.append(BasicBlock(planes, planes))

    return nn.Sequential(*layers)


class MambaLayer(nn.Module):
    def __init__(self, dim, d_state=8, d_conv=4, expand=2):
        super().__init__()
        self.dim = dim
        self.nin = conv1x1(dim, dim)
        self.nin2 = conv1x1(dim, dim)
        self.norm2 = nn.BatchNorm2d(dim) # LayerNorm
        self.relu2 = nn.ReLU(inplace=True)
        self.relu3 = nn.ReLU(inplace=True)

        self.norm = nn.BatchNorm2d(dim) # LayerNorm
        self.relu = nn.ReLU(inplace=True)
        self.mamba = Mamba(
            d_model=dim,  # Model dimension d_model
            d_state=d_state,  # SSM state expansion factor
            d_conv=d_conv,  # Local convolution width
            expand=expand  # Block expansion factor
        )

    def forward(self, x):
        B, C = x.shape[:2]
        x = self.nin(x)
        x = self.norm(x)
        x = self.relu(x)
        act_x = x
        assert C == self.dim
        n_tokens = x.shape[2:].numel()
        img_dims = x.shape[2:]
        x_flat = x.reshape(B, C, n_tokens).transpose(-1, -2)
        x_mamba = self.mamba(x_flat)
        out = x_mamba.transpose(-1, -2).reshape(B, C, *img_dims)
        # act_x = self.relu3(x)
        out += act_x
        out = self.nin2(out)
        out = self.norm2(out)
        out = self.relu2(out)
        return out


class MambaSeq(nn.Module):
    def __init__(self, dim, d_state=16, d_conv=4, expand=2):
        super().__init__()
        self.dim = dim
        self.relu = nn.ReLU(inplace=True)
        self.mamba = Mamba(
            d_model=dim,  # Model dimension d_model
            d_state=d_state,  # SSM state expansion factor
            d_conv=d_conv,  # Local convolution width
            expand=expand  # Block expansion factor
        )

    def forward(self, x):
        B, C = x.shape[:2]
        x = self.relu(x)
        assert C == self.dim
        n_tokens = x.shape[2:].numel()
        img_dims = x.shape[2:]
        x_flat = x.reshape(B, C, n_tokens).transpose(-1, -2)
        x_mamba = self.mamba(x_flat)
        out = x_mamba.transpose(-1, -2).reshape(B, C, *img_dims)
        return out

class DoubleConv(nn.Module):

    def __init__(self, in_ch, out_ch, stride=1, kernel_size=3):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=int(kernel_size / 2)),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, dilation=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, input):
        return self.conv(input)


class SingleConv(nn.Module):

    def __init__(self, in_ch, out_ch):
        super(SingleConv, self).__init__()
        self.conv = nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True))

    def forward(self, input):
        return self.conv(input)


class nnMambaEncoder(nn.Module):
    def __init__(self, in_ch=3, channels=32, blocks=3, number_classes=2):
        super(nnMambaEncoder, self).__init__()
        self.in_conv = DoubleConv(in_ch, channels, stride=2, kernel_size=3)
        self.mamba_layer_stem = MambaLayer(
            dim=channels,  # Model dimension d_model
            d_state=8,  # SSM state expansion factor
            d_conv=4,  # Local convolution width
            expand=2  # Block expansion factor
        )

        self.layer1 = make_res_layer(channels, channels * 2, blocks, stride=2)
        self.layer2 = make_res_layer(channels * 2, channels * 4, blocks, stride=2)
        self.layer3 = make_res_layer(channels * 4, channels * 8, blocks, stride=2)

        self.pooling =  nn.AdaptiveAvgPool2d((1, 1))

        self.mamba_seq = MambaSeq(
            dim=channels*2,  # Model dimension d_model
            d_state=8,  # SSM state expansion factor
            d_conv=2,  # Local convolution width
            expand=2  # Block expansion factor
        )
        
        self.mlp = nn.Sequential(nn.Linear(channels*14, channels), nn.ReLU(), nn.Dropout(0.5), nn.Linear(channels, number_classes))


    def forward(self, x):
        c1 = self.in_conv(x)
        c1_s = self.mamba_layer_stem(c1) + c1
        c2 = self.layer1(c1_s)
        c3 = self.layer2(c2)
        c4 = self.layer3(c3)
        pooled_c2_s = self.pooling(c2)
        pooled_c3_s = self.pooling(c3)
        pooled_c4_s = self.pooling(c4)

        h_feature = torch.cat((pooled_c2_s.reshape(c1.shape[0], c1.shape[1]*2, 1), pooled_c3_s.reshape(c1.shape[0], c1.shape[1]*2, 2), pooled_c4_s.reshape(c1.shape[0], c1.shape[1]*2, 4)), dim=2)

        h_feature_att = self.mamba_seq(h_feature) + h_feature
        h_feature = h_feature_att.reshape(c1.shape[0], -1)

        return self.mlp(h_feature)
 

if __name__ == "__main__":
    model = nnMambaEncoder().cuda()

    input = torch.zeros((8, 3, 224,224)).cuda()
    output = model(input)
    print(output.shape)

2.增加训练代码和数据集代码

  • dr_dataset.py
# -*- coding: utf-8 -*-
# 作者: Mr.Cun
# 文件名: dr_dataset.py
# 创建时间: 2024-10-25
# 文件描述:视网膜数据处理

import torch
import numpy as np
import os
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import transforms, datasets

root_path = '/home/aic/deep_learning_data/retino_data'
batch_size = 64  # 根据自己电脑量力而行
class_labels = {
   0: 'Diabetic Retinopathy', 1: 'No Diabetic Retinopathy'}
# 设置随机种子
torch.manual_seed(42)
np.random.seed(42)


class RetinaDataset:
    def __init__(self, root_path, batch_size,class_labels):
        self.root_path = root_path
        self.batch_size = batch_size
        self.class_labels = class_labels
        self.transform = self._set_transforms()
        self.train_dataset = self._load_dataset('train')
        self.val_dataset = self._load_dataset('valid')
        self.test_dataset = self._load_dataset('test')
        self.train_loader = DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)
        self.valid_loader = DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False)
        self.test_loader = DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False)

    def _set_transforms(self):
        return transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.RandomRotation(30),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

    def _load_dataset(self, split):
        root = os.path.join(self.root_path, split)
        return datasets.ImageFolder(root=root, transform=self.transform)

    def visualize_samples(self, loader):
        figure = plt.figure(figsize=(12, 12))
        cols, rows = 4, 4
        for i in range(1, cols * rows + 1):
            sample_idx = np.random.randint(len(loader.dataset))
            img, label = loader.dataset[sample_idx]
            figure.add_subplot(rows, cols, i)
            plt.title(self.class_labels[label])
            plt.axis("off")
            img_np = img.numpy().transpose((1, 
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值