【mmsegmentation】mmsegmentation中如何增加更多的评估指标?含源代码

代码背景

下面的代码定义了一个名为 IoUMetric 的评估指标类,它继承自 mmengine.evaluator.BaseMetric,用于在语义分割任务中计算多个评估指标。在语义分割任务里,我们需要评估模型预测的分割结果与真实标签之间的匹配程度,而 IoUMetric 类提供了一系列常用的评估指标,如交并比(IoU)、Dice 系数、F1 分数、精确率、召回率和像素准确率(PA)等,帮助我们全面地评估模型的性能。mmsegmentation 是 OpenMMLab 开源的语义分割工具箱,此代码是该工具箱中评估模块的一部分,方便用户在不同的语义分割模型上进行评估。

代码解读

1. 导入必要的库
import os.path as osp
from collections import OrderedDict
from typing import Dict, List, Optional, Sequence
from mmengine.evaluator import BaseMetric
import numpy as np
import torch
from mmengine.dist import is_main_process
from mmengine.logging import MMLogger, print_log
from mmengine.utils import mkdir_or_exist
from PIL import Image
from prettytable import PrettyTable

from mmseg.registry import METRICS

这里导入了多个库,包括文件路径操作、数据结构、类型提示、评估基类、数值计算、深度学习框架、分布式训练、日志记录、文件操作、图像处理和表格输出等相关的库,为后续代码的实现提供了必要的工具。

2. 定义 IoUMetric
@METRICS.register_module()
class IoUMetric(BaseMetric):

使用 @METRICS.register_module() 装饰器将 IoUMetric 类注册到 mmseg 的评估指标注册表中,方便在配置文件中使用。该类继承自 BaseMetric,继承了一些基础的评估功能。

3. __init__ 方法
def __init__(self,
             ignore_index: int = 255,
             iou_metrics: List[str] = ['mIoU'],
             nan_to_num: Optional[int] = None,
             beta: int = 1,
             collect_device: str = 'cpu',
             output_dir: Optional[str] = None,
             format_only: bool = False,
             prefix: Optional[str] = None,
             **kwargs) -> None:
    super().__init__(collect_device=collect_device, prefix=prefix)

    self.ignore_index = ignore_index
    self.metrics = iou_metrics
    self.nan_to_num = nan_to_num
    self.beta = beta
    self.output_dir = output_dir
    if self.output_dir and is_main_process():
        mkdir_or_exist(self.output_dir)
    self.format_only = format_only

这是类的初始化方法,接收多个参数,包括要忽略的索引、要计算的评估指标列表、处理 NaN 值的参数、F1 分数中的权重参数、收集结果的设备、输出目录、是否仅格式化结果而不进行评估以及指标前缀等。同时,调用父类的初始化方法,并对一些属性进行初始化。如果指定了输出目录且当前进程为主进程,则创建该目录。

4. process 方法
def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
    num_classes = len(self.dataset_meta['classes'])
    for data_sample in data_samples:
        pred_label = data_sample['pred_sem_seg']['data'].squeeze()
        if not self.format_only:
            label = data_sample['gt_sem_seg']['data'].squeeze().to(
                pred_label)
            self.results.append(
                self.intersect_and_union(pred_label, label, num_classes,
                                         self.ignore_index))
        if self.output_dir is not None:
            basename = osp.splitext(osp.basename(
                data_sample['img_path']))[0]
            png_filename = osp.abspath(
                osp.join(self.output_dir, f'{basename}.png'))
            output_mask = pred_label.cpu().numpy()
            if data_sample.get('reduce_zero_label', False):
                output_mask = output_mask + 1
            output = Image.fromarray(output_mask.astype(np.uint8))
            output.save(png_filename)

该方法用于处理一个批次的数据和样本。对于每个样本,提取预测标签和真实标签(如果不只是格式化结果),调用 intersect_and_union 方法计算交集和并集,并将结果存储在 self.results 中。如果指定了输出目录,则将预测结果保存为 PNG 图像。

5. compute_metrics 方法
def compute_metrics(self, results: list) -> Dict[str, float]:
    logger: MMLogger = MMLogger.get_current_instance()
    if self.format_only:
        logger.info(f'results are saved to {osp.dirname(self.output_dir)}')
        return OrderedDict()
    results = tuple(zip(*results))
    assert len(results) == 4

    total_area_intersect = sum(results[0])
    total_area_union = sum(results[1])
    total_area_pred_label = sum(results[2])
    total_area_label = sum(results[3])
    ret_metrics = self.total_area_to_metrics(
        total_area_intersect, total_area_union, total_area_pred_label,
        total_area_label, self.metrics, self.nan_to_num, self.beta)

    class_names = self.dataset_meta['classes']

    ret_metrics_summary = OrderedDict({
        ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2)
        for ret_metric, ret_metric_value in ret_metrics.items()
    })
    metrics = dict()
    for key, val in ret_metrics_summary.items():
        if key == 'aAcc':
            metrics[key] = val
        elif key == 'PA':
            metrics[key] = val
        else:
            metrics['m' + key] = val

    ret_metrics.pop('aAcc', None)
    ret_metrics_class = OrderedDict({
        ret_metric: np.round(ret_metric_value * 100, 2)
        for ret_metric, ret_metric_value in ret_metrics.items()
    })
    ret_metrics_class.update({'Class': class_names})
    ret_metrics_class.move_to_end('Class', last=False)
    class_table_data = PrettyTable()
    for key, val in ret_metrics_class.items():
        class_table_data.add_column(key, val)

    print_log('per class results:', logger)
    print_log('\n' + class_table_data.get_string(), logger=logger)

    return metrics

该方法用于从处理后的结果中计算最终的评估指标。首先检查是否只是格式化结果,如果是则记录日志并返回空字典。然后将结果按类别合并,调用 total_area_to_metrics 方法计算各项指标。接着对计算结果进行汇总和格式化,包括计算平均指标、处理类别名称等。最后将每个类别的评估结果以表格形式输出,并返回最终的评估指标字典。

6. intersect_and_union 方法
@staticmethod
def intersect_and_union(pred_label: torch.tensor, label: torch.tensor,
                        num_classes: int, ignore_index: int):
    mask = (label != ignore_index)
    pred_label = pred_label[mask]
    label = label[mask]

    intersect = pred_label[pred_label == label]
    area_intersect = torch.histc(
        intersect.float(), bins=(num_classes), min=0,
        max=num_classes - 1).cpu()
    area_pred_label = torch.histc(
        pred_label.float(), bins=(num_classes), min=0,
        max=num_classes - 1).cpu()
    area_label = torch.histc(
        label.float(), bins=(num_classes), min=0,
        max=num_classes - 1).cpu()
    area_union = area_pred_label + area_label - area_intersect
    return area_intersect, area_union, area_pred_label, area_label

这是一个静态方法,用于计算预测标签和真实标签之间的交集和并集。首先过滤掉要忽略的索引,然后找出预测标签和真实标签相同的部分作为交集,使用 torch.histc 函数计算每个类别的交集、预测标签和真实标签的面积,最后计算并集面积并返回。

7. total_area_to_metrics 方法
@staticmethod
def total_area_to_metrics(total_area_intersect: np.ndarray,
                          total_area_union: np.ndarray,
                          total_area_pred_label: np.ndarray,
                          total_area_label: np.ndarray,
                          metrics: List[str] = ['mIoU'],
                          nan_to_num: Optional[int] = None,
                          beta: int = 1):
    def f_score(precision, recall, beta=1):
        score = (1 + beta**2) * (precision * recall) / (
            (beta**2 * precision) + recall)
        return score

    if isinstance(metrics, str):
        metrics = [metrics]
    allowed_metrics = ['mIoU', 'mDice', 'mFscore', 'Precision', 'Recall', 'PA']
    if not set(metrics).issubset(set(allowed_metrics)):
        raise KeyError(f'metrics {metrics} is not supported')

    all_acc = total_area_intersect.sum() / total_area_label.sum()
    ret_metrics = OrderedDict({'aAcc': all_acc})
    for metric in metrics:
        if metric == 'mIoU':
            iou = total_area_intersect / total_area_union
            acc = total_area_intersect / total_area_label
            ret_metrics['IoU'] = iou
            ret_metrics['Acc'] = acc
        elif metric == 'mDice':
            dice = 2 * total_area_intersect / (
                total_area_pred_label + total_area_label)
            acc = total_area_intersect / total_area_label
            ret_metrics['Dice'] = dice
            ret_metrics['Acc'] = acc
        elif metric == 'mFscore':
            precision = total_area_intersect / total_area_pred_label
            recall = total_area_intersect / total_area_label
            f_value = torch.tensor([
                f_score(x[0], x[1], beta) for x in zip(precision, recall)
            ])
            ret_metrics['Fscore'] = f_value
            ret_metrics['Precision'] = precision
            ret_metrics['Recall'] = recall
        elif metric == 'Precision':
            precision = total_area_intersect / total_area_pred_label
            ret_metrics['Precision'] = precision
        elif metric == 'Recall':
            recall = total_area_intersect / total_area_label
            ret_metrics['Recall'] = recall
        elif metric == 'PA':
            pa = total_area_intersect / total_area_label
            ret_metrics['PA'] = pa

    ret_metrics = {
        metric: value.numpy()
        for metric, value in ret_metrics.items()
    }
    if nan_to_num is not None:
        ret_metrics = OrderedDict({
            metric: np.nan_to_num(metric_value, nan=nan_to_num)
            for metric, metric_value in ret_metrics.items()
        })
    return ret_metrics

这也是一个静态方法,用于根据交集、并集、预测标签面积和真实标签面积计算各项评估指标。首先定义了计算 F1 分数的辅助函数 f_score,然后检查要计算的指标是否合法。接着计算总体准确率 aAcc,并根据不同的指标计算相应的评估结果,存储在 ret_metrics 字典中。最后将结果转换为 NumPy 数组,并处理可能的 NaN 值,返回最终的评估指标字典。

综上所述,这段代码实现了一个完整的语义分割评估指标计算类,通过处理预测结果和真实标签,计算多个常用的评估指标,并以表格形式输出每个类别的评估结果。

源代码

# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from collections import OrderedDict
from typing import Dict, List, Optional, Sequence
from mmengine.evaluator import BaseMetric
import numpy as np
import torch
from mmengine.dist import is_main_process
from mmengine.logging import MMLogger, print_log
from mmengine.utils import mkdir_or_exist
from PIL import Image
from prettytable import PrettyTable

from mmseg.registry import METRICS


@METRICS.register_module()
class IoUMetric(BaseMetric):
    """IoU evaluation metric.

    Args:
        ignore_index (int): Index that will be ignored in evaluation.
            Default: 255.
        iou_metrics (list[str] | str): Metrics to be calculated, the options
            includes 'mIoU', 'mDice', 'mFscore' and 'PA'.
        nan_to_num (int, optional): If specified, NaN values will be replaced
            by the numbers defined by the user. Default: None.
        beta (int): Determines the weight of recall in the combined score.
            Default: 1.
        collect_device (str): Device name used for collecting results from
            different ranks during distributed training. Must be 'cpu' or
            'gpu'. Defaults to 'cpu'.
        output_dir (str): The directory for output prediction. Defaults to
            None.
        format_only (bool): Only format result for results commit without
            perform evaluation. It is useful when you want to save the result
            to a specific format and submit it to the test server.
            Defaults to False.
        prefix (str, optional): The prefix that will be added in the metric
            names to disambiguate homonymous metrics of different evaluators.
            If prefix is not provided in the argument, self.default_prefix
            will be used instead. Defaults to None.
    """

    def __init__(self,
                 ignore_index: int = 255,
                 iou_metrics: List[str] = ['mIoU'],
                 nan_to_num: Optional[int] = None,
                 beta: int = 1,
                 collect_device: str = 'cpu',
                 output_dir: Optional[str] = None,
                 format_only: bool = False,
                 prefix: Optional[str] = None,
                 **kwargs) -> None:
        super().__init__(collect_device=collect_device, prefix=prefix)

        self.ignore_index = ignore_index
        self.metrics = iou_metrics
        self.nan_to_num = nan_to_num
        self.beta = beta
        self.output_dir = output_dir
        if self.output_dir and is_main_process():
            mkdir_or_exist(self.output_dir)
        self.format_only = format_only

    def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
        """Process one batch of data and data_samples.

        The processed results should be stored in ``self.results``, which will
        be used to compute the metrics when all batches have been processed.

        Args:
            data_batch (dict): A batch of data from the dataloader.
            data_samples (Sequence[dict]): A batch of outputs from the model.
        """
        num_classes = len(self.dataset_meta['classes'])
        for data_sample in data_samples:
            pred_label = data_sample['pred_sem_seg']['data'].squeeze()
            # format_only always for test dataset without ground truth
            if not self.format_only:
                label = data_sample['gt_sem_seg']['data'].squeeze().to(
                    pred_label)
                self.results.append(
                    self.intersect_and_union(pred_label, label, num_classes,
                                             self.ignore_index))
            # format_result
            if self.output_dir is not None:
                basename = osp.splitext(osp.basename(
                    data_sample['img_path']))[0]
                png_filename = osp.abspath(
                    osp.join(self.output_dir, f'{basename}.png'))
                output_mask = pred_label.cpu().numpy()
                # The index range of official ADE20k dataset is from 0 to 150.
                # But the index range of output is from 0 to 149.
                # That is because we set reduce_zero_label=True.
                if data_sample.get('reduce_zero_label', False):
                    output_mask = output_mask + 1
                output = Image.fromarray(output_mask.astype(np.uint8))
                output.save(png_filename)

    def compute_metrics(self, results: list) -> Dict[str, float]:
        """Compute the metrics from processed results.

        Args:
            results (list): The processed results of each batch.

        Returns:
            Dict[str, float]: The computed metrics. The keys are the names of
                the metrics, and the values are corresponding results. The key
                mainly includes aAcc, mIoU, mAcc, mDice, mFscore, mPrecision,
                mRecall, PA.
        """
        logger: MMLogger = MMLogger.get_current_instance()
        if self.format_only:
            logger.info(f'results are saved to {osp.dirname(self.output_dir)}')
            return OrderedDict()
        # convert list of tuples to tuple of lists, e.g.
        # [(A_1, B_1, C_1, D_1), ...,  (A_n, B_n, C_n, D_n)] to
        # ([A_1, ..., A_n], ..., [D_1, ..., D_n])
        results = tuple(zip(*results))
        assert len(results) == 4

        total_area_intersect = sum(results[0])
        total_area_union = sum(results[1])
        total_area_pred_label = sum(results[2])
        total_area_label = sum(results[3])
        ret_metrics = self.total_area_to_metrics(
            total_area_intersect, total_area_union, total_area_pred_label,
            total_area_label, self.metrics, self.nan_to_num, self.beta)

        class_names = self.dataset_meta['classes']

        # summary table
        ret_metrics_summary = OrderedDict({
            ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2)
            for ret_metric, ret_metric_value in ret_metrics.items()
        })
        metrics = dict()
        for key, val in ret_metrics_summary.items():
            if key == 'aAcc':
                metrics[key] = val
            elif key == 'PA':
                metrics[key] = val
            else:
                metrics['m' + key] = val

        # each class table
        ret_metrics.pop('aAcc', None)
        ret_metrics_class = OrderedDict({
            ret_metric: np.round(ret_metric_value * 100, 2)
            for ret_metric, ret_metric_value in ret_metrics.items()
        })
        ret_metrics_class.update({'Class': class_names})
        ret_metrics_class.move_to_end('Class', last=False)
        class_table_data = PrettyTable()
        for key, val in ret_metrics_class.items():
            class_table_data.add_column(key, val)

        print_log('per class results:', logger)
        print_log('\n' + class_table_data.get_string(), logger=logger)

        return metrics

    @staticmethod
    def intersect_and_union(pred_label: torch.tensor, label: torch.tensor,
                            num_classes: int, ignore_index: int):
        """Calculate Intersection and Union.

        Args:
            pred_label (torch.tensor): Prediction segmentation map
                or predict result filename. The shape is (H, W).
            label (torch.tensor): Ground truth segmentation map
                or label filename. The shape is (H, W).
            num_classes (int): Number of categories.
            ignore_index (int): Index that will be ignored in evaluation.

        Returns:
            torch.Tensor: The intersection of prediction and ground truth
                histogram on all classes.
            torch.Tensor: The union of prediction and ground truth histogram on
                all classes.
            torch.Tensor: The prediction histogram on all classes.
            torch.Tensor: The ground truth histogram on all classes.
        """

        mask = (label != ignore_index)
        pred_label = pred_label[mask]
        label = label[mask]

        intersect = pred_label[pred_label == label]
        area_intersect = torch.histc(
            intersect.float(), bins=(num_classes), min=0,
            max=num_classes - 1).cpu()
        area_pred_label = torch.histc(
            pred_label.float(), bins=(num_classes), min=0,
            max=num_classes - 1).cpu()
        area_label = torch.histc(
            label.float(), bins=(num_classes), min=0,
            max=num_classes - 1).cpu()
        area_union = area_pred_label + area_label - area_intersect
        return area_intersect, area_union, area_pred_label, area_label

    @staticmethod
    def total_area_to_metrics(total_area_intersect: np.ndarray,
                              total_area_union: np.ndarray,
                              total_area_pred_label: np.ndarray,
                              total_area_label: np.ndarray,
                              metrics: List[str] = ['mIoU'],
                              nan_to_num: Optional[int] = None,
                              beta: int = 1):
        """Calculate evaluation metrics
        Args:
            total_area_intersect (np.ndarray): The intersection of prediction
                and ground truth histogram on all classes.
            total_area_union (np.ndarray): The union of prediction and ground
                truth histogram on all classes.
            total_area_pred_label (np.ndarray): The prediction histogram on
                all classes.
            total_area_label (np.ndarray): The ground truth histogram on
                all classes.
            metrics (List[str] | str): Metrics to be evaluated, 'mIoU',
                'mDice', 'mFscore' and 'PA'.
            nan_to_num (int, optional): If specified, NaN values will be
                replaced by the numbers defined by the user. Default: None.
            beta (int): Determines the weight of recall in the combined score.
                Default: 1.
        Returns:
            Dict[str, np.ndarray]: per category evaluation metrics,
                shape (num_classes, ).
        """

        def f_score(precision, recall, beta=1):
            """calculate the f-score value.

            Args:
                precision (float | torch.Tensor): The precision value.
                recall (float | torch.Tensor): The recall value.
                beta (int): Determines the weight of recall in the combined
                    score. Default: 1.

            Returns:
                [torch.tensor]: The f-score value.
            """
            score = (1 + beta**2) * (precision * recall) / (
                (beta**2 * precision) + recall)
            return score

        if isinstance(metrics, str):
            metrics = [metrics]
        allowed_metrics = ['mIoU', 'mDice', 'mFscore', 'Precision', 'Recall', 'PA']
        if not set(metrics).issubset(set(allowed_metrics)):
            raise KeyError(f'metrics {metrics} is not supported')

        all_acc = total_area_intersect.sum() / total_area_label.sum()
        ret_metrics = OrderedDict({'aAcc': all_acc})
        for metric in metrics:
            if metric == 'mIoU':
                iou = total_area_intersect / total_area_union
                acc = total_area_intersect / total_area_label
                ret_metrics['IoU'] = iou
                ret_metrics['Acc'] = acc
            elif metric == 'mDice':
                dice = 2 * total_area_intersect / (
                    total_area_pred_label + total_area_label)
                acc = total_area_intersect / total_area_label
                ret_metrics['Dice'] = dice
                ret_metrics['Acc'] = acc
            elif metric == 'mFscore':
                precision = total_area_intersect / total_area_pred_label
                recall = total_area_intersect / total_area_label
                f_value = torch.tensor([
                    f_score(x[0], x[1], beta) for x in zip(precision, recall)
                ])
                ret_metrics['Fscore'] = f_value
                ret_metrics['Precision'] = precision
                ret_metrics['Recall'] = recall
            elif metric == 'Precision':
                precision = total_area_intersect / total_area_pred_label
                ret_metrics['Precision'] = precision
            elif metric == 'Recall':
                recall = total_area_intersect / total_area_label
                ret_metrics['Recall'] = recall
            elif metric == 'PA':
                pa = total_area_intersect / total_area_label
                ret_metrics['PA'] = pa

        ret_metrics = {
            metric: value.numpy()
            for metric, value in ret_metrics.items()
        }
        if nan_to_num is not None:
            ret_metrics = OrderedDict({
                metric: np.nan_to_num(metric_value, nan=nan_to_num)
                for metric, metric_value in ret_metrics.items()
            })
        return ret_metrics
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

旅途中的宽~

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值