掩码图像建模 (MIM) 中的对数似然与交叉熵
1. 问题背景
在掩码图像建模(MIM)任务中,模型需要预测被遮蔽的图像块对应的视觉词元(可以理解为图像块的离散类别标签)。
具体来说:
- 每个被遮蔽的图像块 i∈Mi \in Mi∈M 的真实标签是 ziz_izi(即它原本的视觉词元类别)。
- 模型通过 Transformer 编码器生成隐藏向量 hLih_L^ihLi,然后通过一个分类器(参数为 Wc,bcW_c, b_cWc,bc)预测该位置的概率分布 pMIM(z′∣xM)p_{\text{MIM}}(z' | x^M)pMIM(z′∣xM)。
2. Softmax 分类器的作用
分类器的公式是:
pMIM(z′∣xM)=softmaxz(WchLi+bc)p_{\text{MIM}}(z' | x^M) = \text{softmax}_z(W_c h_L^i + b_c)pMIM(z′∣xM)=softmaxz(WchLi+bc)
- 输入:隐藏向量 hLi∈RDh_L^i \in \mathbb{R}^DhLi∈RD(来自 Transformer 的输出)。
- 参数:权重矩阵 Wc∈R∣V∣×DW_c \in \mathbb{R}^{|\mathcal{V}| \times D}Wc∈R∣V∣×D 和偏置 bc∈R∣V∣b_c \in \mathbb{R}^{|\mathcal{V}|}bc∈R∣V∣,其中 ∣V∣|\mathcal{V}|∣V∣ 是视觉词元的总类别数。
- 输出:一个概率分布,表示模型认为被遮蔽块 iii 属于每个视觉词元类别的概率。
具体计算步骤:
- 对每个被遮蔽位置 iii,计算线性变换:WchLi+bcW_c h_L^i + b_cWchLi+bc,得到一个长度为 ∣V∣|\mathcal{V}|∣V∣ 的向量(称为logits)。
- 对 logits 应用 softmax 函数,将其转换为概率分布:
p(z′)=exp(logits[z′])∑k=1∣V∣exp(logits[k])p(z') = \frac{\exp(\text{logits}[z'])}{\sum_{k=1}^{|\mathcal{V}|} \exp(\text{logits}[k])}p(z′)=∑k=1∣V∣exp(logits[k])exp(logits[z′])
其中 z′z'z′ 是某个可能的视觉词元类别。
3. 最大化对数似然(Maximize Log-Likelihood)
目标:让模型对真实标签 ziz_izi 的预测概率尽可能高。
数学表达:
maxθEx∼D[∑i∈MlogpMIM(zi∣xM)]\max_{\theta} \mathbb{E}_{x \sim \mathcal{D}} \left[ \sum_{i \in M} \log p_{\text{MIM}}(z_i | x^M) \right]θmaxEx∼D[i∈M∑logpMIM(zi∣xM)]
- 解释:
- 对每个被遮蔽位置 iii,计算真实标签 ziz_izi 的对数概率 logpMIM(zi∣xM)\log p_{\text{MIM}}(z_i | x^M)logpMIM(zi∣xM)。
- 对所有被遮蔽位置求和,再对所有训练样本 xxx 求期望。
- 目标是最大化这个总和,即让模型对真实标签的预测概率尽可能大。
4. 交叉熵损失(Cross-Entropy Loss)
交叉熵损失是分类任务中常用的损失函数,定义为:
LCE=−∑i∈MlogpMIM(zi∣xM)\mathcal{L}_{\text{CE}} = - \sum_{i \in M} \log p_{\text{MIM}}(z_i | x^M)LCE=−i∈M∑logpMIM(zi∣xM)
- 解释:
- 对每个被遮蔽位置 iii,计算真实标签 ziz_izi 的负对数概率。
- 对所有被遮蔽位置求和,得到总损失。
- 目标是最小化这个损失,即让真实标签的预测概率尽可能高。
5. 最大化对数似然 vs. 最小化交叉熵
关键结论:
最大化对数似然和最小化交叉熵损失是完全等价的!
具体来说:
maxθ∑i∈MlogpMIM(zi∣xM) ⟺ minθ(−∑i∈MlogpMIM(zi∣xM))\max_{\theta} \sum_{i \in M} \log p_{\text{MIM}}(z_i | x^M) \quad \iff \quad \min_{\theta} \left( - \sum_{i \in M} \log p_{\text{MIM}}(z_i | x^M) \right)θmaxi∈M∑logpMIM(zi∣xM)⟺θmin(−i∈M∑logpMIM(zi∣xM))
- 左边是最大化对数似然(使正确标签的概率最大化)。
- 右边是最小化交叉熵损失(使正确标签的负对数概率最小化)。
6. 为什么等价?
- 数学本质:交叉熵损失是负的对数似然。
- 对数似然是 ∑logp\sum \log p∑logp,交叉熵是 −∑logp-\sum \log p−∑logp。
- 最大化 AAA 等价于最小化 −A-A−A。
- 直观理解:
- 如果模型对真实标签的预测概率 p(zi)p(z_i)p(zi) 越大,对数似然 logp(zi)\log p(z_i)logp(zi) 越大,交叉熵损失 −logp(zi)-\log p(z_i)−logp(zi) 越小。
- 例如,若真实标签的概率 p(zi)=0.9p(z_i) = 0.9p(zi)=0.9,则交叉熵损失为 −log(0.9)≈0.11-\log(0.9) \approx 0.11−log(0.9)≈0.11;
若概率 p(zi)=0.1p(z_i) = 0.1p(zi)=0.1,则损失为 −log(0.1)≈2.30-\log(0.1) \approx 2.30−log(0.1)≈2.30。
显然,概率越大,损失越小。
7. 实际训练中的计算
在代码中,通常直接使用交叉熵损失函数(如 PyTorch 的 CrossEntropyLoss
):
# 假设 logits 是模型的输出(未经过 softmax)
# targets 是被遮蔽位置的真实视觉词元标签
loss = F.cross_entropy(logits, targets)
- 内部过程:
- 对 logits 应用 softmax,得到概率分布。
- 计算真实标签的负对数概率。
- 对所有样本和位置求平均,得到最终损失。
总结
- 目标:让模型对真实标签的预测概率尽可能高。
- 数学实现:通过最大化对数似然(等价于最小化交叉熵损失)。
- 代码实现:直接使用交叉熵损失函数,无需手动计算对数似然。