Vision Transformer(ViT)PyTorch代码全解析

本文详细解析了基于原始VisionTransformer论文和PyTorch的实现,介绍了ViT的基本结构、关键组件(如注意力机制、前馈网络等)、以及如何在代码中构建Transformer模型。适合对Transformer不熟悉的读者入门学习。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

Vision Transformer(ViT)PyTorch代码全解析
最近CV领域的Vision Transformer将在NLP领域的Transormer结果借鉴过来,屠杀了各大CV榜单。本文将根据最原始的Vision Transformer论文,及其PyTorch实现,将整个ViT的代码做一个全面的解析。

对原Transformer还不熟悉的读者可以看一下Attention is All You Need原文,中文讲解推荐李宏毅老师的视频 YouTube,BiliBili 个人觉得讲的很明白。

话不多说,直接开始。

下图是ViT的整体框架图,我们在解析代码时会参照此图:


以下是文中给出的符号公式,也是我们解析的重要参照:
z = [ x c l a s s ; x p 1 E , x p 2 E , …   ; x p N E ] + E p o s ,     E ∈ R ( P 2 ⋅ C ) × D , E p o s ∈ R ( N + 1 ) × D               ( 1 ) \mathbf{z}=[\mathbf{x}_{class};\mathbf{x}^1_p\mathbf{E},\mathbf{x}^2_p\mathbf{E},\dots;\mathbf{x}^N_p\mathbf{E}]+\mathbf{E}_{pos},\ \ \ \mathbf{E}\in\mathbb{R}^{(P^2\cdot C)\times D},\mathbf{E}_{pos}\in \mathbb{R}^{(N+1)\times D} \ \ \ \ \ \ \ \ \ \ \ \ \ (1)
z=[x 
class

 ;x 
p
1

 E,x 
p
2

 E,…;x 
p
N

 E]+E 
pos

 ,   E∈R 
(P 
2
 ⋅C)×D
 ,E 
pos

 ∈R 
(N+1)×D
              (1)

z ℓ ′ = M S A ( L N ( z ℓ − 1 ) ) + z ℓ − 1                                        ( 2 ) \mathbf{z'_\ell}=MSA(LN(\mathbf{z}_{\ell-1}))+\mathbf{z}_{\ell-1}\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ (2)




 =MSA(LN(z 
ℓ−1

 ))+z 
ℓ−1

                                       (2)

z ℓ = M L P ( L N ( z ′ ℓ ) ) + z ′ ℓ                                   ( 3 ) \mathbf{z}_{\ell}=MLP(LN(\mathbf{z'}_{\ell}))+\mathbf{z'}_{\ell}\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ (3)



 =MLP(LN(z 

  


 ))+z 

  


                                  (3)

y = L N ( z L 0 )                                    ( 4 ) \mathbf{y}=LN(\mathbf{z}_L^0)\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ (4)
y=LN(z 
L
0

 )                                  (4)

导入需要的包
import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange, repeat
from einops.layers.torch import Rearrange
1
2
3
4
5
6
都是搭建网络时常用的PyTorch包,其中在卷积神经网络的搭建中并不常用的einops和einsum,还不熟悉的读者可以参考博客:einops和einsum:直接操作张量的利器。

pair函数
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

1
2
3
作用是:判断t是否是元组,如果是,直接返回t;如果不是,则将t复制为元组(t, t)再返回。
用来处理

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

AI周红伟

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值