MetricType(helpers文件中的metrics.py)

这段代码实现了一个通用的框架,用于处理不同类型的损失函数和评估指标,特别适用于图神经网络的分类和回归任务。通过 MetricType 枚举类可以灵活地选择合适的损失和评估方法,并根据任务类型动态调整输出维度和评估标准。

from helpers.metrics import MetricType
可以查看ROOT_DIR(helpers文件中的constants.py)-优快云博客

from enum import Enum, auto
from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss, L1Loss
import torch
from typing import NamedTuple, List
from torchmetrics import Accuracy, AUROC, MeanAbsoluteError, MeanSquaredError, F1Score, AveragePrecision
import math
from torch_geometric.data import Data
import numpy as np


class LossesAndMetrics(NamedTuple):
    train_loss: float
    val_loss: float
    test_loss: float
    train_metric: float
    val_metric: float
    test_metric: float

    def get_fold_metrics(self):
        return torch.tensor([self.train_metric, self.val_metric, self.test_metric])


class MetricType(Enum):
    """
        an object for the different metrics
    """
    # classification
    ACCURACY = auto()
    MULTI_LABEL_AP = auto()
    AUC_ROC = auto()

    # regression
    MSE_MAE = auto()

    def apply_metric(self, scores: np.ndarray, target: np.ndarray) -> float:
        if isinstance(scores, np.ndarray):
            scores = torch.from_numpy(scores)
        if isinstance(target, np.ndarray):
            target = torch.from_numpy(target)
        num_classes = scores.size(1)  # target.max().item() + 1
        if self is MetricType.ACCURACY:
            metric = Accuracy(task="multiclass", num_classes=num_classes)
        elif self is MetricType.MULTI_LABEL_AP:
            metric = AveragePrecision(task="multilabel", num_labels=num_classes).to(scores.device)
            result = metric(scores, target.int())
            return result.item()
        elif self is MetricType.MSE_MAE:
            metric = MeanAbsoluteError()
        elif self is MetricType.AUC_ROC:
            metric = AUROC(task="multiclass", num_classes=num_classes)
        else:
            raise ValueError(f'MetricType {self.name} not supported')

        metric = metric.to(scores.device)
        result = metric(scores, target)
        return result.item()

    def is_classification(self) -> bool:
        if self in [MetricType.AUC_ROC, MetricType.ACCURACY, MetricType.MULTI_LABEL_AP]:
            return True
        elif self is MetricType.MSE_MAE:
            return False
        else:
            raise ValueError(f'MetricType {self.name} not supported')

    def is_multilabel(self) -> bool:
        return self is MetricType.MULTI_LABEL_AP

    def get_task_loss(self):
        if self.is_classification():
            if self.is_multilabel():
                return BCEWithLogitsLoss()
            else:
                return CrossEntropyLoss()
        elif self is MetricType.MSE_MAE:
            return MSELoss()
        else:
            raise ValueError(f'MetricType {self.name} not supported')

    def get_out_dim(self, dataset: List[Data]) -> int:
        if self.is_classification():
            if self.is_multilabel():
                return dataset[0].y.shape[1]
            else:
                return int(max([data.y.max().item() for data in dataset]) + 1)
        else:
            return dataset[0].y.shape[-1]

    def higher_is_better(self):
        return self.is_classification()

    def src_better_than_other(self, src: float, other: float) -> bool:
        if self.higher_is_better():
            return src > other
        else:
            return src <
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值