ViT结构详解(附pytorch代码)

ViT模型详解
本文详细介绍Vision Transformer (ViT) 模型的构建过程,包括图像分割、Patch Embedding、多头注意力机制、残差连接及最终分类头的设计原理。

参考这篇文章,本文会加一些注解。

源自paper: AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE

ViT把tranformer用在了图像上, transformer的文章: Attention is all you need

ViT的结构如下:
在这里插入图片描述
可以看到是把图像分割成小块,像NLP的句子那样按顺序进入transformer,经过MLP后,输出类别。
每个小块是16x16,进入Linear Projection of Flattened Patches, 在每个的开头加上cls token位置信息,
也就是position embedding。

从下而上实现,position embedding, Transformer, Head, Vit的顺序。
首先import

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch import nn
from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from torchsummary import summary

image输入要是224x224x3, 所以先reshape一下

# resize to imagenet size 
transform = Compose([Resize((224, 224)), ToTensor()])
x = transform(img)
x = x.unsqueeze(0) # add batch dim
x.shape

这是shape是[1, 3, 224, 224]

把图片分成小块
在这里插入图片描述

patch_size = 16 # 16 pixels
pathes = rearrange(x, 'b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=patch_size, s2=patch_size)

rearrange里面的(h s1)表示hxs1,而s1是patch_size=16, 那通过hx16=224可以算出height里面包含了h个patch_size,
同理算出weight里面包含了w个patch_size。
然后输出是b (h w) (s1 s2 c),这相当于把每个patch(16x16x3)拉成一个向量,每个batch里面有hxw个这样的向量。
就相当于上图一字排开有hxw个小块。

然后把这些小块放进Linear layer改变每条向量的维度。
在这里插入图片描述
上面这些可以写成一个class,用conv2代替linear layer提高计算效率,把拉成的一条向量维度变成e

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768):
        self.patch_size = patch_size
        super().__init__()
        self.projection = nn.Sequential(
            # break-down the image in s1 x s2 patches and flat them
            Rearrange('b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=patch_size, s2=patch_size),
            nn.Linear(patch_size * patch_size * in_channels, emb_size)
        )
                
    def forward(self, x: Tensor) -> Tensor:
        x = self.projection(x)
        return x
PatchEmbedding()(x).shape

torch.Size([1, 196, 768])

CLS token

要在刚刚的patch向量中加入cls token和每个patch所在的位置信息,也就是position embedding。
cls token就是每个sequence开头的一个数字。
一张图片的一串patch是一个sequence, 所以cls token就加在它们前面,embedding_size的向量copy batch_size次。

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768):
        self.patch_size = patch_size
        super().__init__()
        self.proj = nn.Sequential(
            # using a conv layer instead of a linear one -> performance gains
            nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
            Rearrange('b e (h) (w) -> b (h w) e'),
        )
        
        self.cls_token = nn.Parameter(torch.ran
评论 9
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

蓝羽飞鸟

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值