基于 Python 的自然语言处理系列文章 (75):BEiT (图像片段的揭码器预训练)

原始论文:BEIT: BERT Pre-Training of Image Transformers

GitHub:https://github.com/microsoft/unilm/tree/master/beit


一、BEiT 简介

        BEiT(BERT Pre-Training of Image Transformers)是一种借鉴 BERT 预训练范式的视觉模型,它将图像切分为若干 Patch,并在这些 Patch 上进行遮盖和预测任务(类似于 MLM)以进行表征学习。BEiT 主要用于图像分类、分割等下游任务,在 ImageNet 等多个数据集上表现优异。

        BEiT 的核心是使用 ViT 架构,将图像转化为 Patch 的 Token 序列,再使用 BERT 风格的 Masked Patch Prediction 进行预训练。

        架构如下图所示:


二、环境准备与数据加载

import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor, Grayscale
from torch.utils.data import DataLoader, random_split

# 设置设备
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print("Using device", device)

# 加载 MNIST,并转换为 3 通道图像
transform = Grayscale(num_output_channels=3)
full_dataset = MNIST(root='./data/', train=True, download=True, transform=transform)
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_set, val_set = random_split(full_dataset, [train_size, val_size])
test_set = MNIST(root='./data/', train=False, download=True, transform=transform)

train_loader = DataLoader(train_set, batch_size=16, shuffle=True)
val_loader = DataLoader(val_set, batch_size=16)
test_loader = DataLoader(test_set, batch_size=16)

三、BEiT 模型构建

        完整模型代码较长,核心结构包括:

  • Patch Embedding:使用卷积或线性层将图像划分为 Patch

  • Transformer Encoder:堆叠多个 Block,每个 Block 包含 Attention 和 Feedforward

  • Classification Head:最终取出 [CLS] token 送入线性层进行分类

Patch Embedding 模块

class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)  # [B, C, H, W] -> [B, N, D]
        return x

BEiT 模型主体结构

class BEiT(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=10,
                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_drop = nn.Dropout(p=0.1)
        self.blocks = nn.Sequential(*[
            Block(embed_dim, num_heads, mlp_ratio) for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        x = self.patch_embed(x)
        cls_token = self.cls_token.expand(x.size(0), -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        x = self.pos_drop(x)
        x = self.blocks(x)
        x = self.norm(x)
        return self.head(x[:, 0])

四、训练与评估流程

import time
import torch.nn as nn
import torch.optim as optim
from tqdm.auto import tqdm

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

训练与验证函数

def train(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss, correct = 0, 0
    for x, y in tqdm(dataloader):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        preds = model(x)
        loss = criterion(preds, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        correct += (preds.argmax(1) == y).sum().item()
    return total_loss / len(dataloader), correct / len(dataloader.dataset)

def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss, correct = 0, 0
    with torch.no_grad():
        for x, y in tqdm(dataloader):
            x, y = x.to(device), y.to(device)
            preds = model(x)
            loss = criterion(preds, y)
            total_loss += loss.item()
            correct += (preds.argmax(1) == y).sum().item()
    return total_loss / len(dataloader), correct / len(dataloader.dataset)

训练主循环

best_valid_loss = float('inf')
train_losses, valid_losses = [], []
num_epochs = 5

for epoch in range(num_epochs):
    start_time = time.time()

    train_loss, train_acc = train(model, train_loader, optimizer, criterion, device)
    valid_loss, valid_acc = evaluate(model, val_loader, criterion, device)

    train_losses.append(train_loss)
    valid_losses.append(valid_loss)

    end_time = time.time()
    mins, secs = divmod(int(end_time - start_time), 60)

    print(f"Epoch {epoch+1:02} | Time: {mins}m {secs}s")
    print(f"  Train Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%")
    print(f"  Valid Loss: {valid_loss:.3f} | Valid Acc: {valid_acc*100:.2f}%")

五、可视化训练结果

import matplotlib.pyplot as plt
plt.plot(train_losses, label='Train Loss')
plt.plot(valid_losses, label='Valid Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

六、最终测试评估

test_loss, test_acc = evaluate(model, test_loader, criterion, device)
print(f"Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%")

七、对比 BERT 与 BEiT

模型领域输入单位训练目标
BERTNLPToken(词片段)Masked Language Modeling
BEiTCVPatch(图像块)Masked Image Modeling

        BEiT 本质上是把 NLP 的 BERT 迁移到了视觉世界,借助掩码学习、上下文对齐等机制,极大提升了图像编码器在预训练阶段的泛化能力。


八、小结

        本篇文章基于 PyTorch 实现了 BEiT 模型在 MNIST 数据集上的训练与测试,并详细分析了其结构、训练流程及和 BERT 的异同。

        你可以基于该实现,扩展到 CIFAR10、ImageNet 等更复杂的数据集,进一步发挥 BEiT 的能力。

        敬请期待下一篇:《基于 Python 的自然语言处理系列(76):CLIP》。

如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!

欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。

谢谢大家的支持!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

会飞的Anthony

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值