对比学习:PyTorch表示学习技术

对比学习:PyTorch表示学习技术

【免费下载链接】pytorch-deep-learning Materials for the Learn PyTorch for Deep Learning: Zero to Mastery course. 【免费下载链接】pytorch-deep-learning 项目地址: https://gitcode.com/GitHub_Trending/py/pytorch-deep-learning

引言:为什么需要对比学习?

在传统的监督学习中,我们需要大量标注数据来训练模型。然而,现实世界中大多数数据都是未标注的。你还在为标注数据而烦恼吗?对比学习(Contrastive Learning)作为一种自监督学习(Self-Supervised Learning)方法,能够从未标注数据中学习有意义的表示,彻底改变了表示学习的游戏规则。

通过本文,你将掌握:

  • 对比学习的核心原理和数学基础
  • 主流对比学习算法的PyTorch实现
  • 在实际项目中的应用技巧和最佳实践
  • 性能优化和调试策略

对比学习基础概念

什么是表示学习?

表示学习(Representation Learning)旨在学习数据的低维、稠密表示,这些表示能够捕捉数据的内在结构和语义信息。对比学习通过比较数据样本之间的相似性和差异性来学习这种表示。

对比学习的核心思想

对比学习的核心在于"拉近相似样本,推远不相似样本"。这种思想可以通过以下公式表达:

$$ \mathcal{L} = -\log \frac{\exp(\text{sim}(z_i, z_j)/\tau)}{\sum_{k=1}^{2N} \mathbb{1}_{[k \neq i]} \exp(\text{sim}(z_i, z_k)/\tau)} $$

其中:

  • $z_i, z_j$ 是正样本对的嵌入向量
  • $\text{sim}$ 是相似度函数(通常为余弦相似度)
  • $\tau$ 是温度参数
  • $N$ 是批次大小

主流对比学习算法

1. SimCLR(Simple Framework for Contrastive Learning)

SimCLR是Google提出的经典对比学习框架,其核心流程如下:

mermaid

PyTorch实现核心代码:

import torch
import torch.nn as nn
import torch.nn.functional as F

class SimCLR(nn.Module):
    def __init__(self, base_encoder, projection_dim=128):
        super(SimCLR, self).__init__()
        self.encoder = base_encoder(pretrained=False)
        self.feature_dim = self.encoder.fc.in_features
        self.encoder.fc = nn.Identity()  # 移除分类头
        
        # 投影头
        self.projector = nn.Sequential(
            nn.Linear(self.feature_dim, self.feature_dim),
            nn.ReLU(),
            nn.Linear(self.feature_dim, projection_dim)
        )
    
    def forward(self, x):
        features = self.encoder(x)
        projections = self.projector(features)
        return F.normalize(projections, dim=1)
    
    def contrastive_loss(self, projections, temperature=0.5):
        batch_size = projections.shape[0] // 2
        device = projections.device
        
        # 构建相似度矩阵
        similarity_matrix = F.cosine_similarity(
            projections.unsqueeze(1), projections.unsqueeze(0), dim=2
        )
        
        # 创建正负样本掩码
        mask = torch.eye(2 * batch_size, device=device).bool()
        positives = similarity_matrix[mask].view(2 * batch_size, -1)
        negatives = similarity_matrix[~mask].view(2 * batch_size, -1)
        
        # 计算对比损失
        logits = torch.cat([positives, negatives], dim=1)
        labels = torch.zeros(2 * batch_size, device=device, dtype=torch.long)
        
        loss = F.cross_entropy(logits / temperature, labels)
        return loss

2. MoCo(Momentum Contrast)

MoCo通过维护一个动态的负样本队列来解决批次大小限制问题:

class MoCo(nn.Module):
    def __init__(self, base_encoder, dim=128, K=65536, m=0.999, T=0.07):
        super(MoCo, self).__init__()
        self.K = K
        self.m = m
        self.T = T
        
        # 查询编码器
        self.encoder_q = base_encoder(num_classes=dim)
        self.encoder_k = base_encoder(num_classes=dim)
        
        # 动量更新
        for param_q, param_k in zip(self.encoder_q.parameters(), 
                                   self.encoder_k.parameters()):
            param_k.data.copy_(param_q.data)
            param_k.requires_grad = False
        
        # 创建队列
        self.register_buffer("queue", torch.randn(dim, K))
        self.queue = nn.functional.normalize(self.queue, dim=0)
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
    
    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        for param_q, param_k in zip(self.encoder_q.parameters(),
                                   self.encoder_k.parameters()):
            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
    
    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        batch_size = keys.shape[0]
        ptr = int(self.queue_ptr)
        
        # 替换队列中的keys
        self.queue[:, ptr:ptr + batch_size] = keys.T
        ptr = (ptr + batch_size) % self.K
        self.queue_ptr[0] = ptr
    
    def forward(self, im_q, im_k):
        # 计算查询特征
        q = self.encoder_q(im_q)
        q = nn.functional.normalize(q, dim=1)
        
        # 计算键特征
        with torch.no_grad():
            self._momentum_update_key_encoder()
            k = self.encoder_k(im_k)
            k = nn.functional.normalize(k, dim=1)
        
        # 计算正样本相似度
        l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
        
        # 计算负样本相似度
        l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
        
        # 合并logits
        logits = torch.cat([l_pos, l_neg], dim=1)
        logits /= self.T
        
        # 创建标签
        labels = torch.zeros(logits.shape[0], dtype=torch.long).to(q.device)
        
        # 更新队列
        self._dequeue_and_enqueue(k)
        
        return logits, labels

数据增强策略

对比学习的成功很大程度上依赖于有效的数据增强策略。以下是常用的增强组合:

增强类型具体操作作用
空间变换随机裁剪、旋转、翻转增加空间不变性
颜色变换颜色抖动、灰度化增加颜色不变性
噪声添加高斯噪声、随机擦除增加鲁棒性
几何变换透视变换、弹性变形增加几何不变性

PyTorch数据增强实现:

from torchvision import transforms

def get_simclr_transform(input_size=224):
    return transforms.Compose([
        transforms.RandomResizedCrop(input_size),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomApply([
            transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)
        ], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.GaussianBlur(kernel_size=int(0.1 * input_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])

训练流程与超参数优化

训练循环设计

def train_contrastive(model, train_loader, optimizer, scheduler, epochs, device):
    model.train()
    scaler = torch.cuda.amp.GradScaler()
    
    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (images, _) in enumerate(train_loader):
            # 获取增强后的正样本对
            x1, x2 = images[0].to(device), images[1].to(device)
            
            with torch.cuda.amp.autocast():
                # 前向传播
                z1 = model(x1)
                z2 = model(x2)
                
                # 计算对比损失
                loss = model.contrastive_loss(torch.cat([z1, z2]), temperature=0.5)
            
            # 反向传播
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            total_loss += loss.item()
        
        # 学习率调度
        scheduler.step()
        
        print(f'Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_loader):.4f}')

超参数优化策略

超参数推荐值调整策略
批次大小256-1024越大越好,受GPU内存限制
学习率0.03-0.3使用余弦退火或线性warmup
温度参数τ0.05-0.2影响相似度分布的尖锐程度
投影维度128-512与下游任务复杂度相关
训练轮数100-1000根据数据集大小调整

评估与下游任务迁移

线性评估协议

def linear_evaluation(encoder, train_loader, test_loader, num_classes, device):
    # 冻结编码器参数
    for param in encoder.parameters():
        param.requires_grad = False
    
    # 添加线性分类头
    classifier = nn.Linear(encoder.feature_dim, num_classes).to(device)
    optimizer = torch.optim.Adam(classifier.parameters(), lr=0.01)
    criterion = nn.CrossEntropyLoss()
    
    # 训练分类器
    for epoch in range(100):
        encoder.eval()
        classifier.train()
        
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            
            with torch.no_grad():
                features = encoder(images)
            
            outputs = classifier(features)
            loss = criterion(outputs, labels)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
    # 评估性能
    encoder.eval()
    classifier.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            features = encoder(images)
            outputs = classifier(features)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = 100 * correct / total
    return accuracy

性能对比表格

方法ImageNet Top-1 Acc参数量训练时间适用场景
SimCLR76.5%25M中等通用视觉任务
MoCo v271.1%25M较长大批次训练
BYOL74.3%28M较长无需负样本
SwAV75.3%29M较短聚类-based

实际应用案例

案例1:医学图像分析

class MedicalContrastiveLearning(nn.Module):
    def __init__(self, backbone, modality='CT'):
        super().__init__()
        self.backbone = backbone
        self.modality = modality
        
        # 针对医学图像的特定增强
        self.medical_transform = transforms.Compose([
            transforms.RandomAffine(degrees=15, translate=(0.1, 0.1)),
            transforms.RandomElasticTransform(alpha=2.0),
            transforms.RandomAdjustSharpness(sharpness_factor=2),
            transforms.Lambda(self.medical_specific_augment)
        ])
    
    def medical_specific_augment(self, x):
        # 医学图像特定的增强逻辑
        if self.modality == 'CT':
            # CT图像增强:窗宽窗位调整
            x = self.adjust_window(x, center=40, width=80)
        elif self.modality == 'MRI':
            # MRI图像增强:对比度调整
            x = self.enhance_contrast(x)
        return x
    
    def forward(self, x):
        features = self.backbone(x)
        return features

案例2:文本-图像多模态对比学习

class CLIPLikeModel(nn.Module):
    def __init__(self, image_encoder, text_encoder, embedding_dim=512):
        super().__init__()
        self.image_encoder = image_encoder
        self.text_encoder = text_encoder
        self.image_proj = nn.Linear(image_encoder.feature_dim, embedding_dim)
        self.text_proj = nn.Linear(text_encoder.feature_dim, embedding_dim)
    
    def forward(self, images, texts):
        # 图像编码
        image_features = self.image_encoder(images)
        image_embeddings = F.normalize(self.image_proj(image_features), dim=-1)
        
        # 文本编码
        text_features = self.text_encoder(texts)
        text_embeddings = F.normalize(self.text_proj(text_features), dim=-1)
        
        return image_embeddings, text_embeddings
    
    def contrastive_loss(self, image_embeddings, text_embeddings, temperature=0.07):
        # 计算图像-文本相似度矩阵
        logits = torch.matmul(image_embeddings, text_embeddings.t()) / temperature
        labels = torch.arange(len(image_embeddings)).to(image_embeddings.device)
        
        # 对称对比损失
        loss_i = F.cross_entropy(logits, labels)
        loss_t = F.cross_entropy(logits.t(), labels)
        loss = (loss_i + loss_t) / 2
        
        return loss

性能优化技巧

1. 混合精度训练

from torch.cuda.amp import autocast, GradScaler

def train_with_amp(model, dataloader, optimizer):
    scaler = GradScaler()
    
    for images, _ in dataloader:
        optimizer.zero_grad()
        
        with autocast():
            loss = model.contrastive_loss(images)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

2. 梯度累积

def train_with_gradient_accumulation(model, dataloader, optimizer, accumulation_steps=4):
    model.train()
    total_loss = 0
    
    for i, (images, _) in enumerate(dataloader):
        loss = model.contrastive_loss(images)
        loss = loss / accumulation_steps
        loss.backward()
        total_loss += loss.item()
        
        if (i + 1) % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
    
    return total_loss

3. 学习率调度策略

from torch.optim.lr_scheduler import CosineAnnealingLR, LinearWarmup

def get_scheduler(optimizer, warmup_epochs, total_epochs):
    warmup = LinearWarmup(optimizer, warmup_epochs)
    cosine = CosineAnnealingLR(optimizer, T_max=total_epochs - warmup_epochs)
    
    def lr_lambda(epoch):
        if epoch < warmup_epochs:
            return warmup.get_lr()
        else:
            return cosine.get_lr()
    
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

【免费下载链接】pytorch-deep-learning Materials for the Learn PyTorch for Deep Learning: Zero to Mastery course. 【免费下载链接】pytorch-deep-learning 项目地址: https://gitcode.com/GitHub_Trending/py/pytorch-deep-learning

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值