交叉熵损失函数(Cross-Entropy Loss) 是深度学习中常用的损失函数之一,尤其在分类任务中广泛应用。它的作用是衡量模型预测的概率分布与真实标签分布之间的差异,从而指导模型优化。以下是对交叉熵损失函数的详细解释:
1. 交叉熵损失函数的定义
对于分类问题,假设:
-
真实标签是一个 one-hot 向量(例如
[0, 0, 1, 0]
)。 -
模型预测的是一个 概率分布(例如
[0.1, 0.2, 0.6, 0.1]
)。
交叉熵损失函数的公式为:
CrossEntropy(y,y^)=−∑iyilog(y^i)CrossEntropy(y,y^)=−i∑yilog(y^i)
其中:
-
yiyi 是真实标签的第 ii 个值(0 或 1)。
-
y^iy^i 是模型预测的第 ii 个类别的概率。
2. 交叉熵损失函数的作用
(1)衡量预测与真实的差异
-
交叉熵损失函数的值越小,表示模型预测的概率分布与真实标签分布越接近。
-
当预测完全正确时(即预测概率分布与真实标签分布完全一致),交叉熵损失为 0。
(2)指导模型优化
-
在训练过程中,模型通过最小化交叉熵损失来调整参数,使得预测结果更接近真实标签。
-
通过反向传播算法,交叉熵损失的梯度会指导模型更新权重。
(3)处理多分类问题
-
交叉熵损失函数天然适合多分类任务,因为它可以同时考虑多个类别的概率分布。
4. 具体例子
假设批次大小为 3,labels = [0, 1, 2]
,表示:
-
第 0 个
query
的正样本是第 0 个title
。 -
第 1 个
query
的正样本是第 1 个title
。 -
第 2 个
query
的正样本是第 2 个title
。
余弦相似度矩阵经过 Softmax 后可能如下:
复制
[[0.7, 0.2, 0.1], # 第 0 个 query 与 3 个 title 的相似度 [0.1, 0.6, 0.3], # 第 1 个 query 与 3 个 title 的相似度 [0.2, 0.3, 0.5]] # 第 2 个 query 与 3 个 title 的相似度
-
对角线元素(
0.7, 0.6, 0.5
)是正样本对的概率。 -
非对角线元素是负样本对的概率。
交叉熵损失函数会计算:
CrossEntropy=−(log(0.7)+log(0.6)+log(0.5))CrossEntropy=−(log(0.7)+log(0.6)+log(0.5))
模型的目标是最小化这个损失值,即提高正样本对的概率,降低负样本对的概率。
5. 总结
交叉熵损失函数的作用是:
-
衡量模型预测的概率分布与真实标签分布之间的差异。
-
指导模型通过最小化损失值来优化参数。
-
在匹配任务中,最大化正样本对的相似度,同时最小化负样本对的相似度。