(即插即用模块-特征处理部分) 二十五、(2023 ICME) TFF&SFF 时间&通道特征融合模块

在这里插入图片描述

paper:STNet: Spatial and Temporal feature fusion network for change detection in remote sensing images

Code:https://github.com/xwmaxwma/rschange


1、Temporal Feature Fusion Module

遥感图像变化检测中的目标变化区域往往只占图像的一小部分,而非目标变化(如天气变化、光照变化等)可能占据图像的大部分区域,从而导致变化检测任务变得困难。传统的特征融合方法无法有效区分目标变化和非目标变化以及无法有效融合不同尺度的特征图,这就使得现有的研究容易受到非目标变化的干扰、检测结果的细节信息丢失或语义信息不准确,最终导致变化检测结果不准确。论文论证道:变化检测任务需要同时关注变化区域的语义信息和空间细节信息,同时不同尺度的特征图包含不同的信息,如高层特征图语义信息丰富但空间细节信息丢失,低层特征图空间细节信息丰富但语义信息丢失。基于此,这篇论文提出两种特征融合模块 时间特征融合模块(Temporal Feature Fusion Module)通道特征融合模块(Spatial Feature Fusion Module)

TFF 模块通过计算目标变化和非目标变化特征图之间的差异,并利用跨时间门控机制来突出目标变化特征,同时抑制非目标变化特征,从而提高变化检测的准确性和鲁棒性。其主要根据输入特征图的重要性动态地调整输出特征图的权重。

对于输入X,TFF 的实现过程:

  1. 对比度归一化:将输入的双时相特征 R1 和 R2 进行相减,得到粗略的变化表示 Rc。
  2. 深度可分离卷积:将 Rc 分别与 R1 和 R2 拼接,并进行深度可分离卷积,得到 Rc1 和 Rc2。
  3. 门控机制:通过 1x1 卷积和 Sigmoid 函数,计算 Rc1 和 Rc2 的权重 W1 和 W2。
  4. 融合特征:将 W1 和 W2 分别与 R1 和 R2 进行逐元素相乘,并将结果拼接,最后进行深度可分离卷积,得到融合后的特征。

Temporal Feature Fusion Module 结构图:
在这里插入图片描述


2、Spatial Feature Fusion Module

SFF 模块利用高层特征图来指导低层特征图的融合,从而在保留低层特征图空间细节信息的同时,增强其语义信息,从而提高变化检测结果的准确性和鲁棒性。其可以根据不同尺度特征图之间的相关性动态地调整输出特征图的权重。

对于输入X,SFF 的实现过程:

上采样:将高层变化表示进行上采样,使其与低层变化表示具有相同的尺寸。

拼接:将上采样后的高层变化表示与低层变化表示进行拼接。

注意力机制:计算拼接后特征图的查询 Q、键 K 和值 V,并通过加权求和的方式得到最终的融合特征图。


Spatial Feature Fusion Module 结构图:
在这里插入图片描述

3、代码实现

import torch
import torch.nn as nn
import torch.nn.functional as F


def conv_3x3(in_channel, out_channel):
    return nn.Sequential(
        nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=False),
        nn.BatchNorm2d(out_channel),
        nn.ReLU(inplace=True)
    )


def dsconv_3x3(in_channel, out_channel):
    return nn.Sequential(
        nn.Conv2d(in_channel, in_channel, kernel_size=3, stride=1, padding=1, groups=in_channel),
        nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0, groups=1),
        nn.BatchNorm2d(out_channel),
        nn.ReLU(inplace=True)
    )


def conv_1x1(in_channel, out_channel):
    return nn.Sequential(
        nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0, bias=False),
        nn.BatchNorm2d(out_channel),
        nn.ReLU(inplace=True)
    )


class SelfAttentionBlock(nn.Module):
    def __init__(self, key_in_channels, query_in_channels, transform_channels, out_channels,
                 key_query_num_convs, value_out_num_convs):
        super(SelfAttentionBlock, self).__init__()
        self.key_project = self.buildproject(
            in_channels=key_in_channels,
            out_channels=transform_channels,
            num_convs=key_query_num_convs,
        )
        self.query_project = self.buildproject(
            in_channels=query_in_channels,
            out_channels=transform_channels,
            num_convs=key_query_num_convs
        )
        self.value_project = self.buildproject(
            in_channels=key_in_channels,
            out_channels=transform_channels,
            num_convs=value_out_num_convs
        )
        self.out_project = self.buildproject(
            in_channels=transform_channels,
            out_channels=out_channels,
            num_convs=value_out_num_convs
        )
        self.transform_channels = transform_channels

    def forward(self, query_feats, key_feats, value_feats):
        batch_size = query_feats.size(0)

        query = self.query_project(query_feats)
        query = query.reshape(*query.shape[:2], -1)
        query = query.permute(0, 2, 1).contiguous()  # (B, h*w, C)

        key = self.key_project(key_feats)
        key = key.reshape(*key.shape[:2], -1)  # (B, C, h*w)

        value = self.value_project(value_feats)
        value = value.reshape(*value.shape[:2], -1)
        value = value.permute(0, 2, 1).contiguous()  # (B, h*w, C)

        sim_map = torch.matmul(query, key)

        sim_map = (self.transform_channels ** -0.5) * sim_map
        sim_map = F.softmax(sim_map, dim=-1)  # (B, h*w, K)

        context = torch.matmul(sim_map, value)  # (B, h*w, C)
        context = context.permute(0, 2, 1).contiguous()
        context = context.reshape(batch_size, -1, *query_feats.shape[2:])  # (B, C, h, w)

        context = self.out_project(context)  # (B, C, h, w)
        return context

    def buildproject(self, in_channels, out_channels, num_convs):
        convs = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        for _ in range(num_convs - 1):
            convs.append(
                nn.Sequential(
                    nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
                    nn.BatchNorm2d(out_channels),
                    nn.ReLU(inplace=True)
                )
            )
        if len(convs) > 1:
            return nn.Sequential(*convs)
        return convs[0]


class TFF(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(TFF, self).__init__()
        self.catconvA = dsconv_3x3(in_channel * 2, in_channel)
        self.catconvB = dsconv_3x3(in_channel * 2, in_channel)
        self.catconv = dsconv_3x3(in_channel * 2, out_channel)
        self.convA = nn.Conv2d(in_channel, 1, 1)
        self.convB = nn.Conv2d(in_channel, 1, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, xA, xB):
        x_diff = xA - xB

        x_diffA = self.catconvA(torch.cat([x_diff, xA], dim=1))
        x_diffB = self.catconvB(torch.cat([x_diff, xB], dim=1))

        A_weight = self.sigmoid(self.convA(x_diffA))
        B_weight = self.sigmoid(self.convB(x_diffB))

        xA = A_weight * xA
        xB = B_weight * xB

        x = self.catconv(torch.cat([xA, xB], dim=1))

        return x


class SFF(nn.Module):
    def __init__(self, in_channel):
        super(SFF, self).__init__()
        self.conv_small = conv_1x1(in_channel, in_channel)
        self.conv_big = conv_1x1(in_channel, in_channel)
        self.catconv = conv_3x3(in_channel * 2, in_channel)
        self.attention = SelfAttentionBlock(
            key_in_channels=in_channel,
            query_in_channels=in_channel,
            transform_channels=in_channel // 2,
            out_channels=in_channel,
            key_query_num_convs=2,
            value_out_num_convs=1
        )

    def forward(self, x_small, x_big):
        img_size = x_big.size(2), x_big.size(3)
        x_small = F.interpolate(x_small, img_size, mode="bilinear", align_corners=False)
        x = self.conv_small(x_small) + self.conv_big(x_big)
        new_x = self.attention(x, x, x_big)

        out = self.catconv(torch.cat([new_x, x_big], dim=1))
        return out


if __name__ == '__main__':
    x = torch.randn(4, 512, 7, 7).cuda()
    y = torch.randn(4, 512, 7, 7).cuda()
    # model = TFF(512, 512).cuda()
    model = SFF(512).cuda()
    out = model(x, y)
    print(out.shape)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

御宇w

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值