pytorch代码实现之SAConv卷积

SAConv卷积

SAConv卷积模块是一种精度更高、速度更快的“即插即用”卷积,目前很多方法被提出用于降低模型冗余、加速模型推理速度,然而这些方法往往关注于消除不重要的滤波器或构建高效计算单元,反而忽略了特征内部的模式冗余。
原文地址:Split to Be Slim: An Overlooked Redundancy in Vanilla Convolution

由于同一层内的许多特征具有相似却不平等的表现模式。然而,这类具有相似模式的特征却难以判断是否存在冗余或包含重要的细节信息。因此,不同于直接移除不确定的冗余特征方案,提出了一种基于Split的卷积计算单元(称之为SPConv),它运训存在相似模型冗余且仅需非常少的计算量。

SPConv结构图

首先,将输入特征拆分为representative部分与uncertain部分;然后,对于representative部分特征采用相对多的计算复杂度操作提取重要信息,对于uncertain部分采用轻量型操作提取隐含信息;最后,为重新校准与融合两组特征,作者采用了无参特征融合模块。该文所提SPConv是一种“即插即用”型模块,可用于替换现有网络中的常规卷积。

​无需任何技巧,在GPU端的精度与推理速度方面,基于SPConv的网络均可取得SOTA性能。该文主要贡献包含下面几个方面:
(1)重新对常规卷积中的特征冗余问题进行了再思考,提出了将输入分成两部分:representative与uncertain,分别针对两部分进行不同的信息提取;
(2)设计了一种“即插即用”型SPConv模块,它可以无缝替换现有网络中的常规卷积,且在精度与GPU推理速度上均可能优于SOTA性能,同时具有更少的FLOPs和参数量。

代码实现

class ConvAWS2d(nn.Conv2d):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True):
        super().__init__(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=bias)
        self.register_buffer('weight_gamma', torch.ones(self.out_channels, 1, 1, 1))
        self.register_buffer('weight_beta', torch.zeros(self.out_channels, 1, 1, 1))

    def _get_weight(self, weight):
        weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2,
                                  keepdim=True).mean(dim=3, keepdim=True)
        weight = weight - weight_mean
        std = torch.sqrt(weight.view(weight.size(0), -1).var(dim=1) + 1e-5).view(-1, 1, 1, 1)
        weight = weight / std
        weight = self.weight_gamma * weight + self.weight_beta
        return weight

    def forward(self, x):
        weight = self._get_weight(self.weight)
        return super()._conv_forward(x, weight, None)

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        self.weight_gamma.data.fill_(-1)
        super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
                                      missing_keys, unexpected_keys, error_msgs)
        if self.weight_gamma.data.mean() > 0:
            return
        weight = self.weight.data
        weight_mean = weight.data.mean(dim=1, keepdim=True).mean(dim=2,
                                       keepdim=True).mean(dim=3, keepdim=True)
        self.weight_beta.data.copy_(weight_mean)
        std = torch.sqrt(weight.view(weight.size(0), -1).var(dim=1) + 1e-5).view(-1, 1, 1, 1)
        self.weight_gamma.data.copy_(std)
    
class SAConv2d(ConvAWS2d):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 s=1,
                 p=None,
                 g=1,
                 d=1,
                 act=True,
                 bias=True):
        super().__init__(
            in_channels,
            out_channels,
            kernel_size,
            stride=s,
            padding=autopad(kernel_size, p),
            dilation=d,
            groups=g,
            bias=bias)
        self.switch = torch.nn.Conv2d(
            self.in_channels,
            1,
            kernel_size=1,
            stride=s,
            bias=True)
        self.switch.weight.data.fill_(0)
        self.switch.bias.data.fill_(1)
        self.weight_diff = torch.nn.Parameter(torch.Tensor(self.weight.size()))
        self.weight_diff.data.zero_()
        self.pre_context = torch.nn.Conv2d(
            self.in_channels,
            self.in_channels,
            kernel_size=1,
            bias=True)
        self.pre_context.weight.data.fill_(0)
        self.pre_context.bias.data.fill_(0)
        self.post_context = torch.nn.Conv2d(
            self.out_channels,
            self.out_channels,
            kernel_size=1,
            bias=True)
        self.post_context.weight.data.fill_(0)
        self.post_context.bias.data.fill_(0)
        
        self.bn = nn.BatchNorm2d(out_channels)
        self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())

    def forward(self, x):
        # pre-context
        avg_x = torch.nn.functional.adaptive_avg_pool2d(x, output_size=1)
        avg_x = self.pre_context(avg_x)
        avg_x = avg_x.expand_as(x)
        x = x + avg_x
        # switch
        avg_x = torch.nn.functional.pad(x, pad=(2, 2, 2, 2), mode="reflect")
        avg_x = torch.nn.functional.avg_pool2d(avg_x, kernel_size=5, stride=1, padding=0)
        switch = self.switch(avg_x)
        # sac
        weight = self._get_weight(self.weight)
        out_s = super()._conv_forward(x, weight, None)
        ori_p = self.padding
        ori_d = self.dilation
        self.padding = tuple(3 * p for p in self.padding)
        self.dilation = tuple(3 * d for d in self.dilation)
        weight = weight + self.weight_diff
        out_l = super()._conv_forward(x, weight, None)
        out = switch * out_s + (1 - switch) * out_l
        self.padding = ori_p
        self.dilation = ori_d
        # post-context
        avg_x = torch.nn.functional.adaptive_avg_pool2d(out, output_size=1)
        avg_x = self.post_context(avg_x)
        avg_x = avg_x.expand_as(out)
        out = out + avg_x
        return self.act(self.bn(out))
<think>嗯,用户想知道YOLOv11中BN层的具体名称或实现方式。首先,我需要回忆一下YOLO系列模型的结构。YOLOv11虽然可能不是官方的版本,但根据用户提供的引用[1],它可能是基于Ultralytics的YOLO架构改进的。通常,YOLO模型中的卷积块会包含卷积层、批归一化(Batch Normalization,BN)层和激活函数,比如ReLU或SiLU。 根据引用[2]中的PyTorch BN层实现建议,用户可能需要通过遍历模型的子模块来找到BN层。在PyTorch中,BN层的类名是`torch.nn.BatchNorm2d`,所以YOLOv11中的BN层应该也是这个类。名称方面,通常这些层会被命名为类似“bn1”、“bn2”等,但具体名称可能取决于模型配置文件的结构。 用户提供的引用[1]中提到了使用SAConv模块改进YOLOv11的配置文件yolo11n_SAconv.yaml,这提示BN层的名称可能与模块结构相关,比如在SAConv模块中的BN层可能被命名为“bn”或“norm”。需要查看该配置文件的定义才能确定具体名称。 实现方式的话,YOLOv11应该遵循标准的PyTorch BN实现,即使用`BatchNorm2d`,包括对输入数据的归一化处理、可学习的缩放和偏移参数。训练时,BN层会计算当前批次的均值和方差,并更新运行时的统计量;推理时则使用这些统计量进行归一化。 用户可能需要通过代码示例来查找BN层,比如使用`model.named_modules()`来遍历所有模块,并筛选出`BatchNorm2d`类型的层。此外,引用[2]中的`model.apply(set_bn_eval)`方法可以用来设置BN层在推理模式,这也说明BN层的实现是标准的PyTorch方式。 总结来说,YOLOv11中的BN层名称可能根据模块不同而有所变化,但实现方式应基于PyTorch的`BatchNorm2d`。用户若想具体确认名称,需要查看模型配置文件或打印模型结构。</think>### YOLOv11中Batch Normalization层实现解析 根据PyTorch标准实现和YOLO系列网络结构特征,YOLOv11中的BN层实现遵循以下规律: 1. **BN层命名规则** 在YOLOv11的卷积模块中,典型的BN层命名包含以下特征: - 基础命名模式:`model*.m*.bn` - 在SAConv模块中可能命名为:`model*.SAConv.bn` - 具体示例:`model.22.m.0.bn`(第22层第一个模块的BN层) 2. **实现代码定位** 可通过以下代码片段遍历模型中的BN层: ```python import torch from ultralytics import YOLO model = YOLO("yolo11n_SAconv.yaml").model for name, module in model.named_modules(): if isinstance(module, torch.nn.BatchNorm2d): print(f"BN层名称:{name}") print(f"参数维度:{module.weight.shape}") ``` 3. **核心实现逻辑** YOLOv11的BN层继承自PyTorch的`BatchNorm2d`,其计算过程遵循标准公式: $$ y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} \times \gamma + \beta $$ 其中$\gamma$和$\beta$是可训练参数,对应代码中的`weight`和`bias`属性[^2] 4. **配置文件示例** 在`yolo11n_SAconv.yaml`中,典型模块定义包含: ```yaml [[-1, 1, Conv, [64, 3, 1]], # 卷积层 [-1, 1, BatchNorm2d, [None]], # BN层 [-1, 1, nn.SiLU, []]] # 激活函数 ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

我悟了-

你的激励是我肝下去的动力~

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值