掩码图像建模 (MIM) 中的对数似然与交叉熵

掩码图像建模 (MIM) 中的对数似然与交叉熵

1. 问题背景

在掩码图像建模(MIM)任务中,模型需要预测被遮蔽的图像块对应的视觉词元(可以理解为图像块的离散类别标签)。

具体来说:

  • 每个被遮蔽的图像块 i∈Mi \in MiM 的真实标签是 ziz_izi(即它原本的视觉词元类别)。
  • 模型通过 Transformer 编码器生成隐藏向量 hLih_L^ihLi,然后通过一个分类器(参数为 Wc,bcW_c, b_cWc,bc)预测该位置的概率分布 pMIM(z′∣xM)p_{\text{MIM}}(z' | x^M)pMIM(zxM)

2. Softmax 分类器的作用

分类器的公式是:
pMIM(z′∣xM)=softmaxz(WchLi+bc)p_{\text{MIM}}(z' | x^M) = \text{softmax}_z(W_c h_L^i + b_c)pMIM(zxM)=softmaxz(WchLi+bc)

  • 输入:隐藏向量 hLi∈RDh_L^i \in \mathbb{R}^DhLiRD(来自 Transformer 的输出)。
  • 参数:权重矩阵 Wc∈R∣V∣×DW_c \in \mathbb{R}^{|\mathcal{V}| \times D}WcRV×D 和偏置 bc∈R∣V∣b_c \in \mathbb{R}^{|\mathcal{V}|}bcRV,其中 ∣V∣|\mathcal{V}|V 是视觉词元的总类别数。
  • 输出:一个概率分布,表示模型认为被遮蔽块 iii 属于每个视觉词元类别的概率。

具体计算步骤

  1. 对每个被遮蔽位置 iii,计算线性变换:WchLi+bcW_c h_L^i + b_cWchLi+bc,得到一个长度为 ∣V∣|\mathcal{V}|V 的向量(称为logits)。
  2. 对 logits 应用 softmax 函数,将其转换为概率分布:
    p(z′)=exp⁡(logits[z′])∑k=1∣V∣exp⁡(logits[k])p(z') = \frac{\exp(\text{logits}[z'])}{\sum_{k=1}^{|\mathcal{V}|} \exp(\text{logits}[k])}p(z)=k=1Vexp(logits[k])exp(logits[z])
    其中 z′z'z 是某个可能的视觉词元类别。

3. 最大化对数似然(Maximize Log-Likelihood)

目标:让模型对真实标签 ziz_izi 的预测概率尽可能高。

数学表达:
max⁡θEx∼D[∑i∈Mlog⁡pMIM(zi∣xM)]\max_{\theta} \mathbb{E}_{x \sim \mathcal{D}} \left[ \sum_{i \in M} \log p_{\text{MIM}}(z_i | x^M) \right]θmaxExD[iMlogpMIM(zixM)]

  • 解释
    • 对每个被遮蔽位置 iii,计算真实标签 ziz_izi 的对数概率 log⁡pMIM(zi∣xM)\log p_{\text{MIM}}(z_i | x^M)logpMIM(zixM)
    • 对所有被遮蔽位置求和,再对所有训练样本 xxx 求期望。
    • 目标是最大化这个总和,即让模型对真实标签的预测概率尽可能大。

4. 交叉熵损失(Cross-Entropy Loss)

交叉熵损失是分类任务中常用的损失函数,定义为:
LCE=−∑i∈Mlog⁡pMIM(zi∣xM)\mathcal{L}_{\text{CE}} = - \sum_{i \in M} \log p_{\text{MIM}}(z_i | x^M)LCE=iMlogpMIM(zixM)

  • 解释
    • 对每个被遮蔽位置 iii,计算真实标签 ziz_izi 的负对数概率。
    • 对所有被遮蔽位置求和,得到总损失。
    • 目标是最小化这个损失,即让真实标签的预测概率尽可能高。

5. 最大化对数似然 vs. 最小化交叉熵

关键结论
最大化对数似然最小化交叉熵损失完全等价的!

具体来说:
max⁡θ∑i∈Mlog⁡pMIM(zi∣xM)  ⟺  min⁡θ(−∑i∈Mlog⁡pMIM(zi∣xM))\max_{\theta} \sum_{i \in M} \log p_{\text{MIM}}(z_i | x^M) \quad \iff \quad \min_{\theta} \left( - \sum_{i \in M} \log p_{\text{MIM}}(z_i | x^M) \right)θmaxiMlogpMIM(zixM)θmin(iMlogpMIM(zixM))

  • 左边是最大化对数似然(使正确标签的概率最大化)。
  • 右边是最小化交叉熵损失(使正确标签的负对数概率最小化)。

6. 为什么等价?

  • 数学本质:交叉熵损失是负的对数似然。
    • 对数似然是 ∑log⁡p\sum \log plogp,交叉熵是 −∑log⁡p-\sum \log plogp
    • 最大化 AAA 等价于最小化 −A-AA
  • 直观理解
    • 如果模型对真实标签的预测概率 p(zi)p(z_i)p(zi) 越大,对数似然 log⁡p(zi)\log p(z_i)logp(zi) 越大,交叉熵损失 −log⁡p(zi)-\log p(z_i)logp(zi) 越小。
    • 例如,若真实标签的概率 p(zi)=0.9p(z_i) = 0.9p(zi)=0.9,则交叉熵损失为 −log⁡(0.9)≈0.11-\log(0.9) \approx 0.11log(0.9)0.11
      若概率 p(zi)=0.1p(z_i) = 0.1p(zi)=0.1,则损失为 −log⁡(0.1)≈2.30-\log(0.1) \approx 2.30log(0.1)2.30
      显然,概率越大,损失越小。

7. 实际训练中的计算

在代码中,通常直接使用交叉熵损失函数(如 PyTorch 的 CrossEntropyLoss):

# 假设 logits 是模型的输出(未经过 softmax)
# targets 是被遮蔽位置的真实视觉词元标签
loss = F.cross_entropy(logits, targets)
  • 内部过程
    1. 对 logits 应用 softmax,得到概率分布。
    2. 计算真实标签的负对数概率。
    3. 对所有样本和位置求平均,得到最终损失。

总结

  • 目标:让模型对真实标签的预测概率尽可能高。
  • 数学实现:通过最大化对数似然(等价于最小化交叉熵损失)。
  • 代码实现:直接使用交叉熵损失函数,无需手动计算对数似然。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

frostmelody

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值