SwinIR代码精读:窗口注意力机制的实现原理与工程实践

SwinIR代码精读:窗口注意力机制的实现原理与工程实践

【免费下载链接】SwinIR SwinIR: Image Restoration Using Swin Transformer (official repository) 【免费下载链接】SwinIR 项目地址: https://gitcode.com/gh_mirrors/sw/SwinIR

引言:超越传统CNN的图像超分范式

在图像处理领域,卷积神经网络(Convolutional Neural Network, CNN)长期占据主导地位,但受限于局部感受野,难以建模长距离依赖关系。2021年提出的SwinIR(Swin Transformer for Image Restoration)首次将Swin Transformer架构应用于图像恢复任务,通过创新性的窗口注意力(Window-based Multi-Head Self-Attention, W-MSA)机制,在图像超分辨率(Super-Resolution, SR)、降噪和压缩 artifact 去除等任务上取得突破。本文将深入剖析network_swinir.py中窗口注意力机制的实现细节,揭示其如何在保持计算效率的同时捕获图像全局信息。

窗口注意力机制的核心设计

1. 模块概览:从输入到输出的数据流

窗口注意力机制在SwinIR中通过WindowAttention类实现,其核心功能是在局部窗口内计算自注意力,同时引入相对位置偏置增强空间感知能力。以下是该模块的关键组件:

mermaid

2. 窗口划分与合并:降低计算复杂度的关键

SwinIR通过将特征图划分为非重叠窗口,将全局注意力计算转化为局部窗口内的注意力计算,使复杂度从$O(HW)^2$降至$O((HW/M^2)M^4) = O(HWM^2)$(其中$M$为窗口大小)。这一过程通过window_partitionwindow_reverse函数实现:

def window_partition(x, window_size):
    """将特征图划分为窗口
    Args:
        x: (B, H, W, C) 输入特征图
        window_size (int): 窗口大小
    Returns:
        windows: (num_windows*B, window_size, window_size, C) 划分后的窗口
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows

def window_reverse(windows, window_size, H, W):
    """将窗口合并回特征图
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): 窗口大小
        H (int): 原始图像高度
        W (int): 原始图像宽度
    Returns:
        x: (B, H, W, C) 合并后的特征图
    """
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x

示例: 对于$64 \times 64$的特征图和$8 \times 8$的窗口大小,将生成$8 \times 8 = 64$个窗口,每个窗口包含$64$个像素,总计算量从$4096^2 = 16,777,216$降至$64 \times 64^2 = 262,144$,复杂度降低64倍。

3. 相对位置偏置:突破绝对位置编码的局限性

传统Transformer采用绝对位置编码,难以适应不同分辨率输入。SwinIR通过相对位置偏置(Relative Position Bias)解决这一问题,其核心思想是:两个像素的注意力权重不仅取决于内容相似度,还取决于它们的相对位置

3.1 相对位置索引计算
# 生成窗口内坐标
coords_h = torch.arange(window_size[0])  # [0, 1, ..., M-1]
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, M, M
coords_flatten = torch.flatten(coords, 1)  # 2, M*M

# 计算相对坐标 (M*M, M*M, 2)
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # (M*M, M*M, 2)

# 坐标偏移与编码
relative_coords[:, :, 0] += window_size[0] - 1  # 偏移至非负
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = relative_coords.sum(-1)  # (M*M, M*M)

对于$M \times M$窗口,相对位置索引范围为$[0, (2M-1)(2M-1)-1]$,共$(2M-1)^2$种可能的相对位置。

3.2 偏置参数表与注意力计算
# 定义相对位置偏置参数表 ( (2M-1)(2M-1), num_heads )
self.relative_position_bias_table = nn.Parameter(
    torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
)

# 在注意力分数中加入相对位置偏置
attn = (q @ k.transpose(-2, -1))  # (B, num_heads, N, N)
relative_position_bias = self.relative_position_bias_table[relative_position_index.view(-1)].view(
    window_size[0]*window_size[1], window_size[0]*window_size[1], -1
)  # (N, N, num_heads)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # (num_heads, N, N)
attn = attn + relative_position_bias.unsqueeze(0)  # (B, num_heads, N, N)

窗口注意力的前向传播流程

1. 完整前向计算步骤

mermaid

2. 关键代码解析

def forward(self, x, mask=None):
    B_, N, C = x.shape  # B_ = B * num_windows, N = window_size^2
    
    # 1. QKV线性变换与拆分
    qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
    q, k, v = qkv[0], qkv[1], qkv[2]  # (B_, num_heads, N, C//num_heads)
    
    # 2. 缩放Q与计算注意力分数
    q = q * self.scale
    attn = (q @ k.transpose(-2, -1))  # (B_, num_heads, N, N)
    
    # 3. 加入相对位置偏置
    relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
        N, N, -1
    ).permute(2, 0, 1).contiguous()  # (num_heads, N, N)
    attn = attn + relative_position_bias.unsqueeze(0)  # (B_, num_heads, N, N)
    
    # 4. 应用掩码 (用于移位窗口注意力)
    if mask is not None:
        nW = mask.shape[0]
        attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
        attn = attn.view(-1, self.num_heads, N, N)
        attn = self.softmax(attn)
    
    # 5. 注意力 dropout 与输出投影
    attn = self.attn_drop(attn)
    x = (attn @ v).transpose(1, 2).reshape(B_, N, C)  # (B_, N, C)
    x = self.proj(x)
    x = self.proj_drop(x)
    return x

移位窗口注意力:解决窗口边界问题

1. 移位窗口的必要性

固定窗口划分会导致窗口间信息隔离。SwinIR通过交替使用非移位窗口(non-shifted window)和移位窗口(shifted window)解决此问题,形成Swin Transformer Block的基本单元。

mermaid

2. 掩码计算:处理移位窗口的重叠问题

当窗口移位后,部分窗口会跨越特征图边界,需要通过掩码(mask)确保注意力计算仅在有效区域内进行:

def calculate_mask(self, x_size):
    H, W = x_size
    img_mask = torch.zeros((1, H, W, 1))  # (1, H, W, 1)
    
    # 将特征图分为3×3网格
    h_slices = (slice(0, -self.window_size),
                slice(-self.window_size, -self.shift_size),
                slice(-self.shift_size, None))
    w_slices = (slice(0, -self.window_size),
                slice(-self.window_size, -self.shift_size),
                slice(-self.shift_size, None))
    
    # 为每个网格分配唯一ID
    cnt = 0
    for h in h_slices:
        for w in w_slices:
            img_mask[:, h, w, :] = cnt
            cnt += 1
    
    # 划分窗口并生成掩码
    mask_windows = window_partition(img_mask, self.window_size)  # (nW, M, M, 1)
    mask_windows = mask_windows.view(-1, self.window_size*self.window_size)  # (nW, N)
    attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  # (nW, N, N)
    attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
    return attn_mask  # (nW, N, N)

掩码效果:值为$-100$的区域在softmax后注意力权重趋近于0,实现有效区域隔离。

计算复杂度分析

1. 理论计算量(FLOPs)

WindowAttention类的flops方法提供了精确的计算量评估:

def flops(self, N):
    flops = 0
    # QKV线性层: N × C × 3C
    flops += N * self.dim * 3 * self.dim
    # Q×K^T: num_heads × N × (C/num_heads) × N
    flops += self.num_heads * N * (self.dim // self.num_heads) * N
    # 与V相乘: num_heads × N × N × (C/num_heads)
    flops += self.num_heads * N * N * (self.dim // self.num_heads)
    # Proj线性层: N × C × C
    flops += N * self.dim * self.dim
    return flops

示例:对于$C=96$、$num_heads=6$、$N=49$($7 \times 7$窗口)的配置,单窗口FLOPs为:

  • QKV: $49 \times 96 \times 3 \times 96 = 13,547,520$
  • Q×K^T: $6 \times 49 \times 16 \times 49 = 225,792$
  • 与V相乘: $6 \times 49 \times 49 \times 16 = 225,792$
  • Proj: $49 \times 96 \times 96 = 451,584$
  • 总计: ~14.45 MFLOPs

2. 工程优化:Checkpoint机制

为节省内存,SwinIR引入梯度检查点(gradient checkpointing)技术,通过牺牲少量计算时间换取内存占用的显著降低:

if self.use_checkpoint:
    x = checkpoint.checkpoint(blk, x, x_size)  # 仅保存前向传播的部分中间结果
else:
    x = blk(x, x_size)

窗口注意力在SwinIR中的应用

1. 残差Swin Transformer块(RSTB)

窗口注意力机制通过Residual Swin Transformer Block(RSTB) 整合到SwinIR的整体架构中,每个RSTB包含多个Swin Transformer Block和残差连接:

class RSTB(nn.Module):
    def forward(self, x, x_size):
        return self.patch_embed(
            self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))
        ) + x  # 残差连接

2. 图像超分任务的输入输出处理

SwinIR针对超分任务设计了完整的输入预处理和输出重建流程:

class SwinIR(nn.Module):
    def forward_features(self, x):
        # 浅层特征提取
        x = self.conv_first(x)
        
        # 深层特征提取(含窗口注意力)
        x = self.patch_embed(x)
        if self.ape:
            x = x + self.absolute_pos_embed
        x = self.pos_drop(x)
        for layer in self.layers:
            x = layer(x, self.patches_resolution)
        x = self.norm(x)
        x = self.patch_unembed(x, self.patches_resolution)
        
        # 上采样重建
        if self.upsampler == 'pixelshuffle':
            x = self.conv_after_body(x)
            x = self.upsample(x)
            x = self.conv_last(x)
        return x

性能对比与工程实践建议

1. 与传统CNN的对比优势

指标SwinIR (窗口注意力)传统CNN (如EDSR)
感受野全局 (通过多层窗口堆叠)局部 (受限于网络深度)
长距离依赖建模强 (通过自注意力机制)弱 (需通过堆叠卷积层)
计算复杂度$O(HWM^2)$$O(HWC^2k^2)$ (k为卷积核大小)
参数效率高 (注意力权重共享)中 (固定卷积核参数)
并行性中 (窗口内可并行)高 (卷积操作高度并行)

2. 工程调优建议

  1. 窗口大小选择:对于高分辨率图像(如$256 \times 256$),建议使用$7 \times 7$或$11 \times 11$窗口;低分辨率图像(如$64 \times 64$)可使用$5 \times 5$窗口平衡精度与速度。

  2. 硬件适配:在GPU内存有限时(如12GB以下),建议:

    • use_checkpoint设为True
    • 减小embed_dim(如从96降至64)
    • 使用较小num_heads(如从6降至4)
  3. 推理加速

    • 启用TensorRT或ONNX Runtime的FP16推理
    • 合并连续的nn.Linear层和nn.LayerNorm
    • 对静态输入分辨率,预计算并固化relative_position_index

总结与未来展望

窗口注意力机制作为SwinIR的核心创新点,通过局部窗口划分、相对位置偏置和移位窗口技术,成功将Transformer的全局建模能力引入图像恢复领域。其实现细节对理解现代视觉Transformer架构具有重要参考价值。未来可进一步探索:

  • 动态窗口大小(根据图像内容自适应调整)
  • 注意力权重的稀疏化(降低计算复杂度)
  • 与CNN的混合架构(结合局部特征提取优势)

通过本文的代码解析,读者可深入理解窗口注意力的工程实现,并将其应用于自定义图像恢复模型开发中。建议结合network_swinir.py完整代码和官方预训练权重进行实验,进一步验证理论分析。

附录:关键参数配置表

参数超分任务 (x4) 默认值降噪任务 默认值作用说明
embed_dim9664嵌入维度
depths[6, 6, 6, 6][6, 6, 6]各层RSTB的深度
num_heads[6, 6, 6, 6][4, 4, 4]各层注意力头数
window_size77窗口大小
mlp_ratio4.04.0MLP隐藏层维度比例
resi_connection'1conv''3conv'残差连接中的卷积块配置
use_checkpointFalseTrue是否启用梯度检查点

通过调整上述参数,可在不同硬件环境和任务需求下取得精度与效率的平衡。

【免费下载链接】SwinIR SwinIR: Image Restoration Using Swin Transformer (official repository) 【免费下载链接】SwinIR 项目地址: https://gitcode.com/gh_mirrors/sw/SwinIR

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值