MXNet框架如何自定义evaluation metric

本文介绍了在MXNet框架下如何自定义评价标准(Evaluation Metric),详细讲解了自定义metric的步骤,包括初始化、更新函数的实现,并通过代码示例展示了如何在训练过程中调用自定义的metric。此外,还分享了MXNet官方提供的metric类的使用方法和一些默认metric类的写法。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

使用深度学习框架训练模型的时候都需要用到评价标准,比如准确率等,那么在MXNet框架下,这些评价标准(Evaluation Metric)是怎么实现的呢?如果我们要自定义一个不一样的评价标准要怎么做?一起来了解下吧。

首先来看看在MXNet框架下关于evaluation metric的最基本的类和脚本。mxnet.metric.EvalMetric是MXNet框架中计算评价标准(evaluation metric)的基础类,这个基础类是在MXNet项目的/incubator-mxnet/python/mxnet/metric.py中定义的。metric.py脚本中不仅包含类基础类的定义,还包含MXNet所有和metric相关的类的定义,所以如果你想更深入了解的话,可以看看这个脚本。

先来介绍下如何自定义一个evaluation metric。作者在文档中说过,这个类没有办法直接使用,当你要用的时候,应该要先自定义一个继承该基础类的类,将你想实现的评价标准写在里面,然后再在你的训练代码中调用该类。先看一个简单的定义一个evaluation metric的例子:

import mxnet as mx
class Accuracy(mx.metric.EvalMetric): # 在定义类名称的时候,括号里面表示继承哪个类
    def __init__(self, num=None):
        super(Accuracy, self).__init__('accuracy', num)

    def update(self, labels, preds):
        pred_label = mx.nd.argmax_channel(preds[0]).asnumpy().astype('int32')
        label = labels[0].asnumpy().astype('int32')

        mx.metric.check_label_shapes(label, pred_label)

        self.sum_metric += (pred_label.flat == label.flat).sum()
        self.num_inst += len(pred_label.flat)

在这个自定义类中主要包含一个__init__ 函数和一个update函数,前者是用来初始化,后者是用来更新metric,比如说你定义每计算一个batch的样本,就更新一次metric。updata函数的输入中,labels和preds都是一个NDArray的列表。在update函数中有个mx.metric.check_label_shaoes函数,这个函数的定义也是在MXNet项目的/incubator-mxnet/python/mxnet/metric.py脚本中,是用来判断labels和preds的shape是否一致,因为labels和preds都是list,而且一般这个list中只包含一个NDArray,比如说你的batch size是16,类别数是1000,那么labels中的NDArray就是16*1,preds中的NDArray就是16*1000。

以上这些代码可以写在类似名字为my_metric.py脚本中,那么怎么调用了?在你的训练代码中,可以通过mxnet.metric.CompositeEvalMetric类来调用,这个类是用来管理你的evaluation metrics的,这个类的定义也是在MXNet项目的/incubator-mxnet/python/mxnet/metric.py脚本中。先来看看官方的关于使用这个类的例子:

>>> predicts = [mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6]])]
>>> labels   = [mx.nd.array([0, 1, 1])]
>>> eval_metrics_1 = mx.metric.Accuracy()
>>> eval_metrics_2 = mx.metric.F1()
>>> eval_metrics = mx.metric.CompositeEvalMetric()
>>> for child_metric in [eval_metrics_1, eval_metrics_2]:
>>>     eval_metrics.add(child_metric)
>>> eval_metrics.update(labels = labels, preds = predicts) >>> print eval_metrics.get()
    (['accuracy', 'f1'], [0.6666666666666666, 0.8])

这里mx.metric.Accuracy()和mx.metric.F1()就是定义的metric类的实现,当你要调用自定义的metric时,主要用到这个·CompositeEvalMetric·类的add方法,用一次add方法就能够增加一种评价方式,从而在训练界面显示,而最后的update方法是手动更新计算结果,一般我们仅需下面两行定义一个eval_metric即可,然后把eval_metric作为训练的fit函数的一个参数即可:

eval_metric = mx.metric.CompositeEvalMetric()
eval_metric.add(Accuracy())

在代码运行到要更新metric的时候(更具体的是base_model.py中的self.metric(eval_metric, data_batch.label))这一行,然后会先调用mxnet.metric.CompositeEvalMetric()类的update_dict()函数,然后调用基类mxnet.metric.EvalMetricupdate_dict()函数,

上面说完了怎么自定义一个evaluation metric类以及如何调用自定义的metric类,那么下面就看看官方的/incubator-mxnet/python/mxnet/metric.py脚本中默认的一些metric类的写法是什么样的吧。非常喜欢MXNet代码的一个原因就是其代码注释和文档做得很好,看起来一目了然。

class Accuracy(EvalMetric): # 这里也是继承的mxnet.metric.EvalMetric这个类,只不过因为这两个类在同一个脚本中定义,因此就可以直接写EvalMetric
    """Computes accuracy classification score.

    The accuracy score is defined as

    .. math::
        \\text{accuracy}(y, \\hat{y}) = \\frac{1}{n} \\sum_{i=0}^{n-1}
        \\text{1}(\\hat{y_i} == y_i)

    Parameters
    ----------
    axis : int, default=1
        The axis that represents classes
    name : str
        Name of this metric instance for display.
    output_names : list of str, or None
        Name of predictions that should be used when updating with update_dict.
        By default include all predictions.
    label_names : list of str, or None
        Name of labels that should be used when updating with update_dict.
        By default include all labels.

    Examples
    --------
    >>> predicts = [mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6]])]
    >>> labels   = [mx.nd.array([0, 1, 1])]
    >>> acc = mx.metric.Accuracy()
    >>> acc.update(preds = predicts, labels = labels)
    >>> print acc.get()
    ('accuracy', 0.6666666666666666)
    """
    def __init__(self, axis=1, name='accuracy',
                 output_names=None, label_names=None):
 # super这个函数是调用基类mx.metric.EvalMetric的__init__函数,__init__函数括号中的变量是要传递给基类的__init__函数的变量。super()括号中的Accuracy表示类名称。
        super(Accuracy, self).__init__(
            name, axis=axis,
            output_names=output_names, label_names=label_names)
        self.axis = axis

    def update(self, labels, preds):
        """Updates the internal evaluation result.

        Parameters
        ----------
        labels : list of `NDArray`
            The labels of the data.

        preds : list of `NDArray`
            Predicted values.
        """
        check_label_shapes(labels, preds)

        for label, pred_label in zip(labels, preds): # zip函数可以将输入的两个list的对应位置的值变成一个元组(tuple),这样每个tuple就包含两个值,这两个值在这里都是NDArray格式。又因为pred_label的shape和label的shape是不一样的,所以都会进入下面这个if语句,也就是先将pred_label按行求出最大值所在的index,然后pred_label就和label是相同shape的NDArray了。
            if pred_label.shape != label.shape:
                pred_label = ndarray.argmax(pred_label, axis=self.axis)

# 先用asnumpy()方法将NDArray转换成numpy.ndarray,然后把数值转成32位int型,原来是浮点型
            pred_label = pred_label.asnumpy().astype('int32')
            label = label.asnumpy().astype('int32')

            check_label_shapes(label, pred_label)

# (pred_label.flat == label.flat) 会返回numpy.ndarray格式,内容是false或者true,表示相等或不相等,最后求一个和
            self.sum_metric += (pred_label.flat == label.flat).sum()
            self.num_inst += len(pred_label.flat)

再看一个复杂一点的例子,也是官方的用于计算multi label的accuracy的类:

class Multi_Accuracy(mx.metric.EvalMetric):
    """Calculate accuracies of multi label"""

    def __init__(self, num=None):
        self.num = num
# super调用基类的__init__函数进行初始化,__init__其中还会调用reset方法,如果你没有重写reset方法,那就默认调用基类的reset方法
        super(Multi_Accuracy, self).__init__('multi-accuracy')


# 这里重写了基类的reset方法,因为原来基类的reset方法并不考虑self.num_inst和self.sum_metric是list的情况,因此基类里面的reset方法只是两行:self.num_inst = 0 和 self.sum_metric = 0.0。
    def reset(self):
        """Resets the internal evaluation result to initial state."""
        self.num_inst = 0 if self.num is None else [0] * self.num
        self.sum_metric = 0.0 if self.num is None else [0.0] * self.num


# update方法就是你要如何计算metric的实现
    def update(self, labels, preds):
        mx.metric.check_label_shapes(labels, preds)

        if self.num is not None:
            assert len(labels) == self.num

        for i in range(len(labels)):
            pred_label = mx.nd.argmax_channel(preds[i]).asnumpy().astype('int32')
            label = labels[i].asnumpy().astype('int32')

            mx.metric.check_label_shapes(label, pred_label)

            if self.num is None:
                self.sum_metric += (pred_label.flat == label.flat).sum()
                self.num_inst += len(pred_label.flat)
            else:
                self.sum_metric[i] += (pred_label.flat == label.flat).sum()
                self.num_inst[i] += len(pred_label.flat)


# get方法是计算当前的evaluation result。比如说你设定每计算20个batch后就计算一次准确率等结果,那么这个结果的计算就是调用这个get方法
    def get(self):
        """Gets the current evaluation result.

        Returns
        -------
        names : list of str
           Name of the metrics.
        values : list of float
           Value of the evaluations.
        """
        if self.num is None:
            return super(Multi_Accuracy, self).get()
        else:
            return zip(*(('%s-task%d'%(self.name, i), float('nan') if self.num_inst[i] == 0
                                                      else self.sum_metric[i] / self.num_inst[i])
                       for i in range(self.num)))

    def get_name_value(self):
        """Returns zipped name and value pairs.

        Returns
        -------
        list of tuples
            A (name, value) tuple list.
        """
        if self.num is None:
            return super(Multi_Accuracy, self).get_name_value()
        name, value = self.get()
        return list(zip(name, value))
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值