Vision Transformer——ViT代码解读

这篇博客详细介绍了Vision Transformer (ViT) 的工作原理,包括Transformer结构、多头注意力机制和FeedForward网络。提供了一个简洁易懂的PyTorch复现版本,并给出了使用示例。ViT将图像划分为patches,通过Transformer编码器进行处理,最后通过分类头进行预测。文章适合初学者了解ViT模型。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

官方提供的代码:https://github.com/google-research/vision_transformer
大佬复现的版本:https://github.com/lucidrains/vit-pytorch
对不起,我好菜,官方给的代码我确实看不懂啊,所以看了第二个版本的代码。第二个版本的代码超级受欢迎且易使用,我看的时候,Git repo已经被star 5.7k次。大家直接 pip install vit-pytorch就好。
所以作为初次接触vit的同学们来说,推荐看第二个版本,结构清晰明了。

1. 大佬复现版本给的使用案例

import torch
from vit_pytorch import ViT

v = ViT(
    image_size = 256,    # 图像大小
    patch_size = 32,     # patch大小(分块的大小)
    num_classes = 1000,  # imagenet数据集1000分类
    dim = 1024,          # position embedding的维度
    depth = 6,           # encoder和decoder中block层数是6
    heads = 16,          # multi-head中head的数量为16
    mlp_dim = 2048,
    dropout = 0.1,       # 
    emb_dropout = 0.1
)

img = torch.randn(1, 3, 256, 256)

preds = v(img) # (1, 1000)

大家完全可以把这段代码copy-paste到自己的pycharm里,然后使用调试功能,一步步看ViT的每一步操作。

2. Transformer结构

进行6次for循环,有6层encoder结构。for循环内部依次使用multi-head attention和Feed Forward,对应Transformer的Encoder结构中多头自注意力模块和MLP模块。在自注意力后及feed forward后,有残差连接。

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

PreNorm类代码如下,在使用multi-head attention和Feed Forward之前,首先对输入通过LayerNorm进行处理。

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

可以参考论文中的图:
在这里插入图片描述

3. Attention

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim = -1)    # 首先生成q,k,v
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

torch.chunk(tensor, chunk_num, dim)函数的功能:与torch.cat()刚好相反,它是将tensor按dim(行或列)分割成chunk_num个tensor块,返回的是一个元组。
attention操作的整体流程:

  1. 首先对输入生成query, key和value,这里的“输入”有可能是整个网络的输入,也可能是某个hidden layer的output。在这里,生成的qkv是个长度为3的元组,每个元组的大小为(1, 65, 1024)
  2. 对qkv进行处理,重新指定维度,得到的q, k, v维度均为(1, 16, 65, 64)
  3. q和k做点乘,得到的dots维度为(1, 16, 65, 65)
  4. 对dots的最后一维做softmax,得到各个patch对其他patch的注意力得分
  5. 将attention和value做点乘
  6. 对各个维度重新排列,得到与输入相同维度的输出 (1, 65, 1024)

4. FeedForward

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),   # dim=1024, hidden_dim=2048
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

FeedForward模块共有2个全连接层,整个结构是:

  1. 首先过一个全连接层
  2. 经过GELU()激活函数进行处理
  3. nn.Dropout(),以一定概率丢失掉一些神经元,防止过拟合
  4. 再过一个全连接层
  5. nn.Dropout()
    注意:GELU(x) = x * Φ(x), 其中,Φ(x)是高斯分布的累积分布函数 。

5. ViT操作流程

ViT的各个结构都写在了__init__()里,不再细讲,通过forward()来看ViT的整个前向传播过程(操作流程)。

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.Linear(patch_dim, dim),
        )
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))  # (1,65,1024)
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
        self.pool = pool
        self.to_latent = nn.Identity()
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):   # img: (1, 3, 256, 256)
        x = self.to_patch_embedding(img)     # (1, 64, 1024)
        b, n, _ = x.shape
        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)    # (1, 1, 1024)
        x = torch.cat((cls_tokens, x), dim=1)  # (1, 65, 1024)
        x += self.pos_embedding[:, :(n + 1)]   # (1, 65, 1024)
        x = self.dropout(x)                    # (1, 65, 1024)
        x = self.transformer(x)                # (1, 65, 1024)
        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]      # (1, 1024)
        x = self.to_latent(x)
        return self.mlp_head(x)

整体流程:

  1. 首先对输入进来的img(256*256大小),划分为32*32大小的patch,共有8*8个。并将patch转换成embedding。(对应第26行代码)
  2. 生成cls_tokens (对应第28行代码)
  3. 将cls_tokens沿dim=1维与x进行拼接 (对应第29行代码)
  4. 生成随机的position embedding,每个embedding都是1024维 (对应代码14行和30行)
  5. 对输入经过Transformer进行编码(对应代码第32行)
  6. 如果是分类任务的话,截取第一个可学习的class embedding
  7. 最后过一个MLP Head用于分类。
### 关于Vision Transformer (ViT) 的代码实现 在探索 Vision Transformer (ViT) 实现的过程中,可以基于简化版的PyTorch实现来了解其工作原理[^2]。下面展示了一个简化的 ViT 模型结构: ```python import torch from torch import nn, optim class PatchEmbedding(nn.Module): """将图像分割成多个patch并嵌入""" def __init__(self, img_size=224, patch_size=16, embed_dim=768): super().__init__() self.proj = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x): x = self.proj(x).flatten(2).transpose(1, 2) return x class Block(nn.Module): """创建单个Transformer编码器层""" def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop_rate=0.): super().__init__() self.norm1 = nn.LayerNorm(dim) self.attn = nn.MultiheadAttention(dim, num_heads, dropout=drop_rate, batch_first=True) # ...其余初始化省略... def forward(self, x): attn_output, _ = self.attn(query=self.norm1(x), key=x, value=x) x = x + attn_output # 剩余前向传播逻辑... return x class VisionTransformer(nn.Module): """构建完整的Vision Transformer模型""" def __init__(self, img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., class_token=True): super().__init__() self.patch_embed = PatchEmbedding(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim) if class_token: self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.blocks = nn.Sequential(*[ Block( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop_rate=drop_rate ) for i in range(depth)]) self.norm = nn.LayerNorm(embed_dim) def forward_features(self, x): B = x.shape[0] x = self.patch_embed(x) cls_tokens = self.cls_token.expand(B, -1, -1) x = torch.cat((cls_tokens, x), dim=1) x = self.blocks(x) x = self.norm(x) return x[:, 0] def forward(self, x): x = self.forward_features(x) return x ``` 此段代码展示了如何定义 `PatchEmbedding` 类用于处理输入图片到补丁序列转换;`Block` 类代表了每个变压器编码单元的核心组件;而整个架构则由 `VisionTransformer` 完整表示。 为了使类token参与训练,在实例化过程中加入了可学习参数 `cls_token` 并将其与patch embeddings连接在一起作为输入传递给标准的Transformer Encoder 中[^3]。
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值