InfoNce损失

系列博客目录



InfoNCE 损失(Information Noise Contrastive Estimation)是一种对比学习的损失函数,用于最大化正样本(匹配样本)之间的互信息,同时将负样本(不匹配样本)拉开距离。它最早由 Oord 等人在 2018 年提出,用于无监督学习的自监督表示学习。InfoNCE 损失通过以下方式运作:

1. 基本原理

InfoNCE 损失的目标是通过最大化互信息来区分匹配和不匹配的样本对。假设我们有一个输入样本(如图像)以及其对应的正样本(如文本描述),InfoNCE 会拉近匹配的图像-文本对的嵌入距离,并将不匹配对的嵌入拉远。为了实现这一点,InfoNCE 损失函数对一对正样本(即匹配对)和多个负样本进行对比,通过多分类交叉熵的形式来估计它们之间的相似度。

2. 数学定义

假设我们有一个图像 ( I ) 和其对应的文本 ( T ),InfoNCE 损失的公式为:

L InfoNCE = − log ⁡ exp ⁡ ( sim ( I , T ) / τ ) ∑ k = 1 K exp ⁡ ( sim ( I , T ~ k ) / τ ) \mathcal{L}_{\text{InfoNCE}} = - \log \frac{\exp(\text{sim}(I, T)/\tau)}{\sum_{k=1}^{K} \exp(\text{sim}(I, \tilde{T}_k)/\tau)} LInfoNCE=log

### infoNCE Loss 的背景与实现 InfoNCE(Information Noise Contrastive Estimation)是一种用于对比学习的损失函数,广泛应用于自监督学习领域。它通过最大化正样本对之间的相似度并最小化负样本对之间的相似度来训练模型[^3]。 #### InfoNCE 损失函数定义 InfoNCE 损失可以表示为以下形式: \[ L_{\text{infoNCE}} = - \mathbb{E}_{(i,j) \sim p(i,j)} \left[ \log \frac{\exp(\text{sim}(z_i, z_j)/\tau)}{\sum_{k=1}^{K}\exp(\text{sim}(z_i, z_k)/\tau)} \right] \] 其中 \(z_i\) 和 \(z_j\) 是两个嵌入向量,\(p(i,j)\) 表示正样本对的概率分布,\(\tau\) 是温度参数控制分布的锐利程度,\(\text{sim}(a,b)\) 通常指余弦相似度或欧几里得距离[^4]。 以下是基于 PyTorch 实现的一个简单版本的 InfoNCE 损失函数: ```python import torch import torch.nn.functional as F def infonce_loss(embedding_a, embedding_b, temperature=0.5): """ Computes the InfoNCE loss between two sets of embeddings. Args: embedding_a (torch.Tensor): First set of embeddings with shape (batch_size, dim). embedding_b (torch.Tensor): Second set of embeddings with shape (batch_size, dim). temperature (float): Temperature parameter for contrastive loss. Returns: torch.Tensor: Scalar value representing the mean InfoNCE loss over batch. """ # Normalize embeddings along dimension 1 (feature space normalization) embedding_a_norm = F.normalize(embedding_a, dim=1) embedding_b_norm = F.normalize(embedding_b, dim=1) # Compute similarity matrix using cosine similarity sim_matrix = torch.matmul(embedding_a_norm, embedding_b_norm.T) / temperature # Construct positive pairs' indices n = embedding_a.size(0) labels = torch.arange(n).to(sim_matrix.device) # Calculate cross-entropy loss treating each row as logits loss = F.cross_entropy(sim_matrix, labels) return loss ``` 上述代码实现了 InfoNCE 损失的核心逻辑,具体解释如下: - 使用 `F.normalize` 对输入特征进行 L2 归一化处理以计算标准化后的余弦相似度矩阵[^5]。 - 构造了一个对角线上的正样本索引作为目标标签,并利用交叉熵损失函数完成优化过程[^6]。 #### 转移学习中的应用 当结合转移学习时,可以通过预训练好的网络提取高级语义特征再送入对比学习框架中进一步微调权重参数从而提升目标任务性能表现[^7]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值