【基于时间特征交互和引导细化的遥感变化检测 】2022TGRS

变化检测TGRS2022《Remote Sensing Change Detection via Temporal Feature Interaction and Guided Refinement》

ABSTRACT

遥感变化检测(RSCD)从注册的一对遥感图像中识别变化和不变的像素,最近取得了显著的成功。然而,在RSCD中,定位具有精细结构细节的变化对象仍然是一个具有挑战性的问题。在本文中,我们提出了一种新的基于时间特征交互和引导求精的RSCD网络(TFI-GR)来解决这个问题。具体来说,与以往的方法不同,这些方法只使用一个单一的级联或减法操作来进行双时间特征融合,我们设计了一个时间特征交互模块(TFIM)来增强双时间特征之间的交互,并在不同的特征级别捕获时间差异信息。然后,重复执行一个引导细化模块(GRM),该模块聚合低级和高级时间差分表示,以抛光高级特征的位置信息并过滤低级特征的背景杂波。最后,对多层次时间差分特征进行逐步融合,生成用于变化检测的变化图。为了证明所提出的TFIGR的有效性,在三个高空间分辨率RSCD数据集上进行了综合实验。实验结果表明,该方法优于其他最先进的变化检测方法。

网络框架

在这里插入图片描述

网络分为三个阶段,分别是特征提取(Conv-1到Res-5),时间融合(TFIM),特征变化推理(GRM)。
从创新上来说,网络分为三个模块。分别是时间特征提取模块(TFIM),引导精炼模块(GRM),变化信息提取模块(CIEM)。其中GRM包含了CIEM。

网络流程

首先双时相图像T1,T2送入到基于resnet18的暹罗网络中进行特征提取,得到4对不同阶段的特征图。之后分别送入TFIM阶段用来关注差异特征。之后送入GRM用来探索补充信息。之后得到不同层次的特征用于特征聚合,从而在浅层到深层的融合过程中生成变化图。

时间特征提取模块(TFIM)

在这里插入图片描述

首先将得到的t1和t2时刻的特征图进行做差,得到差异图。之后将差异图送入3x3的卷积,之后分别与T1和T2的特征进行相乘,之后再与T1和T2进行相加,两个分支分别再经过一个3x3卷积,之后进行拼接进行通道变化,之后在与差异特征进行相加。之后经过1x1的卷积来减少通道维度,最后得到输出特征。

引导精炼模块(GRM)

在这里插入图片描述

GRM模块主要包含了CIEM和四个网络分支,采用多级输入和输出的方法对特征进行提取。

变化信息提取模块(CIEM)

在这里插入图片描述

CIEM模块首先对输入的不同层次的特征图进行上采样,大小相同之后进行特征拼接,之后将拼接后的特征送入通道注意力模块,之后与自身相乘。然后通过3x3的卷积进行特征学习,之后将特征划分为相应层次,并使用自适应平均池化来进行特征复原,之后将不同层次特征与原始特征进行相加,最后将不容层次特征进行融合得到最终的变化图。

损失函数

论文采用BCE和DICE联合损失作为损失函数。
对于变化检测任务,在大多数情况下,变化区域的比率远远小于不变区域的比率,从而导致类不平衡问题。为了缓解这个问题并引导网络从复杂场景中学习,我们采用了一种混合损失,包括二进制交叉熵损失Lbce和骰子损失Ldice。

消融实验

在这里插入图片描述

从消融实验中可以看出,两层GRM可以达到最好的效果,过多的GRM模块可能会导致过拟合。

对比实验

在这里插入图片描述

可以看出该方法再sysu数据集上达到了较好的效果,其iou为72.40。sysu应该是目前最高的结果。

论文代码 pytorch 网络结构

import torch
import torch.nn as nn
import torch.nn.functional as F
from .resnet import resnet18


class TemporalFeatureInteractionModule(nn.Module):
    def __init__(self, in_d, out_d):
        super(TemporalFeatureInteractionModule, self).__init__()
        self.in_d = in_d
        self.out_d = out_d
        self.conv_sub = nn.Sequential(
            nn.Conv2d(self.in_d, self.in_d, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(self.in_d),
            nn.ReLU(inplace=True)
        )
        self.conv_diff_enh1 = nn.Sequential(
            nn.Conv2d(self.in_d, self.in_d, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(self.in_d),
            nn.ReLU(inplace=True)
        )
        self.conv_diff_enh2 = nn.Sequential(
            nn.Conv2d(self.in_d, self.in_d, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(self.in_d),
            nn.ReLU(inplace=True)
        )
        self.conv_cat = nn.Sequential(
            nn.Conv2d(self.in_d * 2, self.in_d, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(self.in_d),
            nn.ReLU(inplace=True)
        )
        self.conv_dr = nn.Sequential(
            nn.Conv2d(self.in_d, self.out_d, kernel_size=1, bias=True),
            nn.BatchNorm2d(self.out_d),
            nn.ReLU(inplace=True)
        )

    def forward(self, x1, x2):
        # difference enhance
        x_sub = self.conv_sub(torch.abs(x1 - x2))
        x1 = self.conv_diff_enh1(x1.mul(x_sub) + x1)
        x2 = self.conv_diff_enh2(x2.mul(x_sub) + x2)
        # fusion
        x_f = torch.cat([x1, x2], dim=1)
        x_f = self.conv_cat(x_f)
        x = x_sub + x_f
        x = self.conv_dr(x)
        return x

if __name__ == "__main__":
    x = torch.randn(1, 64, 16, 16)
    y = torch.randn(1, 64, 16, 16)
    net1 = TemporalFeatureInteractionModule(64)
    x1 = net1(x,y)
    #print(x1)
    print(x1.shape)

class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmoid(out)


class ChangeInformationExtractionModule(nn.Module):
    def __init__(self, in_d, out_d):
        super(ChangeInformationExtractionModule, self).__init__()
        self.in_d = in_d
        self.out_d = out_d
        self.ca = ChannelAttention(self.in_d * 4, ratio=16)
        self.conv_dr = nn.Sequential(
            nn.Conv2d(self.in_d * 4, self.in_d, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(self.in_d),
            nn.ReLU(inplace=True)
        )
        self.pools_sizes = [2, 4, 8]
        self.conv_pool1 = nn.Sequential(
            nn.AvgPool2d(kernel_size=self.pools_sizes[0], stride=self.pools_sizes[0]),
            nn.Conv2d(self.in_d, self.in_d, kernel_size=3, stride=1, padding=1, bias=False)
        )
        self.conv_pool2 = nn.Sequential(
            nn.AvgPool2d(kernel_size=self.pools_sizes[1], stride=self.pools_sizes[1]),
            nn.Conv2d(self.in_d, self.in_d, kernel_size=3, stride=1, padding=1, bias=False)
        )
        self.conv_pool3 = nn.Sequential(
            nn.AvgPool2d(kernel_size=self.pools_sizes[2], stride=self.pools_sizes[2]),
            nn.Conv2d(self.in_d, self.in_d, kernel_size=3, stride=1, padding=1, bias=False)
        )

    def forward(self, d5, d4, d3, d2):
        # upsampling
        d5 = F.interpolate(d5, d2.size()[2:], mode='bilinear', align_corners=True)
        d4 = F.interpolate(d4, d2.size()[2:], mode='bilinear', align_corners=True)
        d3 = F.interpolate(d3, d2.size()[2:], mode='bilinear', align_corners=True)
        # fusion
        x = torch.cat([d5, d4, d3, d2], dim=1)
        x_ca = self.ca(x)
        x = x * x_ca
        x = self.conv_dr(x)

        # feature = x[0:1, 0:64, 0:64, 0:64]
        # vis.visulize_features(feature)

        # pooling
        d2 = x
        d3 = self.conv_pool1(x)
        d4 = self.conv_pool2(x)
        d5 = self.conv_pool3(x)

        return d5, d4, d3, d2


class GuidedRefinementModule(nn.Module):
    def __init__(self, in_d, out_d):
        super(GuidedRefinementModule, self).__init__()
        self.in_d = in_d
        self.out_d = out_d
        self.conv_d5 = nn.Sequential(
            nn.Conv2d(self.in_d, self.out_d, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(self.out_d),
            nn.ReLU(inplace=True)
        )
        self.conv_d4 = nn.Sequential(
            nn.Conv2d(self.in_d, self.out_d, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(self.out_d),
            nn.ReLU(inplace=True)
        )
        self.conv_d3 = nn.Sequential(
            nn.Conv2d(self.in_d, self.out_d, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(self.out_d),
            nn.ReLU(inplace=True)
        )
        self.conv_d2 = nn.Sequential(
            nn.Conv2d(self.in_d, self.out_d, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(self.out_d),
            nn.ReLU(inplace=True)
        )

    def forward(self, d5, d4, d3, d2, d5_p, d4_p, d3_p, d2_p):
        # feature refinement
        d5 = self.conv_d5(d5_p + d5)
        d4 = self.conv_d4(d4_p + d4)
        d3 = self.conv_d3(d3_p + d3)
        d2 = self.conv_d2(d2_p + d2)

        return d5, d4, d3, d2


class Decoder(nn.Module):
    def __init__(self, in_d, out_d):
        super(Decoder, self).__init__()
        self.in_d = in_d
        self.out_d = out_d
        self.conv_sum1 = nn.Sequential(
            nn.Conv2d(self.in_d, self.in_d, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(self.in_d),
            nn.ReLU(inplace=True)
        )
        self.conv_sum2 = nn.Sequential(
            nn.Conv2d(self.in_d, self.in_d, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(self.in_d),
            nn.ReLU(inplace=True)
        )
        self.conv_sum3 = nn.Sequential(
            nn.Conv2d(self.in_d, self.in_d, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(self.in_d),
            nn.ReLU(inplace=True)
        )
        self.cls = nn.Conv2d(self.in_d, self.out_d, kernel_size=1, bias=False)

    def forward(self, d5, d4, d3, d2):

        d5 = F.interpolate(d5, d4.size()[2:], mode='bilinear', align_corners=True)
        d4 = self.conv_sum1(d4 + d5)
        d4 = F.interpolate(d4, d3.size()[2:], mode='bilinear', align_corners=True)
        d3 = self.conv_sum1(d3 + d4)
        d3 = F.interpolate(d3, d2.size()[2:], mode='bilinear', align_corners=True)
        d2 = self.conv_sum1(d2 + d3)

        mask = self.cls(d2)

        return mask


class BaseNet(nn.Module):
    def __init__(self, input_nc, output_nc):
        super(BaseNet, self).__init__()
        self.backbone = resnet18(pretrained=True)
        self.mid_d = 64
        self.TFIM5 = TemporalFeatureInteractionModule(512, self.mid_d)
        self.TFIM4 = TemporalFeatureInteractionModule(256, self.mid_d)
        self.TFIM3 = TemporalFeatureInteractionModule(128, self.mid_d)
        self.TFIM2 = TemporalFeatureInteractionModule(64, self.mid_d)

        self.CIEM1 = ChangeInformationExtractionModule(self.mid_d, output_nc)
        self.GRM1 = GuidedRefinementModule(self.mid_d, self.mid_d)

        self.CIEM2 = ChangeInformationExtractionModule(self.mid_d, output_nc)
        self.GRM2 = GuidedRefinementModule(self.mid_d, self.mid_d)

        self.decoder = Decoder(self.mid_d, output_nc)

    def forward(self, x1, x2):
        # forward backbone resnet
        x1_1, x1_2, x1_3, x1_4, x1_5 = self.backbone.base_forward(x1)
        x2_1, x2_2, x2_3, x2_4, x2_5 = self.backbone.base_forward(x2)
        # feature difference
        d5 = self.TFIM5(x1_5, x2_5)  # 1/32
        d4 = self.TFIM4(x1_4, x2_4)  # 1/16
        d3 = self.TFIM3(x1_3, x2_3)  # 1/8
        d2 = self.TFIM2(x1_2, x2_2)  # 1/4

        # change information guided refinement 1
        d5_p, d4_p, d3_p, d2_p = self.CIEM1(d5, d4, d3, d2)
        d5, d4, d3, d2 = self.GRM1(d5, d4, d3, d2, d5_p, d4_p, d3_p, d2_p)

        # change information guided refinement 2
        d5_p, d4_p, d3_p, d2_p = self.CIEM2(d5, d4, d3, d2)
        d5, d4, d3, d2 = self.GRM2(d5, d4, d3, d2, d5_p, d4_p, d3_p, d2_p)

        # decoder
        mask = self.decoder(d5, d4, d3, d2)
        mask = F.interpolate(mask, x1.size()[2:], mode='bilinear', align_corners=True)
        mask = torch.sigmoid(mask)

        return mask

论文地址

https://ieeexplore.ieee.org/document/9863802

### TGRS 遥感图像语义分割方法技术 #### 数据集与挑战 遥感图像的语义分割面临诸多独特挑战,包括复杂的背景、多样的物体尺度以及高分辨率带来的计算负担。为了应对这些挑战,研究人员提出了多种针对遥感图像特点的方法策略[^1]。 #### 自监督学习方法 自监督学习成为解决遥感图像语义分割的一个重要途径。通过利用未标记的数据,可以有效减少对标记数据的需求。例如,SeCo[58] 利用了季节变化来增强正样本间的一致性,这种方法特别适用于具有显著时间变化特性的空中场景。此外,还有研究将时间信息地理位置融入 MoCo-V2 框架中,进一步提升了模型的表现力[^3]。 #### 多模态融合技术 除了单源遥感影像外,结合来自不同传感器或多时相的信息能够提高语义分割的效果。这不仅限于光学图像与其他类型的电磁波谱(如 SAR),还包括气象条件等外部因素的影响分析。这种跨域特征提取有助于捕捉更丰富的上下文信息,从而改善最终的结果质量[^2]。 #### 特征表达优化 对于特定任务而言,设计有效的网络结构至关重要。一些工作专注于改进基础卷积神经网络的设计,比如引入注意力机制以突出关键区域;或是基于Transformer架构构建更适合处理长距离依赖关系的新颖模块。这类创新使得模型能够在保持较高精度的同时降低过拟合风险。 ```python import torch.nn as nn class AttentionModule(nn.Module): def __init__(self, channels): super(AttentionModule, self).__init__() self.conv = nn.Conv2d(channels, channels//4, kernel_size=1) self.softmax = nn.Softmax(dim=-1) def forward(self, x): b, c, h, w = x.size() proj_query = self.conv(x).view(b,-1,w*h).permute(0,2,1) proj_key = self.conv(x).view(b,-1,w*h) energy = torch.bmm(proj_query,proj_key) attention = self.softmax(energy) out = torch.bmm(attention.permute(0,2,1), x.view(b,c,h*w)).view(b,c,h,w) return out + x ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

只能凯瑞一点

你的鼓励是我最大的动力,多交流

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值