Swin Transformer实战图像分类(Windows下,无需用到Conda,亲测有效)

本文详细介绍了如何从官网获取SwinTransformer源码,配置环境,处理数据集,下载预训练权重,以及如何修改参数以适应自己的任务,包括解决可能出现的accuracy函数问题。

目录

前言

一、从官网拿到源码,然后配置自己缺少的环境。

针对可能遇到的错误:

二、数据集获取与处理

2.1 数据集下载

2.2 数据集处理

三、下载预训练权重

四、修改部分参数配置

4.1 修改config.py

4.2 修改build.py

4.3 修改units.py

4.4 修改main.py

4.5 修改其他的地方

4.6 将最后结果以折线图的形式呈现出来

五、其他可修改的地方

六、运行代码


前言

关于Swin Transformer的讲解和实战,实际上网络上已经有很多了。不过有一些代码跑起来可能有一些问题(有一些确实有点问题,或者没头没尾的)。

最初的时候,我通过一些调研,参照网上的一些教程,跑的时候也遇到了一些问题,但是最后确实是成功了。下面我就详细地来讲述应该怎么做。

关于Swin Transformer的基础知识就不再赘述了。相信想到用Swin Transformer来实战的同学肯定已经多多少少对其有一定了解了。

在此,我说一下我的实战的思路:

从官网拿到代码,然后改改,换成自己的数据集,加载它的预训练权重,然后让代码跑起来。

如果你的coding能力确实比较强,那么你完全可以从官网上找到部分Swin Transformer的Model部分核心代码,然后数据处理部分、跑模型的部分都自己来写,这样做也完全OK。但是对能力要求较高,并且对模型的理解要求也比较高。比如,参考某B站Up主的视频,它的代码就是这样子的(在别人的文章里应该都能看的到,我就不重复了)。

而我们在这里就介绍傻瓜式的操作。

我的环境:

  • Win 10
  • Python 3.8
  • Pytorch/torchvision 1.13.1+cu116
  • NVIDIA GeForce RTX 3060(CUDA 11.7.102)
  • Pycharm Community;

OK,开始。

一、从官网拿到源码,然后配置自己缺少的环境。

论文: https://arxiv.org/abs/2103.14030

代码: https://github.com/microsoft/Swin-Transformer

注意下,这里的fused window process、还有apex等不安装也是可以跑通的。它们的安装不影响代码的运行。如果你最后对性能有很高的要求,那你再去下载,我们这里主要是学习,然后让它先跑起来。

你要最起码确保在data文件夹下、models文件夹下和最外层的所有py文件都没有依赖报错(就是导入包的报错)

就是如上图那些的一些包,你给它都下载好不报错就行了。或者你新搞个虚拟环境,然后重新装一下就行,怎么搞虚拟环境可以参考这篇文章【正在更新中...】。

针对可能遇到的错误:

注意,这个错误只是可能遇到,不是一定会遇到。它和你下载的Pytorch的版本有关系。并且你的分类数如果大于5,应该是不会报错的。

那我们需要做什么呢?就是你可能需要改一下你的accuracy函数。

说一下这个函数的入口在哪找,因为这个函数并不是Swin-transformer的函数,它是Pytorch内置的文件函数,所以它原本是只读的(只是有写保护,并非不可更改)。那么我们从哪里找呢?

找到main.py->函数validate,有一行

acc1, acc5 = accuracy(output, target, topk=(1, 5))

鼠标点击accuracy函数,然后按ctrl B就可以了(转到定义)

然后在metrics.py(这是个只读文件)里的这个函数accuracy,它应该是这个样子的(如下图):

但有些小伙伴该函数的第一行是这个样子的:

maxk = max(topk)

然后就导致如果你训练过程中,如果分类数比较少(比如二分类),那它就会报类似于“索引k超出范围”这样的错误。你把它这一行给改成上面截图的那个样子就行:

maxk = min(max(topk), output.size()[1])  //应该是这个样子的,否则二分类会报错

OK了家人们,现在我们环境配置就这样说完了。拿到项目和配置环境相信都是最基本的,没有什么好说的了哈。

二、数据集获取与处理

注意,我们是要用预训练权重去跑我们自己是数据集,所以不要傻乎乎的去下载ImgNet 1K,更不要傻乎乎地去下载ImgNet 22K哈哈哈,这些都是官方在一开始训练swin transformer的时候所用到的数据集,如果我们用预训练权重来去训练自己的训练集的话,是不需要下载这些东西的了。

我们这里,就以猫狗数据集为例来为大家说下数据集怎么整。

我们就以Kaggle猫狗大战数据集为例。

2.1 数据集下载<

<think>好的,用户想用Swin Transformer实现图像超分辨率,需要复现代码,使用PyTorch和GitHub资源。首先,我应该回忆一下Swin Transformer的基本结构,它在图像分类中的应用,以及如何调整用于超分辨率任务。 用户提供的引用[1]和[2]是关于图像分类的,但超分辨率可能需要不同的处理方式。我需要先确认Swin Transformer在超分辨率中的常见应用方法。通常,超分辨率模型如SwinIR会使用Swin Transformer块来提取特征,然后结合上采样模块来增加分辨率。 接下来,复现步骤可能需要参考现有的GitHub项目,比如SwinIR。我需要检查是否有官方或社区实现的版本,并确保代码是基于PyTorch的。此外,数据准备方面,用户可能需要使用像DIV2K这样的数据集进行训练,并注意数据预处理步骤,比如将高分辨率图像裁剪成小块作为输入。 模型结构方面,可能需要将Swin Transformer作为主干网络,后面接像素洗牌等上采样方法。损失函数通常使用L1或L2损失,以及可能的感知损失或对抗损失来提升视觉效果。训练时,学习率调整和优化器的选择也很重要,比如使用Adam优化器,并设置适当的学习率衰减策略。 用户可能遇到的问题包括如何修改原有的Swin Transformer分类代码以适应超分辨率任务,如何处理输入输出尺寸的变化,以及如何集成上采样层。因此,在回答时需要提供具体的代码示例,比如定义Swin Transformer层,然后添加上采样模块。 此外,引用现有的资源和代码库是关键,比如推荐用户查阅SwinIR的官方实现,或者社区中优秀的复现项目。同时,数据集的准备和训练技巧也需要详细说明,帮助用户避免常见的陷阱,比如内存不足时调整批大小或使用分布式训练。 最后,需要确保回答结构清晰,分步骤指导用户从环境配置到训练试的整个流程,并附上相关参考文献和代码链接,确保用户能够顺利复现。</think>### 使用Swin Transformer实现图像超分辨率的方法与复现步骤 #### 1. **核心思路** Swin Transformer在图像超分辨率(SR)中主要通过以下方式应用: - **分窗口自注意力**:将图像划分为局部窗口,在窗口内计算自注意力,降低计算复杂度[^1]。 - **层级特征提取**:通过多阶段下采样和上采样模块实现多尺度特征融合。 - **残差连接**:结合残差学习避免梯度消失问题。 #### 2. **代码复现步骤** ##### (1) 环境配置 ```bash # 创建虚拟环境 conda create -n swin_sr python=3.8 conda activate swin_sr # 安装PyTorch和依赖 pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113 pip install timm opencv-python ``` ##### (2) 模型结构示例(基于SwinIR实现) ```python import torch import torch.nn as nn from swin_transformer import SwinTransformerBlock class SwinSR(nn.Module): def __init__(self, upscale=4): super().__init__() # 浅层特征提取 self.conv_head = nn.Conv2d(3, 64, 3, padding=1) # Swin Transformer模块 self.swin_blocks = nn.Sequential( *[SwinTransformerBlock(dim=64, num_heads=4, window_size=8) for _ in range(6)] ) # 上采样模块 self.upsample = nn.Sequential( nn.Conv2d(64, 64*(upscale**2), 3, padding=1), nn.PixelShuffle(upscale), nn.Conv2d(64, 3, 3, padding=1) ) def forward(self, x): x = self.conv_head(x) residual = x x = self.swin_blocks(x) x = residual + x # 残差连接 return self.upsample(x) ``` ##### (3) 关键组件说明 - **SwinTransformerBlock**:核心注意力模块,需实现窗口划分与移位窗口机制 - **PixelShuffle**:亚像素卷积上采样,将通道数转换为空间分辨率 - **残差连接**:保留低频信息,加速模型收敛 #### 3. **数据集准备** 推荐使用标准SR数据集: ```python from torch.utils.data import Dataset class SRDataset(Dataset): def __init__(self, hr_paths, scale=4): self.hr_imgs = [load_image(p) for p in hr_paths] self.scale = scale def __getitem__(self, idx): hr = self.hr_imgs[idx] lr = F.interpolate(hr, scale_factor=1/self.scale, mode='bicubic') return lr, hr # 返回低分辨率输入和高分辨率标签 ``` #### 4. **训练配置** ```python model = SwinSR(upscale=4) criterion = nn.L1Loss() optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) for epoch in range(100): for lr, hr in dataloader: pred = model(lr) loss = criterion(pred, hr) optimizer.zero_grad() loss.backward() optimizer.step() ``` #### 5. **预训练模型与代码资源** 推荐参考项目: 1. **SwinIR官方实现**:https://github.com/JingyunLiang/SwinIR 2. **EDSR基准改进**:https://github.com/sanghyun-son/EDSR-PyTorch 3. **BasicSR框架集成**:https://github.com/xinntao/BasicSR #### 6. **常见问题解决** - **显存不足**:减小`window_size`或`batch_size` - **边缘伪影**:采用反射填充(reflection padding) - **纹理模糊**:添加感知损失(VGG特征匹配) - **训练不稳定**:使用梯度裁剪(gradient clipping) ###
评论 24
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

jxwd

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值