逐步解释Swin-Transformer源码
![]()
这一行没什么好说的,就是设置参数值,并读取yaml里的模型参数。

是否采用混合精度训练,混合精度训练可以给合适的参数设置合适的位数,而不是全部为float32

分布式训练。

这一段就是什么梯度累加,学习率缩放等一些优化策略。跳到main函数中

数据预处理,logger写日志的模块,暂时不用管。
![]()
跳到构建模型中,详细讲解这一部分。

再跳到SwinTransformer中,再看模型的构建代码时,先整体看一下整体一个架构

这个图想必大家都看到过很多次了,首先输入图片(H*W*3),输入到一个Patch Partition中,这个模块SwinTrasfomer中与VIT相同但又略微不同,他将图片按照4*4划分为一个token,送到Linear Embedding中,进行维度变化,利用一个卷积,将通道数变为初始设置的C,源码中为96.
再送到Swintransformer Block中运算,需要注意的是,Block并不会改变输入图像的维度。每一个模块下面的*2,*2,*6,*2就是叠加了多少个Block,patch Merging就是将4*4变为8*8,再将通道数变为2C,以此类推。
Block中主要是一个W-MSA和一个SW-MSA,这就是Swintra

本文详细解析Swin-Transformer的源码,重点介绍了其核心的窗口注意力机制(W-MSA)和窗口间注意力机制(SW-MSA),以及相对位置编码的实现。通过将图像划分为小块并应用线性嵌入,再通过SwinTransformerBlock和PatchMerging进行多层次的特征提取。模型的构建包括多个Stage,每个Stage由多个Block组成,Block内部结合W-MSA和SW-MSA,实现高效且具有上下文关联的特征学习。此外,还探讨了模型训练过程中的优化策略和数据预处理步骤。
最低0.47元/天 解锁文章
1372

被折叠的 条评论
为什么被折叠?



