多分类--dice acc计算

对于多分类的dice acc的计算有两种常用方式,见2/3。
传入dice()函数计算的都为one-not编码形式

1.one hot编码

def convert_to_one_hot_batch(seg):
    res = np.zeros([seg.shape[0],4,seg.shape[1],seg.shape[2]], seg.dtype)  #20,4,256,256
    for i in range(len(seg)):
        res[i]=convert_to_one_hot(seg[i])
    return res

def convert_to_one_hot(seg):
    vals = np.unique(seg)
    vals=np.array([0,1,2,3])  #这一步由label的个数决定,直接指定的话能保证每一个的尺寸是一样的
    res = np.zeros([len(vals)] + list(seg.shape), seg.dtype)
    for c in range(len(vals)):
        res[c][seg == c] = 1
    return res

2.multi class dice_acc

    dice = MulticlassDice()
    seg_onehot = convert_to_one_hot_batch(seg.cpu().numpy())  # 20,4,256,256
    activation_fn = nn.Softmax()
    pref = activation_fn(pred)
    acc, acc_list = dice(pref.cpu().detach().numpy(), seg_onehot)
class MulticlassDice(nn.Module):
    """
    requires one hot encoded target. Applies DiceLoss on each class iteratively.
    requires input.shape[0:1] and target.shape[0:1] to be (N, C) where N is
      batch size and C is number of classes
    """

    def __init__(self):
        super(MulticlassDiceLoss, self).__init__()

    def forward(self, input, target, weights=None):
        C = target.shape[1]
        # if weights is None:
        # 	weights = torch.ones(C) #uniform weights for all classes

        dice = Dice_acc()
        totalLoss = 0
        acc=[]
        for i in range(1,C):
            diceLoss = dice(input[:, i], target[:, i])
            acc.append(diceLoss)
            # print("dice",i,":",diceLoss.sum()/len(diceLoss))
            if weights is not None:
                diceLoss *= weights[i]
            totalLoss += diceLoss
        return totalLoss/(C-1),acc

3. 测试部分 dice acc求解

from medpy.metric.binary import hd, dc, assd

pred = np.uint8(np.argmax(prediction, axis=-1))

for struc in [3,1,2]:
   gt_binary = (gt == struc) * 1   #one hot编码
   pred_binary = (pred == struc) * 1  #one hot编码
   dc(gt_binary, pred_binary)

4.keep_largest_connected_components以及argmax的操作提高acc值

参考:https://github.com/baumgach/acdc_segmenter/blob/master/evaluate_patients.py

 if do_postprocessing:
       prediction_arr = image_utils.keep_largest_connected_components(prediction_arr)

def keep_largest_connected_components(mask):
    '''
    Keeps only the largest connected components of each label for a segmentation mask.
    '''

    out_img = np.zeros(mask.shape, dtype=np.uint8)

    for struc_id in [1, 2, 3]:

        binary_img = mask == struc_id
        blobs = measure.label(binary_img, connectivity=1)#  连通区域标记--4连通

        props = measure.regionprops(blobs)##分别在每个连通区域上进行操作

        if not props:
            continue

        area = [ele.area for ele in props]
        largest_blob_ind = np.argmax(area)
        largest_blob_label = props[largest_blob_ind].label

        out_img[blobs == largest_blob_label] = struc_id

    return out_img

"""
对每个类别,求出对应label上的所有连通区域,最后取出面积最大的那个区域作为最后的预测分割结果
"""
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值