基于 Python 的自然语言处理系列文章 (76):CLIP 模型原理与实现

原始论文:Learning Transferable Visual Models From Natural Language Supervision

GitHub:https://github.com/openai/CLIP


一、CLIP 简介

        CLIP(Contrastive Language–Image Pre-training)由 OpenAI 提出,通过对比学习方式,将自然语言监督引入视觉模型中训练。与传统视觉分类模型不同,CLIP 使用完整的自然语言描述作为监督信号,使模型可以理解“语言–图像”的语义对齐能力。

        CLIP 模型由两个主要部分组成:

  • 图像编码器(如 ResNet)

  • 文本编码器(如 BERT 或 DistilBERT)

        其训练目标是最大化图像与其匹配文本之间的相似度,并最小化与不匹配文本的相似度。

二、数据准备:Flickr8K 图文对

        我们使用 Flickr8K 数据集(包含 8000 张图像,每张图像有 5 条英文描述)作为示例,训练一个简化版本的 CLIP。为了与文本模型(如 DistilBERT)兼容,我们将图像转为 RGB 三通道,大小统一为 224x224。文本则通过 tokenizer 编码为 token 序列。

        首先自定义数据集类 CLIPDataset

class CLIPDataset(torch.utils.data.Dataset):
    def __init__(self, image_filenames, captions, tokenizer, transforms):
        self.image_filenames = image_filenames
        self.captions = list(captions)
        self.encoded_captions = tokenizer(
            self.captions, padding=True, truncation=True, max_length=200
        )
        self.transforms = transforms

    def __getitem__(self, idx):
        item = {
            key: torch.tensor(values[idx])
            for key, values in self.encoded_captions.items()
        }

        image = cv2.imread(f"{CFG.image_path}/{self.image_filenames[idx]}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = self.transforms(image=image)['image']
        item['image'] = torch.tensor(image).permute(2, 0, 1).float()
        item['caption'] = self.captions[idx]
        return item

    def __len__(self):
        return len(self.captions)

        我们使用 DistilBERT 的 tokenizer 对文本进行处理,同时对图像进行标准化与尺寸统一:

def get_transforms():
    return A.Compose([
        A.Resize(CFG.size, CFG.size, always_apply=True),
        A.Normalize(max_pixel_value=255.0, always_apply=True),
    ])

        构建 dataloader:

def build_loaders(df, tokenizer, mode):
    dataset = CLIPDataset(
        df["image"].values,
        df["caption"].values,
        tokenizer=tokenizer,
        transforms=get_transforms(),
    )
    return DataLoader(dataset, batch_size=CFG.batch_size, shuffle=(mode == "train"), num_workers=CFG.num_workers)

三、模型结构解析

        CLIP 模型由两个主要模块组成:图像编码器和文本编码器,以及两个投影头(Projection Head)将它们映射到同一个对比学习空间。

1. 图像编码器

        我们使用 timm 中的 ResNet50 作为图像编码器,输出维度为 2048。

class ImageEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = timm.create_model('resnet50', pretrained=True, num_classes=0, global_pool="avg")
    def forward(self, x):
        return self.model(x)

2. 文本编码器

        使用 HuggingFace 的 DistilBertModel 获取文本表示,CLS token 的输出用于表示整个句子的语义,输出维度为 768。

class TextEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = DistilBertModel.from_pretrained("distilbert-base-uncased")
        self.target_token_idx = 0

    def forward(self, input_ids, attention_mask):
        output = self.model(input_ids=input_ids, attention_mask=attention_mask)
        return output.last_hidden_state[:, self.target_token_idx, :]

3. 投影模块(Projection Head)

        将图像和文本表示分别映射到统一的维度(如 256),便于计算对比相似度。

class ProjectionHead(nn.Module):
    def __init__(self, embedding_dim, projection_dim=256, dropout=0.1):
        super().__init__()
        self.projection = nn.Linear(embedding_dim, projection_dim)
        self.gelu = nn.GELU()
        self.fc = nn.Linear(projection_dim, projection_dim)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(projection_dim)

    def forward(self, x):
        projected = self.projection(x)
        x = self.gelu(projected)
        x = self.fc(x)
        x = self.dropout(x)
        x = x + projected
        x = self.layer_norm(x)
        return x

四、对比学习目标函数

        将图像和文本经过各自编码器与投影层后,计算两者之间的相似度矩阵(batch_size × batch_size)。理想情况下,该相似度矩阵应为单位矩阵。

        采用交叉熵损失(CrossEntropyLoss)结合 softmax,实现对比目标。

class CLIPModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.image_encoder = ImageEncoder()
        self.text_encoder = TextEncoder()
        self.image_projection = ProjectionHead(embedding_dim=2048)
        self.text_projection = ProjectionHead(embedding_dim=768)
        self.temperature = 1.0

    def forward(self, batch):
        image_features = self.image_encoder(batch["image"])
        text_features = self.text_encoder(
            input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
        )
        image_embeddings = self.image_projection(image_features)
        text_embeddings = self.text_projection(text_features)

        logits = (text_embeddings @ image_embeddings.T) / self.temperature
        images_similarity = image_embeddings @ image_embeddings.T
        texts_similarity = text_embeddings @ text_embeddings.T

        targets = F.softmax((images_similarity + texts_similarity) / 2 * self.temperature, dim=-1)
        texts_loss = cross_entropy(logits, targets, reduction='none')
        images_loss = cross_entropy(logits.T, targets.T, reduction='none')

        loss = (images_loss + texts_loss) / 2.0
        return loss.mean()

        其中交叉熵实现如下:

def cross_entropy(preds, targets, reduction='none'):
    log_softmax = nn.LogSoftmax(dim=-1)
    loss = (-targets * log_softmax(preds)).sum(1)
    if reduction == "none":
        return loss
    elif reduction == "mean":
        return loss.mean()

五、训练过程

        训练过程包括:

  • train_epoch:训练一轮

  • valid_epoch:验证一轮

  • main():主训练入口

        支持 AdamW 优化器与学习率调度器。

训练过程核心代码

def train_epoch(model, train_loader, optimizer):
    model.train()
    for batch in tqdm(train_loader):
        batch = {k: v.to(device) for k, v in batch.items() if k != "caption"}
        loss = model(batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

def valid_epoch(model, valid_loader):
    model.eval()
    with torch.no_grad():
        total_loss = 0
        for batch in tqdm(valid_loader):
            batch = {k: v.to(device) for k, v in batch.items() if k != "caption"}
            loss = model(batch)
            total_loss += loss.item()
    return total_loss / len(valid_loader)

        主函数如下:

def main():
    train_df, valid_df = make_dfs()
    tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
    train_loader = build_loaders(train_df, tokenizer, mode="train")
    valid_loader = build_loaders(valid_df, tokenizer, mode="valid")

    model = CLIPModel().to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

    best_loss = float('inf')
    for epoch in range(4):
        train_epoch(model, train_loader, optimizer)
        val_loss = valid_epoch(model, valid_loader)
        print(f"Epoch {epoch + 1}, Val Loss: {val_loss:.4f}")
        if val_loss < best_loss:
            best_loss = val_loss
            torch.save(model.state_dict(), "best.pt")
            print("✅ Best model saved")

六、推理与图像检索

        推理阶段,我们输入一段文本,CLIP 模型输出其向量表示,再与验证集中所有图片的向量做点积,得到最相似的图像。

检索匹配图片核心代码

def get_image_embeddings(valid_df, model_path):
    tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
    valid_loader = build_loaders(valid_df, tokenizer, mode="valid")
    model = CLIPModel().to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()

    image_embeddings = []
    with torch.no_grad():
        for batch in tqdm(valid_loader):
            image_features = model.image_encoder(batch["image"].to(device))
            image_embeddings.append(model.image_projection(image_features))
    return model, torch.cat(image_embeddings)

        图文检索函数:

def find_matches(model, image_embeddings, query, image_filenames, n=9):
    tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
    encoded_query = tokenizer([query])
    batch = {k: torch.tensor(v).to(device) for k, v in encoded_query.items()}
    with torch.no_grad():
        text_features = model.text_encoder(**batch)
        text_embeddings = model.text_projection(text_features)

    image_embeddings_n = F.normalize(image_embeddings, dim=-1)
    text_embeddings_n = F.normalize(text_embeddings, dim=-1)
    similarity = text_embeddings_n @ image_embeddings_n.T

    _, indices = similarity[0].topk(n)
    matches = [image_filenames[i] for i in indices]

    plt.figure(figsize=(12, 8))
    for i, fname in enumerate(matches):
        plt.subplot(3, 3, i+1)
        image = cv2.imread(f"{CFG.image_path}/{fname}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        plt.imshow(image)
        plt.axis("off")
        plt.title(f"{i+1}")
    plt.tight_layout()
    plt.show()

        测试图文匹配效果:

find_matches(
    model, 
    image_embeddings, 
    query="a man riding a surfboard on a wave", 
    image_filenames=valid_df["image"].values, 
    n=9
)

七、结语

        CLIP 模型开创性地提出了一种高效的图文对比学习框架,仅使用自然语言描述就可以让模型学到可迁移的视觉语义表示。相比传统图像分类模型,CLIP 的优点主要包括:

  • ✅ 利用大规模网络语料进行图像监督训练;

  • ✅ 同时支持图像→文本和文本→图像的相似度检索;

  • ✅ 零样本迁移能力强,可适用于未见过的新任务或类别;

  • ✅ 结构简单、部署灵活,可轻松搭建轻量版训练版本。

        当然,CLIP 也并非完美,例如训练成本仍高,依赖高质量大规模图文对,但它所代表的“多模态对齐”范式已成为当前 AI 发展的关键路径。


        📌 下一篇为第 77 篇 —— Flamingo模型原理与实现,敬请期待!

如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!

欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。

谢谢大家的支持!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

会飞的Anthony

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值