2d和3d cnn 解决医疗影像分析问题

项目介绍

本次项目我们对肿瘤病人的医学图像进行分类,从而预测其疗程后的康复情况。数据来源为真实医院中病人肠壁肿瘤的MRI图像,我们选取T2图像作为实验图像。每一组图像对应一个病人接受疗程前的肠壁肿瘤情况,而在医院进行新辅助疗程后,会再次检查病人的肿瘤情况,从而将其肿瘤康复情况分为0,1,2,3四个等级(0为最优,3为最差)。我们的任务就是对疗程前的图像进行四分类任务,看看能否在疗程前就预测他们疗程后康复的情况,从而对病人进行针对性治疗。

预处理方法

由于图像是dcm形式,而每个病人的slice张数不一(少则十几张,多则三十张),因此无法做到张数的统一。再者分辨率的标准化也很难达到,因为每个slice的层厚也不一。slice不一,层厚不一,层间距也不一,使得我们很难达成数据在Z轴(空间)上的统一。因此最终还是选择了2d的做法(3d的尝试后面再提)。2d做法及取每个病人单张slice进行训练,最后预测时则是分别取一个病人的所有slice进行分类预测,而病人的结果则是所有结果的平均值。这样做法的好处是不用做到数据在3维上的统一,并且网络的训练量也小了很多。可能的缺点是没有用到图像在空间上的联系。

spacing

在取slice时,我们去除了病人图像中没有肿瘤信息的slice,相当于做了个过滤。而对于单张slice,由于它们的xy像素值不一,为了适应模型resnet,我们选择resize至244x244。

        # 1. spacing resize
        if self.is_spacing is True:
            shape = self.dataset[index]["shape"]
            shape_spc = self.dataset[index]["shape_spc"]
            if shape[0] != shape_spc[0]:
                img = img.resize(size=(shape_spc[0], shape_spc[1]), resample=Image.NEAREST) # spacing resize

        # 2. 以肿瘤中心切割
        h_min, h_max, w_min, w_max = self.dataset[index]["tumor_hw_min_max_spc"]
        tumor_origin = ( (h_min + h_max) / 2, (w_min + w_max) / 2 )         # 肿瘤中心点坐标

        if self.is_train is True:
            crop_size = 224     # 切割后图片大小
        else:
            crop_size = 224     # 切割后图片大小

        img = TF.crop(
            img=img,        # Image to be cropped.
            i=int(round(tumor_origin[0] - crop_size / 2)),   # Upper pixel coordinate.
            j=int(round(tumor_origin[1] - crop_size / 2)),   # Left pixel coordinate.
            h=crop_size,    # Height of the cropped image.
            w=crop_size     # Width of the cropped image.
        )

        return img

交叉验证

采用5折交叉验证。对于不同的病人类型,按照比例随机划分。

augmentation

采用flip, resize, crop,直接使用pytorch的相关函数即可。

train_transform = transforms.Compose([
    # transforms.CenterCrop(size=224),
    # transforms.RandomRotation(degrees=[-10, 10]),
    # transforms.CenterCrop(size=512)
    # transforms.RandomCrop(size=224),
    # transforms.RandomResizedCrop(size=224),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    # transforms.ColorJitter(brightness=0.1, contrast=0.1),
])

模型

采取resnet残差网络,具体来说是resnet34。pytorch官网就有resnet的源代码,稍作修改即可。

class ResNet(nn.Module):
 
    def __init__(self, block, layers, num_classes=
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值