详解常用的对比学习损失

对比学习损失函数用于在无监督或半监督的情况下学习数据表示,使得相似的数据样本在表示空间中更加接近,而不相似的样本更远离。以下是几种常见的对比学习损失函数及其详细说明:

一、对比损失(Contrastive Loss)

对比损失用于使得正样本对(相似样本对)在表示空间中接近,而负样本对(不相似样本对)远离。

1、公式

\[ L = \frac{1}{2N} \sum_{i=1}^{N} \left( y_i \cdot D_i^2 + (1 - y_i) \cdot \max(margin - D_i, 0)^2 \right) \]

其中:
\( y_i \) 是标签,1 表示正样本对,0 表示负样本对。
\( D_i \) 是样本对的欧氏距离。
\( margin \) 是一个超参数,表示负样本对之间的最小距离。

2、代码实现(PyTorch)

import torch
import torch.nn.functional as F

class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2)
        loss = torch.mean((1 - label) * torch.pow(euclidean_distance, 2) +
                          label * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
        return loss

二、三元组损失(Triplet Loss)

三元组损失用于训练模型使得锚点样本(Anchor)和正样本(Positive)之间的距离小于锚点样本和负样本(Negative)之间的距离。

1、公式

\[ L = \sum_{i=1}^{N} \left[ \|f(x_i^a) - f(x_i^p)\|_2^2 - \|f(x_i^a) - f(x_i^n)\|_2^2 + \alpha \right]_+ \]

其中:
- \( x_i^a \) 是锚点样本。
- \( x_i^p \) 是正样本。
- \( x_i^n \) 是负样本。
- \( \alpha \) 是一个超参数,表示正负样本对之间的最小距离差。

2、代码实现(PyTorch)

import torch
import torch.nn.functional as F

class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):
        pos_distance = F.pairwise_distance(anchor, positive)
        neg_distance = F.pairwise_distance(anchor, negative)
        loss = torch.mean(F.relu(pos_distance - neg_distance + self.margin))
        return loss

三、信息论对比损失(InfoNCE Loss)

InfoNCE 损失常用于自监督学习,通过最大化正样本对之间的相似度,同时最小化正样本对和负样本对之间的相似度。

1、公式

\[ L = - \frac{1}{N} \sum_{i=1}^{N} \log \frac{\exp(f(x_i) \cdot f(x_i^+))}{\exp(f(x_i) \cdot f(x_i^+)) + \sum_{j=1}^{K} \exp(f(x_i) \cdot f(x_j^-))} \]

其中:
- \( f(x_i) \) 是样本 \( x_i \) 的表示。
- \( x_i^+ \) 是正样本。
- \( x_j^- \) 是负样本。
- \( K \) 是负样本的数量。

2、代码实现(PyTorch)

import torch
import torch.nn.functional as F

class InfoNCELoss(nn.Module):
    def __init__(self, temperature=0.1):
        super(InfoNCELoss, self).__init__()
        self.temperature = temperature

    def forward(self, features, labels):
        batch_size = features.size(0)
        labels = labels.contiguous().view(-1, 1)
        mask = torch.eq(labels, labels.T).float()

        contrast_feature = torch.cat(torch.unbind(features, dim=0), dim=0)
        anchor_feature = contrast_feature

        anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature.T),
            self.temperature)

        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        mask = mask.repeat(batch_size, 1)
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size * 2).view(-1, 1).cuda(),
            0
        )
        mask = mask * logits_mask

        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

        mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)

        loss = - (self.temperature / 0.07) * mean_log_prob_pos
        loss = loss.view(batch_size, 2).mean()

        return loss

四、 互信息最大化损失(Mutual Information Maximization Loss)

这种损失用于最大化全局表示和局部表示之间的互信息,常用于图像或图数据。

1、公式

\[ L = - \frac{1}{N} \sum_{i=1}^{N} \left[ \log \frac{\exp(f(x_i) \cdot f(g_i))}{\sum_{j=1}^{N} \exp(f(x_i) \cdot f(g_j))} \right] \]

其中:
- \( f(x_i) \) 是样本 \( x_i \) 的局部表示。
- \( f(g_i) \) 是样本 \( x_i \) 的全局表示。

2、代码实现(PyTorch)

import torch
import torch.nn.functional as F

class MutualInformationLoss(nn.Module):
    def __init__(self):
        super(MutualInformationLoss, self).__init__()

    def forward(self, local_features, global_features):
        batch_size = local_features.size(0)
        scores = torch.matmul(local_features, global_features.T)

        labels = torch.arange(batch_size).cuda()
        loss = F.cross_entropy(scores, labels)

        return loss

这些对比学习损失函数在不同的任务和数据集上有不同的效果,可以根据具体需求进行选择和调整。

### 三元组的概念及应用场景 #### 1. 细粒度情感分析中的三元组 在细粒度情感分析(ABSA, Aspect-Based Sentiment Analysis)领域,三元组通常由 **aspect**、**opinion** 和它们之间的关系构成。具体来说,aspect 是指评论的目标实体或属性,opinion 则是对该 aspect 的主观评价[^1]。这种三元组形式可以帮助理解文本的情感倾向以及具体的关注点。 例如,在一句话 “The food was delicious but the service was slow.” 中,“food-delicious-positive” 构成一个正面的三元组,而 “service-slow-negative” 构成一个负面的三元组。这类三元组的应用场景主要包括产品评论挖掘、品牌声誉监控等领域。 #### 2. 知识图谱中的三元组 知识图谱的核心构建单元也是三元组,其一般表示为 **(头节点, 关系, 尾节点)** 或者更通俗地说 **(主体, 谓词, 客体)**。例如,`(北京, 首都_of_, 中国)` 表示一种事实性的语义关系[^2]。知识图谱通过这些三元组来表达复杂的现实世界信息,并支持诸如智能问答、个性化推荐等功能。 相比于传统的数据库表结构,知识图谱利用图模型的优势在于它可以灵活地扩展新的实体和关系类型,从而适应更加动态的数据环境。 #### 3. 矩阵压缩存储中的稀疏矩阵三元组 在计算机科学中,特别是涉及大规模数据集时,为了节省内存空间并提高计算效率,常采用稀疏矩阵的形式保存那些大部分元素为零的大规模数组。其中一种常见的方法就是使用 COO (Coordinate Format) 来记录非零项的位置及其对应的数值作为三元组 `(row_index, col_index, value)`[^3]。这种方法特别适用于解决高维特征向量或者大型机器学习训练过程中的资源优化问题。 --- ### 应用场景对比 | 特性/应用 | 情感分析三元组 | 知识图谱三元组 | 稀疏矩阵三元组 | |-----------|-----------------------------------------|---------------------------------------|-------------------------------------| | 数据性质 | 主观意见 | 实际存在的客观事实 | 数字化数值 | | 存储方式 | 文本标注 | 图形链接 | 坐标索引 | | 使用目的 | 提取用户偏好与反馈 | 描述事物间逻辑关联 | 减少冗余提升运算速度 | 上述三种类型的三元组虽然表面上看似相似,但实际上服务于完全不同的需求背景和技术框架之下: - 对于情感分析而言,它更多聚焦的是自然语言处理层面的任务; - 知识图谱则强调跨域的知识融合能力; - 而稀疏矩阵则是纯粹从算法性能角度出发的一种技巧手段。 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值