原始论文:Learning Transferable Visual Models From Natural Language Supervision
一、CLIP 简介
CLIP(Contrastive Language–Image Pre-training)由 OpenAI 提出,通过对比学习方式,将自然语言监督引入视觉模型中训练。与传统视觉分类模型不同,CLIP 使用完整的自然语言描述作为监督信号,使模型可以理解“语言–图像”的语义对齐能力。
CLIP 模型由两个主要部分组成:
-
图像编码器(如 ResNet)
-
文本编码器(如 BERT 或 DistilBERT)
其训练目标是最大化图像与其匹配文本之间的相似度,并最小化与不匹配文本的相似度。
二、数据准备:Flickr8K 图文对
我们使用 Flickr8K 数据集(包含 8000 张图像,每张图像有 5 条英文描述)作为示例,训练一个简化版本的 CLIP。为了与文本模型(如 DistilBERT)兼容,我们将图像转为 RGB 三通道,大小统一为 224x224。文本则通过 tokenizer 编码为 token 序列。
首先自定义数据集类 CLIPDataset
:
class CLIPDataset(torch.utils.data.Dataset):
def __init__(self, image_filenames, captions, tokenizer, transforms):
self.image_filenames = image_filenames
self.captions = list(captions)
self.encoded_captions = tokenizer(
self.captions, padding=True, truncation=True, max_length=200
)
self.transforms = transforms
def __getitem__(self, idx):
item = {
key: torch.tensor(values[idx])
for key, values in self.encoded_captions.items()
}
image = cv2.imread(f"{CFG.image_path}/{self.image_filenames[idx]}")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = self.transforms(image=image)['image']
item['image'] = torch.tensor(image).permute(2, 0, 1).float()
item['caption'] = self.captions[idx]
return item
def __len__(self):
return len(self.captions)
我们使用 DistilBERT 的 tokenizer 对文本进行处理,同时对图像进行标准化与尺寸统一:
def get_transforms():
return A.Compose([
A.Resize(CFG.size, CFG.size, always_apply=True),
A.Normalize(max_pixel_value=255.0, always_apply=True),
])
构建 dataloader:
def build_loaders(df, tokenizer, mode):
dataset = CLIPDataset(
df["image"].values,
df["caption"].values,
tokenizer=tokenizer,
transforms=get_transforms(),
)
return DataLoader(dataset, batch_size=CFG.batch_size, shuffle=(mode == "train"), num_workers=CFG.num_workers)
三、模型结构解析
CLIP 模型由两个主要模块组成:图像编码器和文本编码器,以及两个投影头(Projection Head)将它们映射到同一个对比学习空间。
1. 图像编码器
我们使用 timm
中的 ResNet50 作为图像编码器,输出维度为 2048。
class ImageEncoder(nn.Module):
def __init__(self):
super().__init__()
self.model = timm.create_model('resnet50', pretrained=True, num_classes=0, global_pool="avg")
def forward(self, x):
return self.model(x)
2. 文本编码器
使用 HuggingFace 的 DistilBertModel
获取文本表示,CLS token 的输出用于表示整个句子的语义,输出维度为 768。
class TextEncoder(nn.Module):
def __init__(self):
super().__init__()
self.model = DistilBertModel.from_pretrained("distilbert-base-uncased")
self.target_token_idx = 0
def forward(self, input_ids, attention_mask):
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
return output.last_hidden_state[:, self.target_token_idx, :]
3. 投影模块(Projection Head)
将图像和文本表示分别映射到统一的维度(如 256),便于计算对比相似度。
class ProjectionHead(nn.Module):
def __init__(self, embedding_dim, projection_dim=256, dropout=0.1):
super().__init__()
self.projection = nn.Linear(embedding_dim, projection_dim)
self.gelu = nn.GELU()
self.fc = nn.Linear(projection_dim, projection_dim)
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(projection_dim)
def forward(self, x):
projected = self.projection(x)
x = self.gelu(projected)
x = self.fc(x)
x = self.dropout(x)
x = x + projected
x = self.layer_norm(x)
return x
四、对比学习目标函数
将图像和文本经过各自编码器与投影层后,计算两者之间的相似度矩阵(batch_size × batch_size)。理想情况下,该相似度矩阵应为单位矩阵。
采用交叉熵损失(CrossEntropyLoss)结合 softmax,实现对比目标。
class CLIPModel(nn.Module):
def __init__(self):
super().__init__()
self.image_encoder = ImageEncoder()
self.text_encoder = TextEncoder()
self.image_projection = ProjectionHead(embedding_dim=2048)
self.text_projection = ProjectionHead(embedding_dim=768)
self.temperature = 1.0
def forward(self, batch):
image_features = self.image_encoder(batch["image"])
text_features = self.text_encoder(
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
)
image_embeddings = self.image_projection(image_features)
text_embeddings = self.text_projection(text_features)
logits = (text_embeddings @ image_embeddings.T) / self.temperature
images_similarity = image_embeddings @ image_embeddings.T
texts_similarity = text_embeddings @ text_embeddings.T
targets = F.softmax((images_similarity + texts_similarity) / 2 * self.temperature, dim=-1)
texts_loss = cross_entropy(logits, targets, reduction='none')
images_loss = cross_entropy(logits.T, targets.T, reduction='none')
loss = (images_loss + texts_loss) / 2.0
return loss.mean()
其中交叉熵实现如下:
def cross_entropy(preds, targets, reduction='none'):
log_softmax = nn.LogSoftmax(dim=-1)
loss = (-targets * log_softmax(preds)).sum(1)
if reduction == "none":
return loss
elif reduction == "mean":
return loss.mean()
五、训练过程
训练过程包括:
-
train_epoch
:训练一轮 -
valid_epoch
:验证一轮 -
main()
:主训练入口
支持 AdamW
优化器与学习率调度器。
训练过程核心代码
def train_epoch(model, train_loader, optimizer):
model.train()
for batch in tqdm(train_loader):
batch = {k: v.to(device) for k, v in batch.items() if k != "caption"}
loss = model(batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
def valid_epoch(model, valid_loader):
model.eval()
with torch.no_grad():
total_loss = 0
for batch in tqdm(valid_loader):
batch = {k: v.to(device) for k, v in batch.items() if k != "caption"}
loss = model(batch)
total_loss += loss.item()
return total_loss / len(valid_loader)
主函数如下:
def main():
train_df, valid_df = make_dfs()
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
train_loader = build_loaders(train_df, tokenizer, mode="train")
valid_loader = build_loaders(valid_df, tokenizer, mode="valid")
model = CLIPModel().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
best_loss = float('inf')
for epoch in range(4):
train_epoch(model, train_loader, optimizer)
val_loss = valid_epoch(model, valid_loader)
print(f"Epoch {epoch + 1}, Val Loss: {val_loss:.4f}")
if val_loss < best_loss:
best_loss = val_loss
torch.save(model.state_dict(), "best.pt")
print("✅ Best model saved")
六、推理与图像检索
推理阶段,我们输入一段文本,CLIP 模型输出其向量表示,再与验证集中所有图片的向量做点积,得到最相似的图像。
检索匹配图片核心代码
def get_image_embeddings(valid_df, model_path):
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
valid_loader = build_loaders(valid_df, tokenizer, mode="valid")
model = CLIPModel().to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
image_embeddings = []
with torch.no_grad():
for batch in tqdm(valid_loader):
image_features = model.image_encoder(batch["image"].to(device))
image_embeddings.append(model.image_projection(image_features))
return model, torch.cat(image_embeddings)
图文检索函数:
def find_matches(model, image_embeddings, query, image_filenames, n=9):
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
encoded_query = tokenizer([query])
batch = {k: torch.tensor(v).to(device) for k, v in encoded_query.items()}
with torch.no_grad():
text_features = model.text_encoder(**batch)
text_embeddings = model.text_projection(text_features)
image_embeddings_n = F.normalize(image_embeddings, dim=-1)
text_embeddings_n = F.normalize(text_embeddings, dim=-1)
similarity = text_embeddings_n @ image_embeddings_n.T
_, indices = similarity[0].topk(n)
matches = [image_filenames[i] for i in indices]
plt.figure(figsize=(12, 8))
for i, fname in enumerate(matches):
plt.subplot(3, 3, i+1)
image = cv2.imread(f"{CFG.image_path}/{fname}")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
plt.imshow(image)
plt.axis("off")
plt.title(f"{i+1}")
plt.tight_layout()
plt.show()
测试图文匹配效果:
find_matches(
model,
image_embeddings,
query="a man riding a surfboard on a wave",
image_filenames=valid_df["image"].values,
n=9
)
七、结语
CLIP 模型开创性地提出了一种高效的图文对比学习框架,仅使用自然语言描述就可以让模型学到可迁移的视觉语义表示。相比传统图像分类模型,CLIP 的优点主要包括:
-
✅ 利用大规模网络语料进行图像监督训练;
-
✅ 同时支持图像→文本和文本→图像的相似度检索;
-
✅ 零样本迁移能力强,可适用于未见过的新任务或类别;
-
✅ 结构简单、部署灵活,可轻松搭建轻量版训练版本。
当然,CLIP 也并非完美,例如训练成本仍高,依赖高质量大规模图文对,但它所代表的“多模态对齐”范式已成为当前 AI 发展的关键路径。
📌 下一篇为第 77 篇 —— Flamingo模型原理与实现,敬请期待!
如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!
欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。
谢谢大家的支持!