对应论文:An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
直接看代码
首先看Transformer 类
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
]))
#ModuleList是一个存储不同module,并自动将每个模块的参数添加到网络之中的容器
#与sequential的区别是,它的模块之间并没有先后顺序,运行时可以改
def forward(self, x, mask = None):
for attn, ff in self.layers:
x = attn(x, mask = mask)
x = ff(x)
return x
self.layers中有多个类定义的对象,按照执行顺序,逐一解释。
Attention类
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads#
self.heads = heads
self.scale = dim ** -0.5
#dim是线性变换后输出张量的最后维度
self.to_qkv