由于文章篇幅太长了,这里分成两篇文章“上”和“下”进行学习,这是“上”篇!
本文来源公众号“江大白”,仅用于学术分享,侵权删,干货满满。
原文链接:2W字长文,带你深入浅出视觉Transformer
以下文章来源于知乎:深度眸@知乎
作者:深度眸
链接:https://zhuanlan.zhihu.com/p/308301901
本文仅用于学术分享,如有侵权,请联系后台作删文处理
导读
Transformer整个网络结构完全由Attention机制组成,其出色的性能在多个任务上都取得了非常好的效果。本文从Transformer的结构出发,结合视觉中的成果进行了分析,能够帮助大家快速入门。
0 摘要
transformer结构是google在17年的Attention Is All You Need论文中提出,在NLP的多个任务上取得了非常好的效果,可以说目前NLP发展都离不开transformer。最大特点是抛弃了传统的CNN和RNN,整个网络结构完全是由Attention机制组成。由于其出色性能以及对下游任务的友好性或者说下游任务仅仅微调即可得到不错效果,在计算机视觉领域不断有人尝试将transformer引入,近期也出现了一些效果不错的尝试,典型的如目标检测领域的detr和可变形detr,分类领域的vision transformer等等。本文从transformer结构出发,结合视觉中的transformer成果(具体是vision transformer和detr)进行分析,希望能够帮助cv领域想了解transformer的初学者快速入门。由于本人接触transformer时间也不长,也算初学者,故如果有描述或者理解错误的地方欢迎指正。
本文的大部分图来自论文、国外博客和国内翻译博客,在此一并感谢前人工作,具体链接见参考资料。本文特别长,大概有3w字,请先点赞收藏然后慢慢看....
1 transformer介绍
一般讲解transformer都会以机器翻译任务为例子讲解,机器翻译任务是指将一种语言转换得到另一种语言,例如英语翻译为中文任务。从最上层来看,如下所示:
1.1 早期seq2seq
机器翻译是一个历史悠久的问题,本质可以理解为序列转序列问题,也就是我们常说的seq2seq结构,也可以称为encoder-decoder结构,如下所示:
encoder和decoder在早期一般是RNN模块(因为其可以捕获时序信息),后来引入了LSTM或者GRU模块,不管内部组件是啥,其核心思想都是通过Encoder编码成一个表示向量,即上下文编码向量,然后交给Decoder来进行解码,翻译成目标语言。一个采用典型RNN进行编码码翻译的可视化图如下:
可以看出,其解码过程是顺序进行,每次仅解码出一个单词。对于CV领域初学者来说,RNN模块构建的seq2seq算法,理解到这个程度就可以了,不需要深入探讨如何进行训练。但是上述结构其实有缺陷,具体来说是:
-
不论输入和输出的语句长度是什么,中间的上下文向量长度都是固定的,一旦长度过长,仅仅靠一个固定长度的上下文向量明显不合理
-
仅仅利用上下文向量解码,会有信息瓶颈,长度过长时候信息可能会丢失
通俗理解是编码器与解码器的连接点仅仅是编码单元输出的隐含向量,其包含的信息有限,对于一些复杂任务可能信息不够,如要翻译的句子较长时,一个上下文向量可能存不下那么多信息,就会造成翻译精度的下降。
1.2 基于attention的seq2seq
基于上述缺陷进而提出带有注意力机制Attention的seq2seq,同样可以应用于RNN、LSTM或者GRU模块中。注意力机制Attention对人类来说非常好理解,假设给定一张图片,我们会自动聚焦到一些关键信息位置,而不需要逐行扫描全图。此处的attention也是同一个意思,其本质是对输入的自适应加权,结合cv领域的senet中的se模块就能够理解了。
se模块最终是学习出一个1x1xc的向量,然后逐通道乘以原始输入,从而对特征图的每个通道进行加权即通道注意力,对attention进行抽象,不管啥领域其机制都可以归纳为下图:
将Query(通常是向量)和4个Key(和Q长度相同的向量)分别计算相似性,然后经过softmax得到q和4个key相似性的概率权重分布,然后对应权重乘以Value(和Q长度相同的向量),最后相加即可得到包含注意力的attention值输出,理解上应该不难。举个简单例子说明:
-
假设世界上所有小吃都可以被标签化,例如微辣、特辣、变态辣、微甜、有嚼劲....,总共有1000个标签,现在我想要吃的小吃是[微辣、微甜、有嚼劲],这三个单词就是我的Query
-
来到东门老街一共100家小吃点,每个店铺卖的东西不一样,但是肯定可以被标签化,例如第一家小吃被标签化后是[微辣、微咸],第二家小吃被标签化后是[特辣、微臭、特咸],第二家小吃被标签化后是[特辣、微甜、特咸、有嚼劲],其余店铺都可以被标签化,每个店铺的标签就是Keys,但是每家店铺由于卖的东西不一样,单品种类也不一样,所以被标签化后每一家的标签List不一样长
-
Values就是每家店铺对应的单品,例如第一家小吃的Values是[烤羊肉串、炒花生]
-
将Query和所有的Keys进行一一比对,相当于计算相似性,此时就可以知道我想买的小吃和每一家店铺的匹配情况,最后有了匹配列表,就可以去店铺里面买东西了(Values和相似性加权求和)。最终的情况可能是,我在第一家店铺买了烤羊肉串,然后在第10家店铺买了个玉米,最后在第15家店铺买了个烤面筋
以上就是完整的注意力机制,采用我心中的标准Query去和被标签化的所有店铺Keys一一比对,此时就可以得到我的Query在每个店铺中的匹配情况,最终去不同店铺买不同东西的过程就是权重和Values加权求和过程。简要代码如下:
# 假设q是(1,N,512),N就是最大标签化后的list长度,k是(1,M,512),M可以等于N,也可以不相等
# (1,N,512) x (1,512,M)-->(1,N,M)
attn = torch.matmul(q, k.transpose(2, 3))
# softmax转化为概率,输出(1,N,M),表示q中每个n和每个m的相关性
attn=F.softmax(attn, dim=-1)
# (1,N,M) x (1,M,512)-->(1,N,512),V和k的shape相同
output = torch.matmul(attn, v)
带有attention的RNN模块组成的ser2seq,解码时候可视化如下:
在没有attention时候,不同解码阶段都仅仅利用了同一个编码层的最后一个隐含输出,加入attention后可以通过在每个解码时间步输入的都是不同的上下文向量,以上图为例,解码阶段会将第一个开启解码标志<START>(也就是Q)与编码器的每一个时间步的隐含状态(一系列Key和Value)进行点乘计算相似性得到每一时间步的相似性分数,然后通过softmax转化为概率分布,然后将概率分布和对应位置向量进行加权求和得到新的上下文向量,最后输入解码器中进行解码输出,其详细解码可视化如下:
通过上述简单的attention引入,可以将机器翻译性能大幅提升,引入att