目录
前言
关于Swin Transformer的讲解和实战,实际上网络上已经有很多了。不过有一些代码跑起来可能有一些问题(有一些确实有点问题,或者没头没尾的)。
最初的时候,我通过一些调研,参照网上的一些教程,跑的时候也遇到了一些问题,但是最后确实是成功了。下面我就详细地来讲述应该怎么做。
关于Swin Transformer的基础知识就不再赘述了。相信想到用Swin Transformer来实战的同学肯定已经多多少少对其有一定了解了。
在此,我说一下我的实战的思路:
从官网拿到代码,然后改改,换成自己的数据集,加载它的预训练权重,然后让代码跑起来。
如果你的coding能力确实比较强,那么你完全可以从官网上找到部分Swin Transformer的Model部分核心代码,然后数据处理部分、跑模型的部分都自己来写,这样做也完全OK。但是对能力要求较高,并且对模型的理解要求也比较高。比如,参考某B站Up主的视频,它的代码就是这样子的(在别人的文章里应该都能看的到,我就不重复了)。
而我们在这里就介绍傻瓜式的操作。
我的环境:
- Win 10
- Python 3.8
- Pytorch/torchvision 1.13.1+cu116
- NVIDIA GeForce RTX 3060(CUDA 11.7.102)
- Pycharm Community;
OK,开始。
一、从官网拿到源码,然后配置自己缺少的环境。
论文: https://arxiv.org/abs/2103.14030
代码: https://github.com/microsoft/Swin-Transformer
注意下,这里的fused window process、还有apex等不安装也是可以跑通的。它们的安装不影响代码的运行。如果你最后对性能有很高的要求,那你再去下载,我们这里主要是学习,然后让它先跑起来。
你要最起码确保在data文件夹下、models文件夹下和最外层的所有py文件都没有依赖报错(就是导入包的报错)

就是如上图那些的一些包,你给它都下载好不报错就行了。或者你新搞个虚拟环境,然后重新装一下就行,怎么搞虚拟环境可以参考这篇文章【正在更新中...】。
针对可能遇到的错误:
注意,这个错误只是可能遇到,不是一定会遇到。它和你下载的Pytorch的版本有关系。并且你的分类数如果大于5,应该是不会报错的。
那我们需要做什么呢?就是你可能需要改一下你的accuracy函数。
说一下这个函数的入口在哪找,因为这个函数并不是Swin-transformer的函数,它是Pytorch内置的文件函数,所以它原本是只读的(只是有写保护,并非不可更改)。那么我们从哪里找呢?
找到main.py->函数validate,有一行
acc1, acc5 = accuracy(output, target, topk=(1, 5))

鼠标点击accuracy函数,然后按ctrl B就可以了(转到定义)
然后在metrics.py(这是个只读文件)里的这个函数accuracy,它应该是这个样子的(如下图):

但有些小伙伴该函数的第一行是这个样子的:
maxk = max(topk)
然后就导致如果你训练过程中,如果分类数比较少(比如二分类),那它就会报类似于“索引k超出范围”这样的错误。你把它这一行给改成上面截图的那个样子就行:
maxk = min(max(topk), output.size()[1]) //应该是这个样子的,否则二分类会报错
OK了家人们,现在我们环境配置就这样说完了。拿到项目和配置环境相信都是最基本的,没有什么好说的了哈。
二、数据集获取与处理
注意,我们是要用预训练权重去跑我们自己是数据集,所以不要傻乎乎的去下载ImgNet 1K,更不要傻乎乎地去下载ImgNet 22K哈哈哈,这些都是官方在一开始训练swin transformer的时候所用到的数据集,如果我们用预训练权重来去训练自己的训练集的话,是不需要下载这些东西的了。
我们这里,就以猫狗数据集为例来为大家说下数据集怎么整。
我们就以Kaggle猫狗大战数据集为例。

本文详细介绍了如何从官网获取SwinTransformer源码,配置环境,处理数据集,下载预训练权重,以及如何修改参数以适应自己的任务,包括解决可能出现的accuracy函数问题。
最低0.47元/天 解锁文章
1万+

被折叠的 条评论
为什么被折叠?



