这段代码实现了一个通用的框架,用于处理不同类型的损失函数和评估指标,特别适用于图神经网络的分类和回归任务。通过 MetricType
枚举类可以灵活地选择合适的损失和评估方法,并根据任务类型动态调整输出维度和评估标准。
from helpers.metrics import MetricType
可以查看ROOT_DIR(helpers文件中的constants.py)-优快云博客
from enum import Enum, auto
from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss, L1Loss
import torch
from typing import NamedTuple, List
from torchmetrics import Accuracy, AUROC, MeanAbsoluteError, MeanSquaredError, F1Score, AveragePrecision
import math
from torch_geometric.data import Data
import numpy as np
class LossesAndMetrics(NamedTuple):
train_loss: float
val_loss: float
test_loss: float
train_metric: float
val_metric: float
test_metric: float
def get_fold_metrics(self):
return torch.tensor([self.train_metric, self.val_metric, self.test_metric])
class MetricType(Enum):
"""
an object for the different metrics
"""
# classification
ACCURACY = auto()
MULTI_LABEL_AP = auto()
AUC_ROC = auto()
# regression
MSE_MAE = auto()
def apply_metric(self, scores: np.ndarray, target: np.ndarray) -> float:
if isinstance(scores, np.ndarray):
scores = torch.from_numpy(scores)
if isinstance(target, np.ndarray):
target = torch.from_numpy(target)
num_classes = scores.size(1) # target.max().item() + 1
if self is MetricType.ACCURACY:
metric = Accuracy(task="multiclass", num_classes=num_classes)
elif self is MetricType.MULTI_LABEL_AP:
metric = AveragePrecision(task="multilabel", num_labels=num_classes).to(scores.device)
result = metric(scores, target.int())
return result.item()
elif self is MetricType.MSE_MAE:
metric = MeanAbsoluteError()
elif self is MetricType.AUC_ROC:
metric = AUROC(task="multiclass", num_classes=num_classes)
else:
raise ValueError(f'MetricType {self.name} not supported')
metric = metric.to(scores.device)
result = metric(scores, target)
return result.item()
def is_classification(self) -> bool:
if self in [MetricType.AUC_ROC, MetricType.ACCURACY, MetricType.MULTI_LABEL_AP]:
return True
elif self is MetricType.MSE_MAE:
return False
else:
raise ValueError(f'MetricType {self.name} not supported')
def is_multilabel(self) -> bool:
return self is MetricType.MULTI_LABEL_AP
def get_task_loss(self):
if self.is_classification():
if self.is_multilabel():
return BCEWithLogitsLoss()
else:
return CrossEntropyLoss()
elif self is MetricType.MSE_MAE:
return MSELoss()
else:
raise ValueError(f'MetricType {self.name} not supported')
def get_out_dim(self, dataset: List[Data]) -> int:
if self.is_classification():
if self.is_multilabel():
return dataset[0].y.shape[1]
else:
return int(max([data.y.max().item() for data in dataset]) + 1)
else:
return dataset[0].y.shape[-1]
def higher_is_better(self):
return self.is_classification()
def src_better_than_other(self, src: float, other: float) -> bool:
if self.higher_is_better():
return src > other
else:
return src <