本博客内容来自B站up主【王木头学科学】的视频内容
习惯看视频的小伙伴可移至视频链接[待补充]:~~~
首先通俗地解释一下极大似然估计(Maximum Likelihood Estimation,MLE)的思想:通过结果寻找使该结果发生的最可能的原因。
对于训练模型来说,模型结构已经预先定义好,我们具体训练的是模型的参数。例如我们构建了一个ResNet-50,想基于手写数字数据集训练一个能识别手写数字的模型。在ResNet-50这个架构下,对应的模型参数有很多种可能,现在我们有手写数字图像以及对应的真实分类标签(这是我们人类的判断结果),训练的目标就是找到能使得手写数字分类结果正确的那一套模型参数(原因)。
似然值:真实情况已经发生,在某个模型下这种情况发生的可能性。所谓极大似然估计就是最大化似然值。
那极大似然估计是如何跟损失函数联系起来的呢?
以二分类任务——以给定图像判断是不是猫为例,对于每一张图像,我们人类会有一个判断(即标签)
,其中
表示第
张图片,
表示该图像是猫,0表示不是猫。
如果我们想用极大似然估计的话,似然值可通过如下公式计算:
其中表示模型的参数。公式具体的含义是
参数下模型能正确判断是不是猫的可能性,我们要找到使这个可能性最大的模型参数。
为什么可以写成相乘的形式,是因为我们假设每个样本是独立同分布的。
给定一张图像,模型会有一个预测结果,这个预测结果是基于模型参数做出的,因此一定程度上隐含了模型参数,上式可进一步写为:
由于是二分类,只有0和1两种情况,所以符合伯努利分布。
伯努利分布知识补充
其中表示
的概率。
根据伯努利分布重写上述公式得到
其中表示模型预测当前图像是猫的概率。
使用log将连乘操作转变为连加操作:
回顾一下我们的目的,我们是要求极大似然值,也就是上式的最大值(log不改变单调性):
但是为了方便优化,我们更倾向于求最小值,所以加一个符号变成如下公式:
看到这里,有的小伙伴可能就发现了:这不就是交叉熵损失么!!!
确实,形式上看这就是交叉熵损失,但我们是通过最大化似然一步一步推导出来了。事实上,交叉熵是极大似然估计在二分类场景中的特例。
后续我们会从熵和信息的角度推导交叉熵损失~