每日Attention学习24——Strip Convolution Block

模块出处

[TIP 21] [link] CoANet: Connectivity Attention Network for Road Extraction From Satellite Imagery


模块名称

Strip Convolution Block (SCB)


模块作用

多方向条形特征提取


模块结构

在这里插入图片描述


模块特点
  • 类PSP设计,采用四个并行分支提取不同维度的信息
  • 相比于经典的横向/纵向条形卷积,引入了两种斜方向的卷积来更好的学习斜向线条

模块代码
import torch
import torch.nn as nn
import torch.nn.functional as F


class SCB(nn.Module):
    def __init__(self, in_channels, n_filters):
        super(SCB, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, in_channels // 4, 1)
        self.bn1 = nn.BatchNorm2d(in_channels // 4)
        self.relu1 = nn.ReLU()
        self.deconv1 = nn.Conv2d(
            in_channels // 4, in_channels // 8, (1, 9), padding=(0, 4)
        )
        self.deconv2 = nn.Conv2d(
            in_channels // 4, in_channels // 8, (9, 1), padding=(4, 0)
        )
        self.deconv3 = nn.Conv2d(
            in_channels // 4, in_channels // 8, (9, 1), padding=(4, 0)
        )
        self.deconv4 = nn.Conv2d(
            in_channels // 4, in_channels // 8, (1, 9), padding=(0, 4)
        )
        self.bn2 = nn.BatchNorm2d(in_channels // 4 + in_channels // 4)
        self.relu2 = nn.ReLU()
        self.conv3 = nn.Conv2d(
            in_channels // 4 + in_channels // 4, n_filters, 1)
        self.bn3 = nn.BatchNorm2d(n_filters)
        self.relu3 = nn.ReLU()

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x1 = self.deconv1(x)
        x2 = self.deconv2(x)
        x3 = self.inv_h_transform(self.deconv3(self.h_transform(x)))
        x4 = self.inv_v_transform(self.deconv4(self.v_transform(x)))
        x = torch.cat((x1, x2, x3, x4), 1)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu3(x)
        return x

    def h_transform(self, x):
        shape = x.size()
        x = torch.nn.functional.pad(x, (0, shape[-1]))
        x = x.reshape(shape[0], shape[1], -1)[..., :-shape[-1]]
        x = x.reshape(shape[0], shape[1], shape[2], 2*shape[3]-1)
        return x

    def inv_h_transform(self, x):
        shape = x.size()
        x = x.reshape(shape[0], shape[1], -1).contiguous()
        x = torch.nn.functional.pad(x, (0, shape[-2]))
        x = x.reshape(shape[0], shape[1], shape[-2], 2*shape[-2])
        x = x[..., 0: shape[-2]]
        return x

    def v_transform(self, x):
        x = x.permute(0, 1, 3, 2)
        shape = x.size()
        x = torch.nn.functional.pad(x, (0, shape[-1]))
        x = x.reshape(shape[0], shape[1], -1)[..., :-shape[-1]]
        x = x.reshape(shape[0], shape[1], shape[2], 2*shape[3]-1)
        return x.permute(0, 1, 3, 2)

    def inv_v_transform(self, x):
        x = x.permute(0, 1, 3, 2)
        shape = x.size()
        x = x.reshape(shape[0], shape[1], -1)
        x = torch.nn.functional.pad(x, (0, shape[-2]))
        x = x.reshape(shape[0], shape[1], shape[-2], 2*shape[-2])
        x = x[..., 0: shape[-2]]
        return x.permute(0, 1, 3, 2)


if __name__ == '__main__':
    x = torch.randn([1, 64, 44, 44])
    scb = SCB(in_channels=64, n_filters=64)
    out = scb(x)
    print(out.shape)  # [1, 64, 44, 44]
### Strip Convolution in Deep Learning Frameworks Strip convolution represents an interesting variant within the broader family of convolutional operations used in deep learning architectures. Unlike traditional convolutions that operate over both height and width dimensions simultaneously, strip convolutions apply filters along one dimension only—either horizontally or vertically. In practice, this means: - **Horizontal Strips:** Filters slide across rows while maintaining fixed positions along columns. - **Vertical Strips:** Filters move down columns but stay stationary relative to row indices. This approach can be particularly useful when dealing with specific types of structured data where patterns exhibit strong directional characteristics. For instance, certain image features might align predominantly in horizontal lines (e.g., horizon lines in landscapes), making horizontal strips especially effective at capturing these structures without unnecessary computation overhead associated with full 2D convolutions[^1]. Implementing strip convolutions involves modifying standard convolution layers so they act upon single-axis inputs rather than applying square-shaped kernels uniformly across input planes. Below demonstrates how such functionality could look implemented using PyTorch as an example framework: ```python import torch.nn.functional as F from torch import nn class HorizontalStripConv(nn.Module): def __init__(kernel_size=3, padding=1): super().__init__() self.conv = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(1, kernel_size), padding=(0, padding)) def forward(x): return self.conv(x) class VerticalStripConv(nn.Module): def __init__(self, kernel_size=3, padding=1): super().__init__() self.conv = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(kernel_size, 1), padding=(padding, 0)) def forward(self, x): return self.conv(x) ``` While spatially separable convolutions offer computational benefits through decomposition into simpler components, not every scenario allows for efficient splitting of arbitrary kernels into smaller parts. Therefore, employing specialized forms like strip convolutions may provide targeted improvements depending on application-specific requirements.
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值