Vision Transformer(ViT)PyTorch代码全解析
最近CV领域的Vision Transformer将在NLP领域的Transormer结果借鉴过来,屠杀了各大CV榜单。本文将根据最原始的Vision Transformer论文,及其PyTorch实现,将整个ViT的代码做一个全面的解析。
对原Transformer还不熟悉的读者可以看一下Attention is All You Need原文,中文讲解推荐李宏毅老师的视频 YouTube,BiliBili 个人觉得讲的很明白。
话不多说,直接开始。
下图是ViT的整体框架图,我们在解析代码时会参照此图:
以下是文中给出的符号公式,也是我们解析的重要参照:
z = [ x c l a s s ; x p 1 E , x p 2 E , … ; x p N E ] + E p o s , E ∈ R ( P 2 ⋅ C ) × D , E p o s ∈ R ( N + 1 ) × D ( 1 ) \mathbf{z}=[\mathbf{x}_{class};\mathbf{x}^1_p\mathbf{E},\mathbf{x}^2_p\mathbf{E},\dots;\mathbf{x}^N_p\mathbf{E}]+\mathbf{E}_{pos},\ \ \ \mathbf{E}\in\mathbb{R}^{(P^2\cdot C)\times D},\mathbf{E}_{pos}\in \mathbb{R}^{(N+1)\times D} \ \ \ \ \ \ \ \ \ \ \ \ \ (1)
z=[x
class
;x
p
1
E,x
p
2
E,…;x
p
N
E]+E
pos
, E∈R
(P
2
⋅C)×D
,E
pos
∈R
(N+1)×D
(1)
z ℓ ′ = M S A ( L N ( z ℓ − 1 ) ) + z ℓ − 1 ( 2 ) \mathbf{z'_\ell}=MSA(LN(\mathbf{z}_{\ell-1}))+\mathbf{z}_{\ell-1}\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ (2)
z
ℓ
′
=MSA(LN(z
ℓ−1
))+z
ℓ−1
(2)
z ℓ = M L P ( L N ( z ′ ℓ ) ) + z ′ ℓ ( 3 ) \mathbf{z}_{\ell}=MLP(LN(\mathbf{z'}_{\ell}))+\mathbf{z'}_{\ell}\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ (3)
z
ℓ
=MLP(LN(z
′
ℓ
))+z
′
ℓ
(3)
y = L N ( z L 0 ) ( 4 ) \mathbf{y}=LN(\mathbf{z}_L^0)\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ (4)
y=LN(z
L
0
) (4)
导入需要的包
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
1
2
3
4
5
6
都是搭建网络时常用的PyTorch包,其中在卷积神经网络的搭建中并不常用的einops和einsum,还不熟悉的读者可以参考博客:einops和einsum:直接操作张量的利器。
pair函数
def pair(t):
return t if isinstance(t, tuple) else (t, t)
1
2
3
作用是:判断t是否是元组,如果是,直接返回t;如果不是,则将t复制为元组(t, t)再返回。
用来处理