A Multilevel Multimodal Fusion Transformer for Remote Sensing Semantic Segmentation(论文阅读笔记)

A Multilevel Multimodal Fusion Transformer for Remote Sensing Semantic Segmentation
多级多模态融合变压器遥感语义分割
Xianping Ma,Xiaokang Zhang , Member, IEEE, Man-On Pun , Senior Member, IEEE, and Ming Liu
论文地址
代码地址

1.INTRODUCTION

只有少数工作致力于多模态融合任务(根据不同模态的特性融合多模态信息),强调根据不同模态的特性融合多模态信息。与自然图像相比,高分辨率遥感图像具有更严重的光谱异质性和更复杂的空间结构,此外,遥感数据中的地物在尺寸和形状上表现出更大的变化,这使得定位和识别目标变得困难,结果表明,从CV领域衍生出的基于CNN和Transformer的模型在有效学习判别性综合特征方面仍存在局限性。
解决多模态的三种策略:

  1. 早期融合:需要对多模态数据进行适当的对齐,并且可能缺乏对任务无关信息的鲁棒性。
  2. 晚期融合:利用多模态数据之间的交叉相关性方面能力有限。
  3. 中期融合:能够捕捉特征表示的跨模态依赖关系,使其在表示学习的背景下更加有效。
    现有工作通常基于求和或拼接进行单层特征融合,忽略了不同特征层次上的长距离跨模态依赖关系。

2. 网络结构

请添加图片描述
在CNN主干网络中,SFF模块被用于浅层特征融合,而在FViT中的Ada-MBA层则被设计用于深层特征融合。双模态数据(即可见光图像和DSM数据)被用来详细说明所提出的FTransUNet,可见光图像被视为主要模态,而DMS数据则作为辅助模态,因为主要模态通常比辅助模态为地表分类提高更多的信息。
总体来说,这个网络框架也是一种编码器-解码器的结构,然后引入了自注意力机制、交叉注意力机制、残差网络,其他的就是对特征图的下采样和上采样。对图片的处理利用了两个ResNet分支结构,来分别提取可见光图像特征(VIS)和数字表面模型(DSM)的特征。每次下采样之后会通过SFF模块进行特征增强(特征图的拼接),经过四次下采样之后,将特征图展平(这里就和ViT的处理方式一样,其实我感觉前面的步骤除了SFF(浅特征融合)模块,其他的就和TransUNet的处理方式一模一样),然后将VIS和DSM展开后送入FViT处理,这里面包含了自注意力和交叉注意力,处理完之后就进行解码操作了,我还是觉得主要的变化全是在编码阶段,无论是跳连接还是解码操作和U-Net完全一样。

2.1 CNN融合

请添加图片描述
FTransNet在对图片处理阶段类似于TransUNet,它这里的输入分为了单通道的数字表面模型特征和可见光图片,对于每个分支来说,它通过四次卷积进行下采样,用于提取多尺度特征,特征图一次变为1/2、1/4、1/8,每次卷积操作后有一次SFF(浅特征融合),融合后的特征会被融合到下次的VIS中,并且还会通过跳连接,与解码阶段对应的特征图进行拼接融合。下面是SFF模块图示:
请添加图片描述
SFF模块就比较简单了,对VIS和DSM进行全局平均池化,然后是两个1*1的卷积,以及ReLU和Sigmoid函数,处理完之后,将来自VIS和DSM的特征经过加权并通过逐元素相加的方式融合,从而生成最终的浅层融合特征。

2.2 Fusion Vision Transformer

请添加图片描述
接着上面的下采样过程后,最后一次采样后的特征图进行了Embedding(这里是ViT里面的操作),传入到FViT中的特征图分别为xI(VIS)和yl(DSM),原文中是这么说的“ x l x_l xl y l y_l yl首先通过两个嵌入层和一个重塑操作进行标记化处理。嵌入层将输入的通道数从 C l C_l Cl改变为 C h i d C_{hid} Chid,随后重塑操作将嵌入层的输出战平为两个二维补丁序列,分别记为 z x 0 z_x^0 zx0 z y 0 z_y^0 zy0,其大小为 C h i d × L C_{hid} \times L Chid×L,其中 L = ( H × W ) / ( 2 I − 1 × 2 I − 1 ) L=(H \times W)/(2^{I-1} \times 2^{I-1}) L=(H×W)/(2I1×2I1)是序列长度。为了保留位置信息,特定的位置嵌入被添加到向量化补丁 z x 0 z_x^0 zx0 z y 0 z_y^0 zy0中。之后,标记 z x 0 z_x^0 zx0 z y 0 z_y^0 zy0被输入到FViT中”。简单来说就是下采样结束后将特征图进行PatchEmbedding,这里就是ViT中的操作。接下来是FViT中的具体操作:
在FViT中以此有有 N 1 N_1 N1个SA操作、 N 2 N_2 N2个Ada-MBA操作、 N 3 N_3 N3个SA操作,SA层是用于深层特征增强的自注意力层,Ada-MBA是用于深层特征融合,SA层用于融合特征增强。下图是SA和Ada-MBA的具体操作:请添加图片描述
这里的 z x n z_x^n zxn z y n z_y^n zyn分别为第 n n n层在VIS分支和DSM分支中的隐藏特征,其中 n ∈ { 1 , 2... , N 1 + N 2 + N 3 } n \in \{1,2...,N_1+N_2+N3\} n{1,2...,N1+N2+N3}(最开始我懵了一下,为什么 n n n的大小可以到 N 1 + N 2 + N 3 N_1+N_2+N_3 N1+N2+N3,看了结构图发现,堆叠后这些隐藏特征是连续传入的)。SA层由两个SA模块、两个多层感知机(MLP)和层归一化(LN)算子组成,SA层的作用就是利用多头自注意力机制推到没中模态的全局关系。SA完成深层特征增强后,FViT进一步使用 N 2 N_2 N2个Ada-MBA层,Ada-MBA模块同时计算交叉注意力(CA)和自注意力(SA),来学习主模态和辅助模态之间的关系(就是DSM和VIS之间的关系),Ada-MBA模块如下图所示:请添加图片描述
作者将多模态特征输入 z x n − 1 z_x^{n-1} zxn1 z y n − 1 z_y^{n-1} zyn1划分为H个等长段(这里就是transformer里面的多头自注意力机制的思想),现在一共划分为了h个头,对于每个头来说,都有两组矩阵 { Q x , K x , V x } \{ Q_x,K_x,V_x \} {Qx,Kx,Vx} { Q y , K y , V y } \{ Q_y,K_y,V_y \} {Qy,Ky,Vy}通过线性投影 U x q k v U_x^{qkv} Uxqkv U y q k v U_y^{qkv} Uyqkv分别计算,接下来同时计算自注意力信息(SA)和交叉注意力信息(CA),下面是原文中的数学公式:请添加图片描述
接下来是融合SA和CA: g x n = λ x s a s a x + λ x c a c a x g_x^n=\lambda_x^{sa}sa_x+\lambda_x^{ca}ca_x gxn=λxsasax+λxcacax g y n = λ y s a s a y + λ y c a c a y g_y^n=\lambda_y^{sa}sa_y+\lambda_y^{ca}ca_y gyn=λysasay+λycacay,其中的 λ x s a \lambda_x^{sa} λxsa等参数都是可学习的加权系数,用于平衡来自SA和CA的贡献。
最后,融合后的特征图通过 N 3 N_3 N3个SA层进行增强。最终的输出记为 z N z_N zN,这是从最后一个SA层提取的特征图。

以上便是整个编码阶段了。

2.3 级联解码器

解码阶段就是通过利用多个上采样模块来恢复隐藏的融合特征。首先使用Reconstruction将而为输入序列 z N z_N zN重塑为大小为 C d e c × ( H / 2 I − 1 ) × ( W / 2 I − 1 ) C_{dec} \times (H/2^{I-1}) \times (W/2^{I-1}) Cdec×(H/2I1)×(W/2I1)的三维张量,随后,多个级联解码器块通过连接来自对应CNN主干网络的跳跃连接,恢复 H × W H \times W H×W

3. 实验结果

时隔一周,终于把实验做了,用的服务器是特斯拉P100,导师给的服务器,没办法,算力有限,训练用了差不多两天,得到权重文件后进行训练。在训练的时候报了许多错的,我忘记记录了,都是些小错误,然后训练的时候一堆警告,下面是我修改后的测试代码:

import numpy as np
import cv2
from tqdm import tqdm
from sklearn.metrics import confusion_matrix
import random
import time
import itertools
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
import torch.optim.lr_scheduler
import torch.nn.init
from utils import *
from torch.autograd import Variable
from IPython.display import clear_output
from model.vitcross_seg_modeling_heatmap import VisionTransformer as ViT_seg
from model.vitcross_seg_modeling_heatmap import CONFIGS as CONFIGS_ViT_seg
try:
    from urllib.request import URLopener
except ImportError:
    from urllib import URLopener
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from pynvml import *
nvmlInit()
handle = nvmlDeviceGetHandleByIndex(int(os.environ["CUDA_VISIBLE_DEVICES"]))
print("Device :", nvmlDeviceGetName(handle))

BATCH_SIZE = 1
config_vit = CONFIGS_ViT_seg['R50-ViT-B_16']
config_vit.n_classes = 6
config_vit.n_skip = 3
config_vit.patches.grid = (int(256 / 16), int(256 / 16))
net = ViT_seg(config_vit, img_size=256, num_classes=6).cuda()
net.load_from(weights=np.load(config_vit.pretrained_path))
params = 0
for name, param in net.named_parameters():
    params += param.nelement()
print(params)
# Load the datasets

print("training : ", train_ids)
print("testing : ", test_ids)
print("BATCH_SIZE: ", BATCH_SIZE)
print("Stride Size: ", Stride_Size)
train_set = ISPRS_dataset(train_ids, cache=CACHE)
train_loader = torch.utils.data.DataLoader(train_set,batch_size=BATCH_SIZE)

# for data, dsm, target in train_loader:
#     print(data.shape, dsm.shape, target.shape)
#     break

base_lr = 0.01
params_dict = dict(net.named_parameters())
params = []
for key, value in params_dict.items():
    if '_D' in key:
        # Decoder weights are trained at the nominal learning rate
        params += [{'params':[value],'lr': base_lr}]
    else:
        # Encoder weights are trained at lr / 2 (we have VGG-16 weights as initialization)
        params += [{'params':[value],'lr': base_lr / 2}]

optimizer = optim.SGD(net.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0005)
# We define the scheduler
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [25, 35, 45], gamma=0.1)


def test(net, test_ids, all=False, stride=WINDOW_SIZE[0], batch_size=BATCH_SIZE, window_size=WINDOW_SIZE):
    # Use the network on the test set
    ## Potsdam
    if DATASET == 'Potsdam':
        test_images = (1 / 255 * np.asarray(io.imread(DATA_FOLDER.format(id))[:, :, :3], dtype='float32') for id in test_ids)
        # test_images = (1 / 255 * np.asarray(io.imread(DATA_FOLDER.format(id))[:, :, (3, 0, 1, 2)][:, :, :3], dtype='float32') for id in test_ids)
    ## Vaihingen
    else:
        test_images = (1 / 255 * np.asarray(io.imread(DATA_FOLDER.format(id)), dtype='float32') for id in test_ids)
    test_dsms = (np.asarray(io.imread(DSM_FOLDER.format(id)), dtype='float32') for id in test_ids)
    test_labels = (np.asarray(io.imread(LABEL_FOLDER.format(id)), dtype='uint8') for id in test_ids)
    eroded_labels = (convert_from_color(io.imread(ERODED_FOLDER.format(id))) for id in test_ids)
    all_preds = []
    all_gts = []

    # Switch the network to inference mode
    index = 0
    for img, dsm, gt, gt_e in tqdm(zip(test_images, test_dsms, test_labels, eroded_labels), total=len(test_ids), leave=False):
        pred = np.zeros(img.shape[:2] + (N_CLASSES,))

        total = count_sliding_window(img, step=stride, window_size=window_size) // batch_size
        for i, coords in enumerate(
                tqdm(grouper(batch_size, sliding_window(img, step=stride, window_size=window_size)), total=total,
                    leave=True)):
            # Build the tensor
            image_patches = [np.copy(img[x:x + w, y:y + h]).transpose((2, 0, 1)) for x, y, w, h in coords]
            image_patches = np.asarray(image_patches)


            gt_patches = [np.copy(gt[x:x + w, y:y + h]).transpose((2, 0, 1)) for x, y, w, h in coords]
            gt_patches = np.asarray(gt_patches)


            min = np.min(dsm)
            max = np.max(dsm)
            dsm = (dsm - min) / (max - min)
            dsm_patches = [np.copy(dsm[x:x + w, y:y + h]) for x, y, w, h in coords]
            dsm_patches = np.asarray(dsm_patches)
            with torch.no_grad():
                image_patches = Variable(torch.from_numpy(image_patches).cuda())
                gt_patches = Variable(torch.from_numpy(gt_patches).cuda())
                dsm_patches = Variable(torch.from_numpy(dsm_patches).cuda())

            # Do the inference
            outs, heatmap1, heatmap2, heatmap3 = net(image_patches, dsm_patches)
            outs = outs.data.cpu().numpy()
            image_patches = np.asarray(255 * torch.squeeze(image_patches).cpu(), dtype='uint8').transpose((1, 2, 0))
            gt_patches = np.asarray(torch.squeeze(gt_patches).cpu(), dtype='uint8').transpose((1, 2, 0))
            heatmap1 = cv2.resize(heatmap1, (256, 256))
            # heatmap[heatmap < 0.7] = 0
            heatmap1 = np.uint8(255 * heatmap1)
            heatmap1 = cv2.applyColorMap(heatmap1, cv2.COLORMAP_JET)
            heatmap1 = heatmap1[:, :, (2, 1, 0)]

            epsilon = 1e-7
            heatmap2 = np.nan_to_num(heatmap2, nan=0.0, posinf=1.0, neginf=0.0)
            heatmap2 /= (np.max(heatmap2) + epsilon)

            heatmap2 = cv2.resize(heatmap2, (256, 256))
            # heatmap[heatmap < 0.7] = 0
            heatmap2 = np.nan_to_num(heatmap2, nan=0.0, posinf=1.0, neginf=0.0)
            heatmap2 = np.clip(heatmap2, 0.0, 1.0)
            heatmap2 = np.uint8(255 * heatmap2)
            heatmap2 = cv2.applyColorMap(heatmap2, cv2.COLORMAP_JET)
            heatmap2 = heatmap2[:, :, (2, 1, 0)]
            heatmap3 = cv2.resize(heatmap3, (256, 256))
            # heatmap[heatmap < 0.7] = 0
            heatmap3 = np.uint8(255 * heatmap3)
            heatmap3 = cv2.applyColorMap(heatmap3, cv2.COLORMAP_JET)
            heatmap3 = heatmap3[:, :, (2, 1, 0)]
            x_comp = 65
            y_comp = 100
            fig = plt.figure()
            fig.add_subplot(1, 5, 1)
            plt.imshow(image_patches)
            # plt.title('CFNet', y=-0.1)
            plt.axis('off')
            
            plt.gca().add_patch(plt.Rectangle((x_comp - 2, y_comp - 2), 2, 2, color='red', fill=False, linewidth=1))
            
            fig.add_subplot(1, 5, 2)
            plt.imshow(heatmap1)
            # heatmap_str = './CFNet_features' + str(featureid) + '.jpg'
            # cv2.imwrite(heatmap_str, heatmap1)
            plt.gca().add_patch(plt.Rectangle((x_comp - 2, y_comp - 2), 2, 2, color='red', fill=False, linewidth=1))
            plt.axis('off')
            fig.add_subplot(1, 5, 3)
            plt.imshow(heatmap2)
            # heatmap_str = './CFNet_features' + str(featureid+1) + '.jpg'
            # cv2.imwrite(heatmap_str, heatmap2)
            plt.gca().add_patch(plt.Rectangle((x_comp - 2, y_comp - 2), 2, 2, color='red', fill=False, linewidth=1))
            plt.axis('off')
            fig.add_subplot(1, 5, 4)
            plt.imshow(heatmap3)
            # heatmap_str = './CFNet_features' + str(featureid+1) + '.jpg'
            # cv2.imwrite(heatmap_str, heatmap2)
            plt.gca().add_patch(plt.Rectangle((x_comp - 2, y_comp - 2), 2, 2, color='red', fill=False, linewidth=1))
            plt.axis('off')
            
            
            fig.add_subplot(1, 5, 5)
            plt.imshow(gt_patches)
            plt.gca().add_patch(plt.Rectangle((x_comp - 2, y_comp - 2), 2, 2, color='red', fill=False, linewidth=1))
            clear_output()
            plt.axis('off')
            plt.savefig('./seg_results/heatmap_f_tree' + str(index) + '.pdf', dpi=1200)
            plt.close(fig)
            index += 1
            # plt.show()
            # plt.savefig('heatmap.png', dpi=1200)


            # Fill in the results array
            for out, (x, y, w, h) in zip(outs, coords):
                out = out.transpose((1, 2, 0))
                pred[x:x + w, y:y + h] += out
            del (outs)

        pred = np.argmax(pred, axis=-1)
        all_preds.append(pred)
        all_gts.append(gt_e)
        clear_output()
            
    accuracy = metrics(np.concatenate([p.ravel() for p in all_preds]),
                       np.concatenate([p.ravel() for p in all_gts]).ravel())
    if all:
        return accuracy, all_preds, all_gts
    else:
        return accuracy


def train(net, optimizer, epochs, scheduler=None, weights=WEIGHTS, save_epoch=1):
    losses = np.zeros(1000000)
    mean_losses = np.zeros(100000000)
    weights = weights.cuda()

    criterion = nn.NLLLoss2d(weight=weights)
    iter_ = 0
    acc_best = 90.0

    for e in range(1, epochs + 1):
        if scheduler is not None:
            scheduler.step()
        net.train()
        for batch_idx, (data, dsm, target) in enumerate(train_loader):
            data, dsm, target = Variable(data.cuda()), Variable(dsm.cuda()), Variable(target.cuda())
            optimizer.zero_grad()
            output = net(data, dsm)
            loss = CrossEntropy2d(output, target, weight=weights)
            loss.backward()
            optimizer.step()

            losses[iter_] = loss.data
            mean_losses[iter_] = np.mean(losses[max(0, iter_ - 100):iter_])

            if iter_ % 100 == 0:
                clear_output()
                rgb = np.asarray(255 * np.transpose(data.data.cpu().numpy()[0], (1, 2, 0)), dtype='uint8')
                pred = np.argmax(output.data.cpu().numpy()[0], axis=0)
                gt = target.data.cpu().numpy()[0]
                print('Train (epoch {}/{}) [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {}'.format(
                    e, epochs, batch_idx, len(train_loader),
                    100. * batch_idx / len(train_loader), loss.data, accuracy(pred, gt)))
            iter_ += 1

            del (data, target, loss)

            # if e % save_epoch == 0:
            if iter_ % 500 == 0:
                net.eval()
                acc = test(net, test_ids, all=False, stride=Stride_Size)
                net.train()
                if acc > acc_best:
                    torch.save(net.state_dict(), './resultsv_se_ablation/segnet256_epoch{}_{}'.format(e, acc))
                    acc_best = acc
    print('acc_best: ', acc_best)

# time_start=time.time()
# train(net, optimizer, 50, scheduler)
# time_end=time.time()
# print('Total Time Cost: ',time_end-time_start)

#####   test   ####
state_dict = torch.load('./resultsv_se_ablation/segnet256_epoch8_92.16619266956769', weights_only=True)
net.load_state_dict(state_dict)
# net.load_state_dict(torch.load(''), weights_only=True)
net.eval()
acc, all_preds, all_gts = test(net, test_ids, all=True, stride=256)
# print("Acc: ", acc)
# for p, id_ in zip(all_preds, test_ids):
#     img = convert_to_color(p)
#     # plt.imshow(img) and plt.show()
#     io.imsave('./resultsp_cross_transunet/inference_9101_tile{}.png'.format(id_), img)

主要就是进度条、heatmap2等的警告,不影响结果,只是在控制台一直警告看着很不爽。下面是控制台输出图:
在这里插入图片描述
真的是算力有限,有用的都已经不错了,下面是一些热力图,我就只跑了这篇论文的模型,也没有做对比、没有调参数:
请添加图片描述
请添加图片描述

上面就是这篇论文所有的东西了,这篇论文首先从多模态入手,从不同通道的图片去提取特征,然后进行一个特征的融合,基于U-Net模型,主要修改的是编码器阶段,通过自注意力机制和交叉注意力机制,对不同模态的特征图进行特征提取,大概就是这样。

### 单层多级实现或解释 单层多级调度(Single-Level Multilevel Implementation)通常指的是操作系统中的进程调度策略之一。这种机制通过多个队列来管理不同优先级的任务,从而优化资源分配和提高系统性能[^2]。 #### 基本概念 在单层多级调度中,所有的任务被划分为不同的优先级类别,并放置到对应的队列中。每个队列可以采用独立的调度算法进行处理。例如,高优先级队列可能使用先来先服务(First Come First Served, FCFS),而低优先级队列则可能应用轮转法(Round Robin)。当某个队列为空时,调度器会自动切换至下一个较低优先级的队列继续执行任务[^2]。 #### 实现细节 以下是单层多级调度的一个典型实现方式: 1. **队列划分**: 将所有待处理的任务按照其重要性和紧急程度分成若干个等级,并分别放入相应的队列。 2. **时间片设置**: 对于每一个级别的队列设定特定的时间片长度。较高优先级的队列往往拥有更短但更为频繁的时间片,以便快速响应关键操作;而对于较不重要的后台作业,则给予较长却较少触发的机会去占用CPU周期。 3. **动态调整**: 如果某项工作未能在一个固定时间段内完成计算需求,则它可能会降级进入下一层更低优序别的等待序列里重新排队等候再次获得运行权限直到最终结束为止。 下面是一个简单的伪代码表示如何基于Python模拟这样的一个多级别反馈队列(Multilevel Feedback Queue): ```python class Process: def __init__(self, pid, burst_time): self.pid = pid self.burst_time = burst_time self.remaining_time = burst_time def round_robin(queue, time_quantum): current_time = 0 while any(proc.remaining_time > 0 for proc in queue): for i in range(len(queue)): if queue[i].remaining_time > 0: if queue[i].remaining_time <= time_quantum: current_time += queue[i].remaining_time print(f"Process {queue[i].pid} completed at time {current_time}") queue[i].remaining_time = 0 else: current_time += time_quantum queue[i].remaining_time -= time_quantum print(f"Time quantum expired for process {queue[i].pid}, remaining time={queue[i].remaining_time}") # Example usage processes = [Process(1, 10), Process(2, 5), Process(3, 8)] round_robin(processes, 2) ``` 此脚本展示了基本的轮询方法论应用于一组假定进程中所遵循的原则——即轮流给各个程序分配一定量的工作时段直至它们全部消耗完毕各自所需的总耗时为止。 尽管上述例子仅演示了一个单独层次上的循环安排过程,但在实际应用场景当中我们还可以进一步扩展成包含多层次结构的形式以适应更加复杂的情况下的负载均衡考量因素等等诸多方面的要求。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值