代码背景
下面的代码定义了一个名为 IoUMetric
的评估指标类,它继承自 mmengine.evaluator.BaseMetric
,用于在语义分割任务中计算多个评估指标。在语义分割任务里,我们需要评估模型预测的分割结果与真实标签之间的匹配程度,而 IoUMetric
类提供了一系列常用的评估指标,如交并比(IoU)、Dice 系数、F1 分数、精确率、召回率和像素准确率(PA)等,帮助我们全面地评估模型的性能。mmsegmentation
是 OpenMMLab 开源的语义分割工具箱,此代码是该工具箱中评估模块的一部分,方便用户在不同的语义分割模型上进行评估。
代码解读
1. 导入必要的库
import os.path as osp
from collections import OrderedDict
from typing import Dict, List, Optional, Sequence
from mmengine.evaluator import BaseMetric
import numpy as np
import torch
from mmengine.dist import is_main_process
from mmengine.logging import MMLogger, print_log
from mmengine.utils import mkdir_or_exist
from PIL import Image
from prettytable import PrettyTable
from mmseg.registry import METRICS
这里导入了多个库,包括文件路径操作、数据结构、类型提示、评估基类、数值计算、深度学习框架、分布式训练、日志记录、文件操作、图像处理和表格输出等相关的库,为后续代码的实现提供了必要的工具。
2. 定义 IoUMetric
类
@METRICS.register_module()
class IoUMetric(BaseMetric):
使用 @METRICS.register_module()
装饰器将 IoUMetric
类注册到 mmseg
的评估指标注册表中,方便在配置文件中使用。该类继承自 BaseMetric
,继承了一些基础的评估功能。
3. __init__
方法
def __init__(self,
ignore_index: int = 255,
iou_metrics: List[str] = ['mIoU'],
nan_to_num: Optional[int] = None,
beta: int = 1,
collect_device: str = 'cpu',
output_dir: Optional[str] = None,
format_only: bool = False,
prefix: Optional[str] = None,
**kwargs) -> None:
super().__init__(collect_device=collect_device, prefix=prefix)
self.ignore_index = ignore_index
self.metrics = iou_metrics
self.nan_to_num = nan_to_num
self.beta = beta
self.output_dir = output_dir
if self.output_dir and is_main_process():
mkdir_or_exist(self.output_dir)
self.format_only = format_only
这是类的初始化方法,接收多个参数,包括要忽略的索引、要计算的评估指标列表、处理 NaN 值的参数、F1 分数中的权重参数、收集结果的设备、输出目录、是否仅格式化结果而不进行评估以及指标前缀等。同时,调用父类的初始化方法,并对一些属性进行初始化。如果指定了输出目录且当前进程为主进程,则创建该目录。
4. process
方法
def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
num_classes = len(self.dataset_meta['classes'])
for data_sample in data_samples:
pred_label = data_sample['pred_sem_seg']['data'].squeeze()
if not self.format_only:
label = data_sample['gt_sem_seg']['data'].squeeze().to(
pred_label)
self.results.append(
self.intersect_and_union(pred_label, label, num_classes,
self.ignore_index))
if self.output_dir is not None:
basename = osp.splitext(osp.basename(
data_sample['img_path']))[0]
png_filename = osp.abspath(
osp.join(self.output_dir, f'{basename}.png'))
output_mask = pred_label.cpu().numpy()
if data_sample.get('reduce_zero_label', False):
output_mask = output_mask + 1
output = Image.fromarray(output_mask.astype(np.uint8))
output.save(png_filename)
该方法用于处理一个批次的数据和样本。对于每个样本,提取预测标签和真实标签(如果不只是格式化结果),调用 intersect_and_union
方法计算交集和并集,并将结果存储在 self.results
中。如果指定了输出目录,则将预测结果保存为 PNG 图像。
5. compute_metrics
方法
def compute_metrics(self, results: list) -> Dict[str, float]:
logger: MMLogger = MMLogger.get_current_instance()
if self.format_only:
logger.info(f'results are saved to {osp.dirname(self.output_dir)}')
return OrderedDict()
results = tuple(zip(*results))
assert len(results) == 4
total_area_intersect = sum(results[0])
total_area_union = sum(results[1])
total_area_pred_label = sum(results[2])
total_area_label = sum(results[3])
ret_metrics = self.total_area_to_metrics(
total_area_intersect, total_area_union, total_area_pred_label,
total_area_label, self.metrics, self.nan_to_num, self.beta)
class_names = self.dataset_meta['classes']
ret_metrics_summary = OrderedDict({
ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2)
for ret_metric, ret_metric_value in ret_metrics.items()
})
metrics = dict()
for key, val in ret_metrics_summary.items():
if key == 'aAcc':
metrics[key] = val
elif key == 'PA':
metrics[key] = val
else:
metrics['m' + key] = val
ret_metrics.pop('aAcc', None)
ret_metrics_class = OrderedDict({
ret_metric: np.round(ret_metric_value * 100, 2)
for ret_metric, ret_metric_value in ret_metrics.items()
})
ret_metrics_class.update({'Class': class_names})
ret_metrics_class.move_to_end('Class', last=False)
class_table_data = PrettyTable()
for key, val in ret_metrics_class.items():
class_table_data.add_column(key, val)
print_log('per class results:', logger)
print_log('\n' + class_table_data.get_string(), logger=logger)
return metrics
该方法用于从处理后的结果中计算最终的评估指标。首先检查是否只是格式化结果,如果是则记录日志并返回空字典。然后将结果按类别合并,调用 total_area_to_metrics
方法计算各项指标。接着对计算结果进行汇总和格式化,包括计算平均指标、处理类别名称等。最后将每个类别的评估结果以表格形式输出,并返回最终的评估指标字典。
6. intersect_and_union
方法
@staticmethod
def intersect_and_union(pred_label: torch.tensor, label: torch.tensor,
num_classes: int, ignore_index: int):
mask = (label != ignore_index)
pred_label = pred_label[mask]
label = label[mask]
intersect = pred_label[pred_label == label]
area_intersect = torch.histc(
intersect.float(), bins=(num_classes), min=0,
max=num_classes - 1).cpu()
area_pred_label = torch.histc(
pred_label.float(), bins=(num_classes), min=0,
max=num_classes - 1).cpu()
area_label = torch.histc(
label.float(), bins=(num_classes), min=0,
max=num_classes - 1).cpu()
area_union = area_pred_label + area_label - area_intersect
return area_intersect, area_union, area_pred_label, area_label
这是一个静态方法,用于计算预测标签和真实标签之间的交集和并集。首先过滤掉要忽略的索引,然后找出预测标签和真实标签相同的部分作为交集,使用 torch.histc
函数计算每个类别的交集、预测标签和真实标签的面积,最后计算并集面积并返回。
7. total_area_to_metrics
方法
@staticmethod
def total_area_to_metrics(total_area_intersect: np.ndarray,
total_area_union: np.ndarray,
total_area_pred_label: np.ndarray,
total_area_label: np.ndarray,
metrics: List[str] = ['mIoU'],
nan_to_num: Optional[int] = None,
beta: int = 1):
def f_score(precision, recall, beta=1):
score = (1 + beta**2) * (precision * recall) / (
(beta**2 * precision) + recall)
return score
if isinstance(metrics, str):
metrics = [metrics]
allowed_metrics = ['mIoU', 'mDice', 'mFscore', 'Precision', 'Recall', 'PA']
if not set(metrics).issubset(set(allowed_metrics)):
raise KeyError(f'metrics {metrics} is not supported')
all_acc = total_area_intersect.sum() / total_area_label.sum()
ret_metrics = OrderedDict({'aAcc': all_acc})
for metric in metrics:
if metric == 'mIoU':
iou = total_area_intersect / total_area_union
acc = total_area_intersect / total_area_label
ret_metrics['IoU'] = iou
ret_metrics['Acc'] = acc
elif metric == 'mDice':
dice = 2 * total_area_intersect / (
total_area_pred_label + total_area_label)
acc = total_area_intersect / total_area_label
ret_metrics['Dice'] = dice
ret_metrics['Acc'] = acc
elif metric == 'mFscore':
precision = total_area_intersect / total_area_pred_label
recall = total_area_intersect / total_area_label
f_value = torch.tensor([
f_score(x[0], x[1], beta) for x in zip(precision, recall)
])
ret_metrics['Fscore'] = f_value
ret_metrics['Precision'] = precision
ret_metrics['Recall'] = recall
elif metric == 'Precision':
precision = total_area_intersect / total_area_pred_label
ret_metrics['Precision'] = precision
elif metric == 'Recall':
recall = total_area_intersect / total_area_label
ret_metrics['Recall'] = recall
elif metric == 'PA':
pa = total_area_intersect / total_area_label
ret_metrics['PA'] = pa
ret_metrics = {
metric: value.numpy()
for metric, value in ret_metrics.items()
}
if nan_to_num is not None:
ret_metrics = OrderedDict({
metric: np.nan_to_num(metric_value, nan=nan_to_num)
for metric, metric_value in ret_metrics.items()
})
return ret_metrics
这也是一个静态方法,用于根据交集、并集、预测标签面积和真实标签面积计算各项评估指标。首先定义了计算 F1 分数的辅助函数 f_score
,然后检查要计算的指标是否合法。接着计算总体准确率 aAcc
,并根据不同的指标计算相应的评估结果,存储在 ret_metrics
字典中。最后将结果转换为 NumPy 数组,并处理可能的 NaN 值,返回最终的评估指标字典。
综上所述,这段代码实现了一个完整的语义分割评估指标计算类,通过处理预测结果和真实标签,计算多个常用的评估指标,并以表格形式输出每个类别的评估结果。
源代码
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from collections import OrderedDict
from typing import Dict, List, Optional, Sequence
from mmengine.evaluator import BaseMetric
import numpy as np
import torch
from mmengine.dist import is_main_process
from mmengine.logging import MMLogger, print_log
from mmengine.utils import mkdir_or_exist
from PIL import Image
from prettytable import PrettyTable
from mmseg.registry import METRICS
@METRICS.register_module()
class IoUMetric(BaseMetric):
"""IoU evaluation metric.
Args:
ignore_index (int): Index that will be ignored in evaluation.
Default: 255.
iou_metrics (list[str] | str): Metrics to be calculated, the options
includes 'mIoU', 'mDice', 'mFscore' and 'PA'.
nan_to_num (int, optional): If specified, NaN values will be replaced
by the numbers defined by the user. Default: None.
beta (int): Determines the weight of recall in the combined score.
Default: 1.
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
output_dir (str): The directory for output prediction. Defaults to
None.
format_only (bool): Only format result for results commit without
perform evaluation. It is useful when you want to save the result
to a specific format and submit it to the test server.
Defaults to False.
prefix (str, optional): The prefix that will be added in the metric
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Defaults to None.
"""
def __init__(self,
ignore_index: int = 255,
iou_metrics: List[str] = ['mIoU'],
nan_to_num: Optional[int] = None,
beta: int = 1,
collect_device: str = 'cpu',
output_dir: Optional[str] = None,
format_only: bool = False,
prefix: Optional[str] = None,
**kwargs) -> None:
super().__init__(collect_device=collect_device, prefix=prefix)
self.ignore_index = ignore_index
self.metrics = iou_metrics
self.nan_to_num = nan_to_num
self.beta = beta
self.output_dir = output_dir
if self.output_dir and is_main_process():
mkdir_or_exist(self.output_dir)
self.format_only = format_only
def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
"""Process one batch of data and data_samples.
The processed results should be stored in ``self.results``, which will
be used to compute the metrics when all batches have been processed.
Args:
data_batch (dict): A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of outputs from the model.
"""
num_classes = len(self.dataset_meta['classes'])
for data_sample in data_samples:
pred_label = data_sample['pred_sem_seg']['data'].squeeze()
# format_only always for test dataset without ground truth
if not self.format_only:
label = data_sample['gt_sem_seg']['data'].squeeze().to(
pred_label)
self.results.append(
self.intersect_and_union(pred_label, label, num_classes,
self.ignore_index))
# format_result
if self.output_dir is not None:
basename = osp.splitext(osp.basename(
data_sample['img_path']))[0]
png_filename = osp.abspath(
osp.join(self.output_dir, f'{basename}.png'))
output_mask = pred_label.cpu().numpy()
# The index range of official ADE20k dataset is from 0 to 150.
# But the index range of output is from 0 to 149.
# That is because we set reduce_zero_label=True.
if data_sample.get('reduce_zero_label', False):
output_mask = output_mask + 1
output = Image.fromarray(output_mask.astype(np.uint8))
output.save(png_filename)
def compute_metrics(self, results: list) -> Dict[str, float]:
"""Compute the metrics from processed results.
Args:
results (list): The processed results of each batch.
Returns:
Dict[str, float]: The computed metrics. The keys are the names of
the metrics, and the values are corresponding results. The key
mainly includes aAcc, mIoU, mAcc, mDice, mFscore, mPrecision,
mRecall, PA.
"""
logger: MMLogger = MMLogger.get_current_instance()
if self.format_only:
logger.info(f'results are saved to {osp.dirname(self.output_dir)}')
return OrderedDict()
# convert list of tuples to tuple of lists, e.g.
# [(A_1, B_1, C_1, D_1), ..., (A_n, B_n, C_n, D_n)] to
# ([A_1, ..., A_n], ..., [D_1, ..., D_n])
results = tuple(zip(*results))
assert len(results) == 4
total_area_intersect = sum(results[0])
total_area_union = sum(results[1])
total_area_pred_label = sum(results[2])
total_area_label = sum(results[3])
ret_metrics = self.total_area_to_metrics(
total_area_intersect, total_area_union, total_area_pred_label,
total_area_label, self.metrics, self.nan_to_num, self.beta)
class_names = self.dataset_meta['classes']
# summary table
ret_metrics_summary = OrderedDict({
ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2)
for ret_metric, ret_metric_value in ret_metrics.items()
})
metrics = dict()
for key, val in ret_metrics_summary.items():
if key == 'aAcc':
metrics[key] = val
elif key == 'PA':
metrics[key] = val
else:
metrics['m' + key] = val
# each class table
ret_metrics.pop('aAcc', None)
ret_metrics_class = OrderedDict({
ret_metric: np.round(ret_metric_value * 100, 2)
for ret_metric, ret_metric_value in ret_metrics.items()
})
ret_metrics_class.update({'Class': class_names})
ret_metrics_class.move_to_end('Class', last=False)
class_table_data = PrettyTable()
for key, val in ret_metrics_class.items():
class_table_data.add_column(key, val)
print_log('per class results:', logger)
print_log('\n' + class_table_data.get_string(), logger=logger)
return metrics
@staticmethod
def intersect_and_union(pred_label: torch.tensor, label: torch.tensor,
num_classes: int, ignore_index: int):
"""Calculate Intersection and Union.
Args:
pred_label (torch.tensor): Prediction segmentation map
or predict result filename. The shape is (H, W).
label (torch.tensor): Ground truth segmentation map
or label filename. The shape is (H, W).
num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation.
Returns:
torch.Tensor: The intersection of prediction and ground truth
histogram on all classes.
torch.Tensor: The union of prediction and ground truth histogram on
all classes.
torch.Tensor: The prediction histogram on all classes.
torch.Tensor: The ground truth histogram on all classes.
"""
mask = (label != ignore_index)
pred_label = pred_label[mask]
label = label[mask]
intersect = pred_label[pred_label == label]
area_intersect = torch.histc(
intersect.float(), bins=(num_classes), min=0,
max=num_classes - 1).cpu()
area_pred_label = torch.histc(
pred_label.float(), bins=(num_classes), min=0,
max=num_classes - 1).cpu()
area_label = torch.histc(
label.float(), bins=(num_classes), min=0,
max=num_classes - 1).cpu()
area_union = area_pred_label + area_label - area_intersect
return area_intersect, area_union, area_pred_label, area_label
@staticmethod
def total_area_to_metrics(total_area_intersect: np.ndarray,
total_area_union: np.ndarray,
total_area_pred_label: np.ndarray,
total_area_label: np.ndarray,
metrics: List[str] = ['mIoU'],
nan_to_num: Optional[int] = None,
beta: int = 1):
"""Calculate evaluation metrics
Args:
total_area_intersect (np.ndarray): The intersection of prediction
and ground truth histogram on all classes.
total_area_union (np.ndarray): The union of prediction and ground
truth histogram on all classes.
total_area_pred_label (np.ndarray): The prediction histogram on
all classes.
total_area_label (np.ndarray): The ground truth histogram on
all classes.
metrics (List[str] | str): Metrics to be evaluated, 'mIoU',
'mDice', 'mFscore' and 'PA'.
nan_to_num (int, optional): If specified, NaN values will be
replaced by the numbers defined by the user. Default: None.
beta (int): Determines the weight of recall in the combined score.
Default: 1.
Returns:
Dict[str, np.ndarray]: per category evaluation metrics,
shape (num_classes, ).
"""
def f_score(precision, recall, beta=1):
"""calculate the f-score value.
Args:
precision (float | torch.Tensor): The precision value.
recall (float | torch.Tensor): The recall value.
beta (int): Determines the weight of recall in the combined
score. Default: 1.
Returns:
[torch.tensor]: The f-score value.
"""
score = (1 + beta**2) * (precision * recall) / (
(beta**2 * precision) + recall)
return score
if isinstance(metrics, str):
metrics = [metrics]
allowed_metrics = ['mIoU', 'mDice', 'mFscore', 'Precision', 'Recall', 'PA']
if not set(metrics).issubset(set(allowed_metrics)):
raise KeyError(f'metrics {metrics} is not supported')
all_acc = total_area_intersect.sum() / total_area_label.sum()
ret_metrics = OrderedDict({'aAcc': all_acc})
for metric in metrics:
if metric == 'mIoU':
iou = total_area_intersect / total_area_union
acc = total_area_intersect / total_area_label
ret_metrics['IoU'] = iou
ret_metrics['Acc'] = acc
elif metric == 'mDice':
dice = 2 * total_area_intersect / (
total_area_pred_label + total_area_label)
acc = total_area_intersect / total_area_label
ret_metrics['Dice'] = dice
ret_metrics['Acc'] = acc
elif metric == 'mFscore':
precision = total_area_intersect / total_area_pred_label
recall = total_area_intersect / total_area_label
f_value = torch.tensor([
f_score(x[0], x[1], beta) for x in zip(precision, recall)
])
ret_metrics['Fscore'] = f_value
ret_metrics['Precision'] = precision
ret_metrics['Recall'] = recall
elif metric == 'Precision':
precision = total_area_intersect / total_area_pred_label
ret_metrics['Precision'] = precision
elif metric == 'Recall':
recall = total_area_intersect / total_area_label
ret_metrics['Recall'] = recall
elif metric == 'PA':
pa = total_area_intersect / total_area_label
ret_metrics['PA'] = pa
ret_metrics = {
metric: value.numpy()
for metric, value in ret_metrics.items()
}
if nan_to_num is not None:
ret_metrics = OrderedDict({
metric: np.nan_to_num(metric_value, nan=nan_to_num)
for metric, metric_value in ret_metrics.items()
})
return ret_metrics