BasicSR论文复现:手把手教你实现SwinIR超分辨率模型
你是否还在为低分辨率图片模糊不清而烦恼?是否想让老照片恢复清晰细节?本文将带你从零开始,使用BasicSR工具箱实现SwinIR超分辨率模型,让你轻松将普通图片转换为高清版本。读完本文,你将掌握SwinIR模型的核心原理、实现步骤以及训练测试方法,成为超分辨率领域的实践者。
SwinIR模型原理概述
SwinIR(Swin Transformer for Image Restoration)是一种基于Swin Transformer的图像恢复模型,它在超分辨率、去噪、去模糊等任务中表现出色。与传统的卷积神经网络(CNN)相比,SwinIR采用了窗口注意力机制,能够更好地捕捉图像的长距离依赖关系,从而生成更清晰、更自然的高分辨率图像。
SwinIR的核心创新点包括:
- 窗口注意力机制:将图像分割成不重叠的窗口,在每个窗口内计算注意力,减少计算量
- 移位窗口:通过窗口移位操作,增强不同窗口之间的信息交互
- 残差Swin Transformer块(RSTB):引入残差连接,提高模型的训练稳定性和性能
SwinIR模型架构
环境准备与项目结构
在开始实现SwinIR之前,我们需要先搭建开发环境并了解BasicSR项目结构。
环境搭建
首先,克隆BasicSR项目仓库:
git clone https://gitcode.com/gh_mirrors/ba/BasicSR.git
cd BasicSR
然后安装所需依赖:
pip install -r requirements.txt
项目结构
BasicSR的主要目录结构如下:
- basicsr/archs/:存放各种模型架构,包括SwinIR
- basicsr/models/:存放模型训练和测试相关代码
- options/:存放训练和测试的配置文件
- inference/:存放推理脚本
- scripts/:存放数据准备、模型转换等辅助脚本
SwinIR相关的核心文件包括:
- basicsr/archs/swinir_arch.py:SwinIR模型架构实现
- basicsr/models/swinir_model.py:SwinIR模型训练测试类
- options/train/SwinIR/train_SwinIR_SRx4_scratch.yml:SwinIR训练配置文件
- inference/inference_swinir.py:SwinIR推理脚本
SwinIR核心代码实现
窗口注意力机制
窗口注意力是SwinIR的核心组件,它将图像分割成固定大小的窗口,在每个窗口内计算注意力。以下是窗口注意力的实现代码:
class WindowAttention(nn.Module):
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# 相对位置偏置参数表
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# 获取窗口内每个token的相对位置索引
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # 从0开始
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer('relative_position_index', relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
b_, n, c = x.shape
qkv = self.qkv(x).reshape(b_, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # (b_, num_heads, n, head_dim)
q = q * self.scale
attn = (q @ k.transpose(-2, -1)) # (b_, num_heads, n, n)
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nw = mask.shape[0]
attn = attn.view(b_ // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, n, n)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(b_, n, c) # (b_, n, c)
x = self.proj(x)
x = self.proj_drop(x)
return x
Swin Transformer块
Swin Transformer块是SwinIR的基本构建单元,它由窗口注意力和MLP组成,并使用残差连接:
class SwinTransformerBlock(nn.Module):
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
# 如果窗口大小大于输入分辨率,则不分区窗口
if min(self.input_resolution) <= self.window_size:
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size必须在[0, window_size)范围内"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
# 如果shift_size > 0,则计算注意力掩码
if self.shift_size > 0:
attn_mask = self.calculate_mask(self.input_resolution)
else:
attn_mask = None
self.register_buffer('attn_mask', attn_mask)
def calculate_mask(self, x_size):
# 计算SW-MSA的注意力掩码
h, w = x_size
img_mask = torch.zeros((1, h, w, 1)) # 1 h w 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nw, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
return attn_mask
def forward(self, x, x_size):
h, w = x_size
b, _, c = x.shape
shortcut = x
x = self.norm1(x)
x = x.view(b, h, w, c)
# 循环移位
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
# 分窗口
x_windows = window_partition(shifted_x, self.window_size) # nw*b, window_size, window_size, c
x_windows = x_windows.view(-1, self.window_size * self.window_size, c) # nw*b, window_size*window_size, c
# W-MSA/SW-MSA
if self.input_resolution == x_size:
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nw*b, window_size*window_size, c
else:
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
# 合并窗口
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, c)
shifted_x = window_reverse(attn_windows, self.window_size, h, w) # b h w c
# 反向循环移位
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(b, h * w, c)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
残差Swin Transformer块(RSTB)
RSTB是SwinIR的核心组件,它由多个Swin Transformer块组成,并引入了残差连接:
class RSTB(nn.Module):
"""Residual Swin Transformer Block (RSTB)."""
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
img_size=224, patch_size=4, resi_connection='1conv'):
super(RSTB, self).__init__()
self.dim = dim
self.input_resolution = input_resolution
self.residual_group = BasicLayer(
dim=dim,
input_resolution=input_resolution,
depth=depth,
num_heads=num_heads,
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path,
norm_layer=norm_layer,
downsample=downsample,
use_checkpoint=use_checkpoint)
if resi_connection == '1conv':
self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
elif resi_connection == '3conv':
# 节省参数和内存
self.conv = nn.Sequential(
nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(dim // 4, dim, 3, 1, 1))
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
self.patch_unembed = PatchUnEmbed(
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
def forward(self, x, x_size):
return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
SwinIR完整模型
SwinIR完整模型由浅层特征提取、深层特征提取和上采样模块组成:
@ARCH_REGISTRY.register()
class SwinIR(nn.Module):
def __init__(self, img_size=64, patch_size=1, in_chans=3,
embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
use_checkpoint=False, upscale=4, img_range=1., upsampler='pixelshuffle',
resi_connection='1conv'):
super(SwinIR, self).__init__()
self.img_range = img_range
self.upscale = upscale
self.upsampler = upsampler
self.window_size = window_size
# 浅层特征提取
self.conv_first = nn.Conv2d(in_chans, embed_dim, 3, 1, 1)
# 深层特征提取 (SwinIR)
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.num_features = embed_dim
self.mlp_ratio = mlp_ratio
# 分割图像为不重叠的补丁
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
# 位置嵌入
if self.ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
trunc_normal_(self.absolute_pos_embed, std=.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
# 构建RSTB
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = RSTB(
dim=embed_dim,
input_resolution=(patches_resolution[0], patches_resolution[1]),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
norm_layer=norm_layer,
downsample=None,
use_checkpoint=use_checkpoint,
img_size=img_size,
patch_size=patch_size,
resi_connection=resi_connection)
self.layers.append(layer)
self.norm = norm_layer(self.num_features)
# 重建
if self.upsampler == 'pixelshuffle':
# 像素洗牌上采样
self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
self.upsample = Upsample(upscale, embed_dim)
self.conv_last = nn.Conv2d(embed_dim, in_chans, 3, 1, 1)
elif self.upsampler == 'pixelshuffledirect':
# 一步像素洗牌上采样
self.conv_after_body = nn.Conv2d(embed_dim, embed_dim * upscale * upscale, 3, 1, 1)
self.pixel_shuffle = nn.PixelShuffle(upscale)
self.conv_last = nn.Conv2d(embed_dim, in_chans, 3, 1, 1)
elif self.upsampler == 'nearest+conv':
# 最近邻上采样 + 卷积
self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
self.upsample = nn.Upsample(scale_factor=upscale, mode='nearest')
self.conv_last = nn.Conv2d(embed_dim, in_chans, 3, 1, 1)
else:
# 无 Upsampler
self.conv_after_body = nn.Conv2d(embed_dim, in_chans, 3, 1, 1)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'absolute_pos_embed'}
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {'relative_position_bias_table'}
def forward_features(self, x):
x_size = (x.shape[2], x.shape[3])
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x, x_size)
x = self.norm(x) # (B, N, C)
x = x.transpose(1, 2).view(x.shape[0], self.num_features, x_size[0] // self.patch_embed.patch_size[0], x_size[1] // self.patch_embed.patch_size[1])
return x
def forward(self, x):
H, W = x.shape[2], x.shape[3]
x = self.conv_first(x)
x = self.forward_features(x)
if self.upsampler == 'pixelshuffle':
x = self.conv_after_body(x)
x = self.upsample(x)
x = self.conv_last(x)
elif self.upsampler == 'pixelshuffledirect':
x = self.conv_after_body(x)
x = self.pixel_shuffle(x)
x = self.conv_last(x)
elif self.upsampler == 'nearest+conv':
x = self.conv_after_body(x)
x = self.upsample(x)
x = self.conv_last(x)
else:
x = self.conv_after_body(x)
# 裁剪到原始大小
x = x[:, :, :H*self.upscale, :W*self.upscale]
return x
模型训练与测试
训练配置文件
在BasicSR中,模型训练通过配置文件进行设置。SwinIR的训练配置文件位于options/train/SwinIR/train_SwinIR_SRx4_scratch.yml,主要参数包括:
# 基本设置
exp_name: SwinIR_SRx4_scratch
scale: 4
num_gpu: 1 # 自动检测GPU数量
# 数据集设置
datasets:
train:
name: DIV2K
type: PairedImageDataset
dataroot_gt: datasets/DIV2K/DIV2K_train_HR
dataroot_lq: datasets/DIV2K/DIV2K_train_LR_bicubic/X4
filename_tmpl: '{}x4'
io_backend:
type: disk
# 数据增强
use_flip: true
use_rot: true
val:
name: Set5
type: PairedImageDataset
dataroot_gt: datasets/Set5/GTmod12
dataroot_lq: datasets/Set5/LRbicx4
io_backend:
type: disk
# 网络结构
network_g:
type: SwinIR
upscale: 4
in_chans: 3
img_size: 64
window_size: 8
img_range: 1.
depths: [6, 6, 6, 6, 6, 6]
embed_dim: 180
num_heads: [6, 6, 6, 6, 6, 6]
mlp_ratio: 2
upsampler: 'pixelshuffle'
resi_connection: '1conv'
# 训练设置
train:
batch_size: 16
num_iter: 300000
val_interval: 5000
# 优化器
optimizer:
type: Adam
lr: !!float 2e-4
weight_decay: 0
betas: [0.9, 0.999]
# 学习率调度器
scheduler:
type: CosineAnnealingRestartLR
periods: [50000, 50000, 50000, 50000, 50000, 50000]
restart_weights: [1, 1, 1, 1, 1, 1]
eta_min: !!float 1e-7
# 损失函数
loss:
type: L1Loss
数据准备
BasicSR提供了数据准备脚本,可以帮助我们下载和预处理数据集。使用以下命令准备DIV2K数据集:
python scripts/data_preparation/download_datasets.py --dataset DIV2K
如果需要生成低分辨率图像,可以使用Matlab脚本:
matlab -nodesktop -nosplash -r "generate_LR_Vimeo90K; exit"
模型训练
使用以下命令启动SwinIR模型训练:
python basicsr/train.py -opt options/train/SwinIR/train_SwinIR_SRx4_scratch.yml
训练过程中,模型参数和日志会保存在experiments/SwinIR_SRx4_scratch目录下。
模型测试
模型训练完成后,可以使用以下命令进行测试:
python inference/inference_swinir.py -opt options/test/SwinIR/test_SwinIR_SRx4.yml
测试配置文件options/test/SwinIR/test_SwinIR_SRx4.yml的主要内容包括:
# 基本设置
exp_name: test_SwinIR_SRx4
scale: 4
num_gpu: 1 # 自动检测GPU数量
# 数据集设置
datasets:
test_1:
name: Set5
type: PairedImageDataset
dataroot_gt: datasets/Set5/GTmod12
dataroot_lq: datasets/Set5/LRbicx4
io_backend:
type: disk
# 网络结构
network_g:
type: SwinIR
upscale: 4
in_chans: 3
img_size: 64
window_size: 8
img_range: 1.
depths: [6, 6, 6, 6, 6, 6]
embed_dim: 180
num_heads: [6, 6, 6, 6, 6, 6]
mlp_ratio: 2
upsampler: 'pixelshuffle'
resi_connection: '1conv'
# 路径设置
path:
pretrain_network_g: experiments/SwinIR_SRx4_scratch/models/net_g_latest.pth
strict_load: true
# 测试设置
test:
tile: 64
tile_overlap: 32
模型推理
除了使用配置文件进行测试外,BasicSR还提供了简单易用的推理脚本。以下是使用SwinIR进行图像超分辨率的示例代码:
import cv2
import numpy as np
import torch
from basicsr.archs.swinir_arch import SwinIR
from basicsr.utils.img_util import img2tensor, tensor2img
# 加载模型
model = SwinIR(
upscale=4,
in_chans=3,
img_size=64,
window_size=8,
img_range=1.,
depths=[6, 6, 6, 6, 6, 6],
embed_dim=180,
num_heads=[6, 6, 6, 6, 6, 6],
mlp_ratio=2,
upsampler='pixelshuffle',
resi_connection='1conv'
)
model.load_state_dict(torch.load('experiments/SwinIR_SRx4_scratch/models/net_g_latest.pth')['params'])
model.eval()
model = model.to('cuda')
# 加载低分辨率图像
img_lq = cv2.imread('test_lr.png', cv2.IMREAD_COLOR)
img_lq = img_lq.astype(np.float32) / 255.
img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True).unsqueeze(0).to('cuda')
# 推理
with torch.no_grad():
output = model(img_lq)
# 保存高分辨率图像
output = tensor2img(output, rgb2bgr=True, min_max=(0, 1))
cv2.imwrite('test_sr.png', output)
性能评估
为了评估SwinIR模型的性能,我们可以使用BasicSR提供的指标计算脚本:
python scripts/metrics/calculate_psnr_ssim.py --gt_path datasets/Set5/GTmod12 --restored_path results/SwinIR_SRx4_scratch/Set5 --crop_border 4
该脚本会计算PSNR和SSIM指标,评估超分辨率结果的质量。在Set5数据集上,SwinIR模型通常可以达到32dB以上的PSNR值,远高于传统方法。
总结与展望
本文详细介绍了如何使用BasicSR工具箱实现SwinIR超分辨率模型,包括模型原理、核心代码、训练测试和推理过程。通过本文的学习,你已经掌握了SwinIR的实现方法,可以将其应用于自己的项目中。
SwinIR作为一种基于Transformer的超分辨率模型,在性能上超越了许多传统的CNN模型。未来,我们可以尝试以下改进方向:
- 结合注意力机制和卷积操作,进一步提高模型性能
- 探索更有效的预训练方法,减少训练数据需求
- 将SwinIR应用于视频超分辨率等更复杂的任务
希望本文能够帮助你进入超分辨率领域的研究和应用。如果你有任何问题或建议,欢迎在评论区留言讨论。别忘了点赞、收藏、关注,获取更多超分辨率技术的实践教程!
参考资料
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考



