【即插即用】STN注意力机制(附源码)

论文地址:

https://arxiv.org/pdf/1506.02025.pdf

简要介绍:

STN,也就是空间变换网络,是一种很酷的机器学习技术,它能自动地调整图像,让网络更好地处理图片的各种变化,比如扭曲或旋转。这个过程就像给图片做一个微整,让它们在变美的同时,也让计算机更容易识别。

STN由三个主要部分组成:定位网络、网格生成器和采样器。定位网络负责分析图像,找出需要调整的地方,比如哪里需要拉伸,哪里需要旋转。想象一下,这就像有一支神奇的笔,能够精确地标记出图片的每个像素需要怎么移动。

网格生成器根据定位网络提供的信息,制作出一个“模板”,这个模板指导采样器如何从原始图像中取样。采样器就像一个巧手的工匠,按照这个模板,用细腻的手法从原始图片中选取正确的像素,然后把它们重新放置到新的位置,让图片达到最佳的视觉效果。

STN的好处是,它能够自我学习如何处理各种复杂的图像变换,这让它在识别图片时更加精准。而且,STN可以像搭乐高一样,和其他深度学习模型一起使用,让整个系统变得更加强大。

结构图:
Pytorch源码:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt


class STN(nn.Module):
    def __init__(self):
        super(STN, self).__init__()
        self.localization = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=7),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True),
            nn.Conv2d(8, 10, kernel_size=5),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True)
        )
        # 移除fc_loc中第一个nn.Linear的输入尺寸定义
        self.fc_loc = None  # 将在forward中动态创建

        # 用于空间变换网络的权重和偏置初始化参数
        self.fc_loc_output_params = torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float)

    def forward(self, x):
        xs = self.localization(x)
        xs_size = xs.size()
        # 动态计算全连接层的输入尺寸
        fc_input_size = xs_size[1] * xs_size[2] * xs_size[3]

        # 根据动态计算的输入尺寸创建fc_loc
        if self.fc_loc is None:
            self.fc_loc = nn.Sequential(
                nn.Linear(fc_input_size, 32),
                nn.ReLU(True),
                nn.Linear(32, 3 * 2)
            )
            # 初始化空间变换网络的权重和偏置
            self.fc_loc[2].weight.data.zero_()
            self.fc_loc[2].bias.data.copy_(self.fc_loc_output_params)

        xs = xs.view(-1, fc_input_size)
        theta = self.fc_loc(xs)
        theta = theta.view(-1, 2, 3)
        grid = F.affine_grid(theta, x.size(), align_corners=True)
        x = F.grid_sample(x, grid, align_corners=True)
        return x


def test_stn(input_size=(1, 1, 32, 32)):
    stn = STN()
    input_tensor = torch.rand(input_size)
    transformed_tensor = stn(input_tensor)
    print(transformed_tensor.shape)

    input_image = input_tensor.numpy()[0][0]
    transformed_image = transformed_tensor.detach().numpy()[0][0]

    plt.figure()
    plt.subplot(1, 2, 1)
    plt.title("Original Image")
    plt.imshow(input_image, cmap='gray')
    plt.subplot(1, 2, 2)
    plt.title("Transformed Image")
    plt.imshow(transformed_image, cmap='gray')
    plt.show()


# 测试不同尺寸的输入
# test_stn(input_size=(1, 1, 32, 32))
test_stn(input_size=(1, 1, 64, 64))

### STN 中的空间注意力机制 空间变换网络(STN, Spatial Transformer Network)是一种能够使神经网络学习到对输入数据的几何变换具有鲁棒性的模块。它通过引入局部化的注意力建 Mechansim 来增强模型对于不同尺度、旋转和平移变化的学习能力[^2]。 #### 工作原理 STN 的核心在于其三部分结构: 1. **本地网络(Localization Network)**: 这是一个小型卷积神经网络,负责预测仿射变换参数矩阵 \( \theta \),该矩阵定义了如何调整输入特征图的位置和形状。 2. **网格生成器(Grid Generator)**: 基于由 Localization Network 输出的 \( \theta \) 参数计算目标位置坐标网格。这些新坐标的集合表示经过变换后的样本点分布情况。 3. **采样器(Sampler)**: 使用双线性插值法或其他方式从原始图片上提取对应像素值填充至新的指定区域完成最终输出操作过程[^3]。 这种设计允许 STN 自动捕捉最佳视角或者矫正倾斜对象等问题而无需额外标注信息支持下达到良好表现水平同时还能与其他主流架构无缝衔接形成端到端可微分训练框架从而进一步提升整体性能指标比如 mIoU 等评估标准得分更高表明分割准确性更好[^4]。 以下是 Python 实现的一个简单版本: ```python import torch.nn as nn class STNet(nn.Module): def __init__(self): super(STNet, self).__init__() # 定位网络(Localization network) self.localization = nn.Sequential( nn.Conv2d(1, 8, kernel_size=7), nn.MaxPool2d(2, stride=2), nn.ReLU(True), nn.Conv2d(8, 10, kernel_size=5), nn.MaxPool2d(2, stride=2), nn.ReLU(True) ) # Regressor for the affine matrix self.fc_loc = nn.Sequential( nn.Linear(10 * 3 * 3, 32), nn.ReLU(True), nn.Linear(32, 3 * 2) # Output should be (batch_size, 6) ) def forward(self, x): xs = self.localization(x) xs = xs.view(-1, 10*3*3) theta = self.fc_loc(xs).view(-1, 2, 3) grid = F.affine_grid(theta, x.size()) x = F.grid_sample(x, grid) return x ``` 上述代码展示了如何构建一个基本形式下的 STN 结构并将其嵌入 PyTorch 模型之中以便后续调用执行具体任务需求时可以直接利用此组件来处理复杂场景下的视觉理解挑战项目当中去尝试解决实际应用层面遇到的各种难题状况发生概率较低但仍需保持警惕态度对待可能出现异常情形做好充分准备措施加以应对即可有效降低风险系数提高系统稳定性保障业务连续正常运转不受干扰影响持续创造价值回报社会大众群体共同进步发展繁荣昌盛未来无限美好前景值得期待憧憬向往追求卓越成就非凡人生梦想成真指日可待加油奋斗吧少年们! ---
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值