对应论文:An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
直接看代码
首先看Transformer 类
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
]))
#ModuleList是一个存储不同module,并自动将每个模块的参数添加到网络之中的容器
#与sequential的区别是,它的模块之间并没有先后顺序,运行时可以改
def forward(self, x, mask = None):
for attn, ff in self.layers:
x = attn(x, mask = mask)
x = ff(x)
return x
self.layers中有多个类定义的对象,按照执行顺序,逐一解释。
Attention类
class At

最低0.47元/天 解锁文章
3660





