对于多分类的dice acc的计算有两种常用方式,见2/3。
传入dice()函数计算的都为one-not编码形式
1.one hot编码
def convert_to_one_hot_batch(seg):
res = np.zeros([seg.shape[0],4,seg.shape[1],seg.shape[2]], seg.dtype) #20,4,256,256
for i in range(len(seg)):
res[i]=convert_to_one_hot(seg[i])
return res
def convert_to_one_hot(seg):
vals = np.unique(seg)
vals=np.array([0,1,2,3]) #这一步由label的个数决定,直接指定的话能保证每一个的尺寸是一样的
res = np.zeros([len(vals)] + list(seg.shape), seg.dtype)
for c in range(len(vals)):
res[c][seg == c] = 1
return res
2.multi class dice_acc
dice = MulticlassDice()
seg_onehot = convert_to_one_hot_batch(seg.cpu().numpy()) # 20,4,256,256
activation_fn = nn.Softmax()
pref = activation_fn(pred)
acc, acc_list = dice(pref.cpu().detach().numpy(), seg_onehot)
class MulticlassDice(nn.Module):
"""
requires one hot encoded target. Applies DiceLoss on each class iteratively.
requires input.shape[0:1] and target.shape[0:1] to be (N, C) where N is
batch size and C is number of classes
"""
def __init__(self):
super(MulticlassDiceLoss, self).__init__()
def forward(self, input, target, weights=None):
C = target.shape[1]
# if weights is None:
# weights = torch.ones(C) #uniform weights for all classes
dice = Dice_acc()
totalLoss = 0
acc=[]
for i in range(1,C):
diceLoss = dice(input[:, i], target[:, i])
acc.append(diceLoss)
# print("dice",i,":",diceLoss.sum()/len(diceLoss))
if weights is not None:
diceLoss *= weights[i]
totalLoss += diceLoss
return totalLoss/(C-1),acc
3. 测试部分 dice acc求解
from medpy.metric.binary import hd, dc, assd
pred = np.uint8(np.argmax(prediction, axis=-1))
for struc in [3,1,2]:
gt_binary = (gt == struc) * 1 #one hot编码
pred_binary = (pred == struc) * 1 #one hot编码
dc(gt_binary, pred_binary)
4.keep_largest_connected_components以及argmax的操作提高acc值
参考:https://github.com/baumgach/acdc_segmenter/blob/master/evaluate_patients.py
if do_postprocessing:
prediction_arr = image_utils.keep_largest_connected_components(prediction_arr)
def keep_largest_connected_components(mask):
'''
Keeps only the largest connected components of each label for a segmentation mask.
'''
out_img = np.zeros(mask.shape, dtype=np.uint8)
for struc_id in [1, 2, 3]:
binary_img = mask == struc_id
blobs = measure.label(binary_img, connectivity=1)# 连通区域标记--4连通
props = measure.regionprops(blobs)##分别在每个连通区域上进行操作
if not props:
continue
area = [ele.area for ele in props]
largest_blob_ind = np.argmax(area)
largest_blob_label = props[largest_blob_ind].label
out_img[blobs == largest_blob_label] = struc_id
return out_img
"""
对每个类别,求出对应label上的所有连通区域,最后取出面积最大的那个区域作为最后的预测分割结果
"""