SwinIR代码精读:窗口注意力机制的实现原理与工程实践
引言:超越传统CNN的图像超分范式
在图像处理领域,卷积神经网络(Convolutional Neural Network, CNN)长期占据主导地位,但受限于局部感受野,难以建模长距离依赖关系。2021年提出的SwinIR(Swin Transformer for Image Restoration)首次将Swin Transformer架构应用于图像恢复任务,通过创新性的窗口注意力(Window-based Multi-Head Self-Attention, W-MSA)机制,在图像超分辨率(Super-Resolution, SR)、降噪和压缩 artifact 去除等任务上取得突破。本文将深入剖析network_swinir.py中窗口注意力机制的实现细节,揭示其如何在保持计算效率的同时捕获图像全局信息。
窗口注意力机制的核心设计
1. 模块概览:从输入到输出的数据流
窗口注意力机制在SwinIR中通过WindowAttention类实现,其核心功能是在局部窗口内计算自注意力,同时引入相对位置偏置增强空间感知能力。以下是该模块的关键组件:
2. 窗口划分与合并:降低计算复杂度的关键
SwinIR通过将特征图划分为非重叠窗口,将全局注意力计算转化为局部窗口内的注意力计算,使复杂度从$O(HW)^2$降至$O((HW/M^2)M^4) = O(HWM^2)$(其中$M$为窗口大小)。这一过程通过window_partition和window_reverse函数实现:
def window_partition(x, window_size):
"""将特征图划分为窗口
Args:
x: (B, H, W, C) 输入特征图
window_size (int): 窗口大小
Returns:
windows: (num_windows*B, window_size, window_size, C) 划分后的窗口
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""将窗口合并回特征图
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): 窗口大小
H (int): 原始图像高度
W (int): 原始图像宽度
Returns:
x: (B, H, W, C) 合并后的特征图
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
示例: 对于$64 \times 64$的特征图和$8 \times 8$的窗口大小,将生成$8 \times 8 = 64$个窗口,每个窗口包含$64$个像素,总计算量从$4096^2 = 16,777,216$降至$64 \times 64^2 = 262,144$,复杂度降低64倍。
3. 相对位置偏置:突破绝对位置编码的局限性
传统Transformer采用绝对位置编码,难以适应不同分辨率输入。SwinIR通过相对位置偏置(Relative Position Bias)解决这一问题,其核心思想是:两个像素的注意力权重不仅取决于内容相似度,还取决于它们的相对位置。
3.1 相对位置索引计算
# 生成窗口内坐标
coords_h = torch.arange(window_size[0]) # [0, 1, ..., M-1]
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, M, M
coords_flatten = torch.flatten(coords, 1) # 2, M*M
# 计算相对坐标 (M*M, M*M, 2)
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # (M*M, M*M, 2)
# 坐标偏移与编码
relative_coords[:, :, 0] += window_size[0] - 1 # 偏移至非负
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # (M*M, M*M)
对于$M \times M$窗口,相对位置索引范围为$[0, (2M-1)(2M-1)-1]$,共$(2M-1)^2$种可能的相对位置。
3.2 偏置参数表与注意力计算
# 定义相对位置偏置参数表 ( (2M-1)(2M-1), num_heads )
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
)
# 在注意力分数中加入相对位置偏置
attn = (q @ k.transpose(-2, -1)) # (B, num_heads, N, N)
relative_position_bias = self.relative_position_bias_table[relative_position_index.view(-1)].view(
window_size[0]*window_size[1], window_size[0]*window_size[1], -1
) # (N, N, num_heads)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # (num_heads, N, N)
attn = attn + relative_position_bias.unsqueeze(0) # (B, num_heads, N, N)
窗口注意力的前向传播流程
1. 完整前向计算步骤
2. 关键代码解析
def forward(self, x, mask=None):
B_, N, C = x.shape # B_ = B * num_windows, N = window_size^2
# 1. QKV线性变换与拆分
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # (B_, num_heads, N, C//num_heads)
# 2. 缩放Q与计算注意力分数
q = q * self.scale
attn = (q @ k.transpose(-2, -1)) # (B_, num_heads, N, N)
# 3. 加入相对位置偏置
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
N, N, -1
).permute(2, 0, 1).contiguous() # (num_heads, N, N)
attn = attn + relative_position_bias.unsqueeze(0) # (B_, num_heads, N, N)
# 4. 应用掩码 (用于移位窗口注意力)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
# 5. 注意力 dropout 与输出投影
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C) # (B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
移位窗口注意力:解决窗口边界问题
1. 移位窗口的必要性
固定窗口划分会导致窗口间信息隔离。SwinIR通过交替使用非移位窗口(non-shifted window)和移位窗口(shifted window)解决此问题,形成Swin Transformer Block的基本单元。
2. 掩码计算:处理移位窗口的重叠问题
当窗口移位后,部分窗口会跨越特征图边界,需要通过掩码(mask)确保注意力计算仅在有效区域内进行:
def calculate_mask(self, x_size):
H, W = x_size
img_mask = torch.zeros((1, H, W, 1)) # (1, H, W, 1)
# 将特征图分为3×3网格
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
# 为每个网格分配唯一ID
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
# 划分窗口并生成掩码
mask_windows = window_partition(img_mask, self.window_size) # (nW, M, M, 1)
mask_windows = mask_windows.view(-1, self.window_size*self.window_size) # (nW, N)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # (nW, N, N)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
return attn_mask # (nW, N, N)
掩码效果:值为$-100$的区域在softmax后注意力权重趋近于0,实现有效区域隔离。
计算复杂度分析
1. 理论计算量(FLOPs)
WindowAttention类的flops方法提供了精确的计算量评估:
def flops(self, N):
flops = 0
# QKV线性层: N × C × 3C
flops += N * self.dim * 3 * self.dim
# Q×K^T: num_heads × N × (C/num_heads) × N
flops += self.num_heads * N * (self.dim // self.num_heads) * N
# 与V相乘: num_heads × N × N × (C/num_heads)
flops += self.num_heads * N * N * (self.dim // self.num_heads)
# Proj线性层: N × C × C
flops += N * self.dim * self.dim
return flops
示例:对于$C=96$、$num_heads=6$、$N=49$($7 \times 7$窗口)的配置,单窗口FLOPs为:
- QKV: $49 \times 96 \times 3 \times 96 = 13,547,520$
- Q×K^T: $6 \times 49 \times 16 \times 49 = 225,792$
- 与V相乘: $6 \times 49 \times 49 \times 16 = 225,792$
- Proj: $49 \times 96 \times 96 = 451,584$
- 总计: ~14.45 MFLOPs
2. 工程优化:Checkpoint机制
为节省内存,SwinIR引入梯度检查点(gradient checkpointing)技术,通过牺牲少量计算时间换取内存占用的显著降低:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x, x_size) # 仅保存前向传播的部分中间结果
else:
x = blk(x, x_size)
窗口注意力在SwinIR中的应用
1. 残差Swin Transformer块(RSTB)
窗口注意力机制通过Residual Swin Transformer Block(RSTB) 整合到SwinIR的整体架构中,每个RSTB包含多个Swin Transformer Block和残差连接:
class RSTB(nn.Module):
def forward(self, x, x_size):
return self.patch_embed(
self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))
) + x # 残差连接
2. 图像超分任务的输入输出处理
SwinIR针对超分任务设计了完整的输入预处理和输出重建流程:
class SwinIR(nn.Module):
def forward_features(self, x):
# 浅层特征提取
x = self.conv_first(x)
# 深层特征提取(含窗口注意力)
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x, self.patches_resolution)
x = self.norm(x)
x = self.patch_unembed(x, self.patches_resolution)
# 上采样重建
if self.upsampler == 'pixelshuffle':
x = self.conv_after_body(x)
x = self.upsample(x)
x = self.conv_last(x)
return x
性能对比与工程实践建议
1. 与传统CNN的对比优势
| 指标 | SwinIR (窗口注意力) | 传统CNN (如EDSR) |
|---|---|---|
| 感受野 | 全局 (通过多层窗口堆叠) | 局部 (受限于网络深度) |
| 长距离依赖建模 | 强 (通过自注意力机制) | 弱 (需通过堆叠卷积层) |
| 计算复杂度 | $O(HWM^2)$ | $O(HWC^2k^2)$ (k为卷积核大小) |
| 参数效率 | 高 (注意力权重共享) | 中 (固定卷积核参数) |
| 并行性 | 中 (窗口内可并行) | 高 (卷积操作高度并行) |
2. 工程调优建议
-
窗口大小选择:对于高分辨率图像(如$256 \times 256$),建议使用$7 \times 7$或$11 \times 11$窗口;低分辨率图像(如$64 \times 64$)可使用$5 \times 5$窗口平衡精度与速度。
-
硬件适配:在GPU内存有限时(如12GB以下),建议:
- 将
use_checkpoint设为True - 减小
embed_dim(如从96降至64) - 使用较小
num_heads(如从6降至4)
- 将
-
推理加速:
- 启用TensorRT或ONNX Runtime的FP16推理
- 合并连续的
nn.Linear层和nn.LayerNorm层 - 对静态输入分辨率,预计算并固化
relative_position_index
总结与未来展望
窗口注意力机制作为SwinIR的核心创新点,通过局部窗口划分、相对位置偏置和移位窗口技术,成功将Transformer的全局建模能力引入图像恢复领域。其实现细节对理解现代视觉Transformer架构具有重要参考价值。未来可进一步探索:
- 动态窗口大小(根据图像内容自适应调整)
- 注意力权重的稀疏化(降低计算复杂度)
- 与CNN的混合架构(结合局部特征提取优势)
通过本文的代码解析,读者可深入理解窗口注意力的工程实现,并将其应用于自定义图像恢复模型开发中。建议结合network_swinir.py完整代码和官方预训练权重进行实验,进一步验证理论分析。
附录:关键参数配置表
| 参数 | 超分任务 (x4) 默认值 | 降噪任务 默认值 | 作用说明 |
|---|---|---|---|
embed_dim | 96 | 64 | 嵌入维度 |
depths | [6, 6, 6, 6] | [6, 6, 6] | 各层RSTB的深度 |
num_heads | [6, 6, 6, 6] | [4, 4, 4] | 各层注意力头数 |
window_size | 7 | 7 | 窗口大小 |
mlp_ratio | 4.0 | 4.0 | MLP隐藏层维度比例 |
resi_connection | '1conv' | '3conv' | 残差连接中的卷积块配置 |
use_checkpoint | False | True | 是否启用梯度检查点 |
通过调整上述参数,可在不同硬件环境和任务需求下取得精度与效率的平衡。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考



