A Multilevel Multimodal Fusion Transformer for Remote Sensing Semantic Segmentation
多级多模态融合变压器遥感语义分割
Xianping Ma,Xiaokang Zhang , Member, IEEE, Man-On Pun , Senior Member, IEEE, and Ming Liu
论文地址
代码地址
1.INTRODUCTION
只有少数工作致力于多模态融合任务(根据不同模态的特性融合多模态信息),强调根据不同模态的特性融合多模态信息。与自然图像相比,高分辨率遥感图像具有更严重的光谱异质性和更复杂的空间结构,此外,遥感数据中的地物在尺寸和形状上表现出更大的变化,这使得定位和识别目标变得困难,结果表明,从CV领域衍生出的基于CNN和Transformer的模型在有效学习判别性综合特征方面仍存在局限性。
解决多模态的三种策略:
- 早期融合:需要对多模态数据进行适当的对齐,并且可能缺乏对任务无关信息的鲁棒性。
- 晚期融合:利用多模态数据之间的交叉相关性方面能力有限。
- 中期融合:能够捕捉特征表示的跨模态依赖关系,使其在表示学习的背景下更加有效。
现有工作通常基于求和或拼接进行单层特征融合,忽略了不同特征层次上的长距离跨模态依赖关系。
2. 网络结构
在CNN主干网络中,SFF模块被用于浅层特征融合,而在FViT中的Ada-MBA层则被设计用于深层特征融合。双模态数据(即可见光图像和DSM数据)被用来详细说明所提出的FTransUNet,可见光图像被视为主要模态,而DMS数据则作为辅助模态,因为主要模态通常比辅助模态为地表分类提高更多的信息。
总体来说,这个网络框架也是一种编码器-解码器的结构,然后引入了自注意力机制、交叉注意力机制、残差网络,其他的就是对特征图的下采样和上采样。对图片的处理利用了两个ResNet分支结构,来分别提取可见光图像特征(VIS)和数字表面模型(DSM)的特征。每次下采样之后会通过SFF模块进行特征增强(特征图的拼接),经过四次下采样之后,将特征图展平(这里就和ViT的处理方式一样,其实我感觉前面的步骤除了SFF(浅特征融合)模块,其他的就和TransUNet的处理方式一模一样),然后将VIS和DSM展开后送入FViT处理,这里面包含了自注意力和交叉注意力,处理完之后就进行解码操作了,我还是觉得主要的变化全是在编码阶段,无论是跳连接还是解码操作和U-Net完全一样。
2.1 CNN融合
FTransNet在对图片处理阶段类似于TransUNet,它这里的输入分为了单通道的数字表面模型特征和可见光图片,对于每个分支来说,它通过四次卷积进行下采样,用于提取多尺度特征,特征图一次变为1/2、1/4、1/8,每次卷积操作后有一次SFF(浅特征融合),融合后的特征会被融合到下次的VIS中,并且还会通过跳连接,与解码阶段对应的特征图进行拼接融合。下面是SFF模块图示:
SFF模块就比较简单了,对VIS和DSM进行全局平均池化,然后是两个1*1的卷积,以及ReLU和Sigmoid函数,处理完之后,将来自VIS和DSM的特征经过加权并通过逐元素相加的方式融合,从而生成最终的浅层融合特征。
2.2 Fusion Vision Transformer
接着上面的下采样过程后,最后一次采样后的特征图进行了Embedding(这里是ViT里面的操作),传入到FViT中的特征图分别为xI(VIS)和yl(DSM),原文中是这么说的“
x
l
x_l
xl和
y
l
y_l
yl首先通过两个嵌入层和一个重塑操作进行标记化处理。嵌入层将输入的通道数从
C
l
C_l
Cl改变为
C
h
i
d
C_{hid}
Chid,随后重塑操作将嵌入层的输出战平为两个二维补丁序列,分别记为
z
x
0
z_x^0
zx0和
z
y
0
z_y^0
zy0,其大小为
C
h
i
d
×
L
C_{hid} \times L
Chid×L,其中
L
=
(
H
×
W
)
/
(
2
I
−
1
×
2
I
−
1
)
L=(H \times W)/(2^{I-1} \times 2^{I-1})
L=(H×W)/(2I−1×2I−1)是序列长度。为了保留位置信息,特定的位置嵌入被添加到向量化补丁
z
x
0
z_x^0
zx0和
z
y
0
z_y^0
zy0中。之后,标记
z
x
0
z_x^0
zx0和
z
y
0
z_y^0
zy0被输入到FViT中”。简单来说就是下采样结束后将特征图进行PatchEmbedding,这里就是ViT中的操作。接下来是FViT中的具体操作:
在FViT中以此有有
N
1
N_1
N1个SA操作、
N
2
N_2
N2个Ada-MBA操作、
N
3
N_3
N3个SA操作,SA层是用于深层特征增强的自注意力层,Ada-MBA是用于深层特征融合,SA层用于融合特征增强。下图是SA和Ada-MBA的具体操作:
这里的
z
x
n
z_x^n
zxn和
z
y
n
z_y^n
zyn分别为第
n
n
n层在VIS分支和DSM分支中的隐藏特征,其中
n
∈
{
1
,
2...
,
N
1
+
N
2
+
N
3
}
n \in \{1,2...,N_1+N_2+N3\}
n∈{1,2...,N1+N2+N3}(最开始我懵了一下,为什么
n
n
n的大小可以到
N
1
+
N
2
+
N
3
N_1+N_2+N_3
N1+N2+N3,看了结构图发现,堆叠后这些隐藏特征是连续传入的)。SA层由两个SA模块、两个多层感知机(MLP)和层归一化(LN)算子组成,SA层的作用就是利用多头自注意力机制推到没中模态的全局关系。SA完成深层特征增强后,FViT进一步使用
N
2
N_2
N2个Ada-MBA层,Ada-MBA模块同时计算交叉注意力(CA)和自注意力(SA),来学习主模态和辅助模态之间的关系(就是DSM和VIS之间的关系),Ada-MBA模块如下图所示:
作者将多模态特征输入
z
x
n
−
1
z_x^{n-1}
zxn−1和
z
y
n
−
1
z_y^{n-1}
zyn−1划分为H个等长段(这里就是transformer里面的多头自注意力机制的思想),现在一共划分为了h个头,对于每个头来说,都有两组矩阵
{
Q
x
,
K
x
,
V
x
}
\{ Q_x,K_x,V_x \}
{Qx,Kx,Vx}和
{
Q
y
,
K
y
,
V
y
}
\{ Q_y,K_y,V_y \}
{Qy,Ky,Vy}通过线性投影
U
x
q
k
v
U_x^{qkv}
Uxqkv和
U
y
q
k
v
U_y^{qkv}
Uyqkv分别计算,接下来同时计算自注意力信息(SA)和交叉注意力信息(CA),下面是原文中的数学公式:
接下来是融合SA和CA:
g
x
n
=
λ
x
s
a
s
a
x
+
λ
x
c
a
c
a
x
g_x^n=\lambda_x^{sa}sa_x+\lambda_x^{ca}ca_x
gxn=λxsasax+λxcacax、
g
y
n
=
λ
y
s
a
s
a
y
+
λ
y
c
a
c
a
y
g_y^n=\lambda_y^{sa}sa_y+\lambda_y^{ca}ca_y
gyn=λysasay+λycacay,其中的
λ
x
s
a
\lambda_x^{sa}
λxsa等参数都是可学习的加权系数,用于平衡来自SA和CA的贡献。
最后,融合后的特征图通过
N
3
N_3
N3个SA层进行增强。最终的输出记为
z
N
z_N
zN,这是从最后一个SA层提取的特征图。
以上便是整个编码阶段了。
2.3 级联解码器
解码阶段就是通过利用多个上采样模块来恢复隐藏的融合特征。首先使用Reconstruction将而为输入序列 z N z_N zN重塑为大小为 C d e c × ( H / 2 I − 1 ) × ( W / 2 I − 1 ) C_{dec} \times (H/2^{I-1}) \times (W/2^{I-1}) Cdec×(H/2I−1)×(W/2I−1)的三维张量,随后,多个级联解码器块通过连接来自对应CNN主干网络的跳跃连接,恢复 H × W H \times W H×W。
3. 实验结果
时隔一周,终于把实验做了,用的服务器是特斯拉P100,导师给的服务器,没办法,算力有限,训练用了差不多两天,得到权重文件后进行训练。在训练的时候报了许多错的,我忘记记录了,都是些小错误,然后训练的时候一堆警告,下面是我修改后的测试代码:
import numpy as np
import cv2
from tqdm import tqdm
from sklearn.metrics import confusion_matrix
import random
import time
import itertools
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
import torch.optim.lr_scheduler
import torch.nn.init
from utils import *
from torch.autograd import Variable
from IPython.display import clear_output
from model.vitcross_seg_modeling_heatmap import VisionTransformer as ViT_seg
from model.vitcross_seg_modeling_heatmap import CONFIGS as CONFIGS_ViT_seg
try:
from urllib.request import URLopener
except ImportError:
from urllib import URLopener
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from pynvml import *
nvmlInit()
handle = nvmlDeviceGetHandleByIndex(int(os.environ["CUDA_VISIBLE_DEVICES"]))
print("Device :", nvmlDeviceGetName(handle))
BATCH_SIZE = 1
config_vit = CONFIGS_ViT_seg['R50-ViT-B_16']
config_vit.n_classes = 6
config_vit.n_skip = 3
config_vit.patches.grid = (int(256 / 16), int(256 / 16))
net = ViT_seg(config_vit, img_size=256, num_classes=6).cuda()
net.load_from(weights=np.load(config_vit.pretrained_path))
params = 0
for name, param in net.named_parameters():
params += param.nelement()
print(params)
# Load the datasets
print("training : ", train_ids)
print("testing : ", test_ids)
print("BATCH_SIZE: ", BATCH_SIZE)
print("Stride Size: ", Stride_Size)
train_set = ISPRS_dataset(train_ids, cache=CACHE)
train_loader = torch.utils.data.DataLoader(train_set,batch_size=BATCH_SIZE)
# for data, dsm, target in train_loader:
# print(data.shape, dsm.shape, target.shape)
# break
base_lr = 0.01
params_dict = dict(net.named_parameters())
params = []
for key, value in params_dict.items():
if '_D' in key:
# Decoder weights are trained at the nominal learning rate
params += [{'params':[value],'lr': base_lr}]
else:
# Encoder weights are trained at lr / 2 (we have VGG-16 weights as initialization)
params += [{'params':[value],'lr': base_lr / 2}]
optimizer = optim.SGD(net.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0005)
# We define the scheduler
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [25, 35, 45], gamma=0.1)
def test(net, test_ids, all=False, stride=WINDOW_SIZE[0], batch_size=BATCH_SIZE, window_size=WINDOW_SIZE):
# Use the network on the test set
## Potsdam
if DATASET == 'Potsdam':
test_images = (1 / 255 * np.asarray(io.imread(DATA_FOLDER.format(id))[:, :, :3], dtype='float32') for id in test_ids)
# test_images = (1 / 255 * np.asarray(io.imread(DATA_FOLDER.format(id))[:, :, (3, 0, 1, 2)][:, :, :3], dtype='float32') for id in test_ids)
## Vaihingen
else:
test_images = (1 / 255 * np.asarray(io.imread(DATA_FOLDER.format(id)), dtype='float32') for id in test_ids)
test_dsms = (np.asarray(io.imread(DSM_FOLDER.format(id)), dtype='float32') for id in test_ids)
test_labels = (np.asarray(io.imread(LABEL_FOLDER.format(id)), dtype='uint8') for id in test_ids)
eroded_labels = (convert_from_color(io.imread(ERODED_FOLDER.format(id))) for id in test_ids)
all_preds = []
all_gts = []
# Switch the network to inference mode
index = 0
for img, dsm, gt, gt_e in tqdm(zip(test_images, test_dsms, test_labels, eroded_labels), total=len(test_ids), leave=False):
pred = np.zeros(img.shape[:2] + (N_CLASSES,))
total = count_sliding_window(img, step=stride, window_size=window_size) // batch_size
for i, coords in enumerate(
tqdm(grouper(batch_size, sliding_window(img, step=stride, window_size=window_size)), total=total,
leave=True)):
# Build the tensor
image_patches = [np.copy(img[x:x + w, y:y + h]).transpose((2, 0, 1)) for x, y, w, h in coords]
image_patches = np.asarray(image_patches)
gt_patches = [np.copy(gt[x:x + w, y:y + h]).transpose((2, 0, 1)) for x, y, w, h in coords]
gt_patches = np.asarray(gt_patches)
min = np.min(dsm)
max = np.max(dsm)
dsm = (dsm - min) / (max - min)
dsm_patches = [np.copy(dsm[x:x + w, y:y + h]) for x, y, w, h in coords]
dsm_patches = np.asarray(dsm_patches)
with torch.no_grad():
image_patches = Variable(torch.from_numpy(image_patches).cuda())
gt_patches = Variable(torch.from_numpy(gt_patches).cuda())
dsm_patches = Variable(torch.from_numpy(dsm_patches).cuda())
# Do the inference
outs, heatmap1, heatmap2, heatmap3 = net(image_patches, dsm_patches)
outs = outs.data.cpu().numpy()
image_patches = np.asarray(255 * torch.squeeze(image_patches).cpu(), dtype='uint8').transpose((1, 2, 0))
gt_patches = np.asarray(torch.squeeze(gt_patches).cpu(), dtype='uint8').transpose((1, 2, 0))
heatmap1 = cv2.resize(heatmap1, (256, 256))
# heatmap[heatmap < 0.7] = 0
heatmap1 = np.uint8(255 * heatmap1)
heatmap1 = cv2.applyColorMap(heatmap1, cv2.COLORMAP_JET)
heatmap1 = heatmap1[:, :, (2, 1, 0)]
epsilon = 1e-7
heatmap2 = np.nan_to_num(heatmap2, nan=0.0, posinf=1.0, neginf=0.0)
heatmap2 /= (np.max(heatmap2) + epsilon)
heatmap2 = cv2.resize(heatmap2, (256, 256))
# heatmap[heatmap < 0.7] = 0
heatmap2 = np.nan_to_num(heatmap2, nan=0.0, posinf=1.0, neginf=0.0)
heatmap2 = np.clip(heatmap2, 0.0, 1.0)
heatmap2 = np.uint8(255 * heatmap2)
heatmap2 = cv2.applyColorMap(heatmap2, cv2.COLORMAP_JET)
heatmap2 = heatmap2[:, :, (2, 1, 0)]
heatmap3 = cv2.resize(heatmap3, (256, 256))
# heatmap[heatmap < 0.7] = 0
heatmap3 = np.uint8(255 * heatmap3)
heatmap3 = cv2.applyColorMap(heatmap3, cv2.COLORMAP_JET)
heatmap3 = heatmap3[:, :, (2, 1, 0)]
x_comp = 65
y_comp = 100
fig = plt.figure()
fig.add_subplot(1, 5, 1)
plt.imshow(image_patches)
# plt.title('CFNet', y=-0.1)
plt.axis('off')
plt.gca().add_patch(plt.Rectangle((x_comp - 2, y_comp - 2), 2, 2, color='red', fill=False, linewidth=1))
fig.add_subplot(1, 5, 2)
plt.imshow(heatmap1)
# heatmap_str = './CFNet_features' + str(featureid) + '.jpg'
# cv2.imwrite(heatmap_str, heatmap1)
plt.gca().add_patch(plt.Rectangle((x_comp - 2, y_comp - 2), 2, 2, color='red', fill=False, linewidth=1))
plt.axis('off')
fig.add_subplot(1, 5, 3)
plt.imshow(heatmap2)
# heatmap_str = './CFNet_features' + str(featureid+1) + '.jpg'
# cv2.imwrite(heatmap_str, heatmap2)
plt.gca().add_patch(plt.Rectangle((x_comp - 2, y_comp - 2), 2, 2, color='red', fill=False, linewidth=1))
plt.axis('off')
fig.add_subplot(1, 5, 4)
plt.imshow(heatmap3)
# heatmap_str = './CFNet_features' + str(featureid+1) + '.jpg'
# cv2.imwrite(heatmap_str, heatmap2)
plt.gca().add_patch(plt.Rectangle((x_comp - 2, y_comp - 2), 2, 2, color='red', fill=False, linewidth=1))
plt.axis('off')
fig.add_subplot(1, 5, 5)
plt.imshow(gt_patches)
plt.gca().add_patch(plt.Rectangle((x_comp - 2, y_comp - 2), 2, 2, color='red', fill=False, linewidth=1))
clear_output()
plt.axis('off')
plt.savefig('./seg_results/heatmap_f_tree' + str(index) + '.pdf', dpi=1200)
plt.close(fig)
index += 1
# plt.show()
# plt.savefig('heatmap.png', dpi=1200)
# Fill in the results array
for out, (x, y, w, h) in zip(outs, coords):
out = out.transpose((1, 2, 0))
pred[x:x + w, y:y + h] += out
del (outs)
pred = np.argmax(pred, axis=-1)
all_preds.append(pred)
all_gts.append(gt_e)
clear_output()
accuracy = metrics(np.concatenate([p.ravel() for p in all_preds]),
np.concatenate([p.ravel() for p in all_gts]).ravel())
if all:
return accuracy, all_preds, all_gts
else:
return accuracy
def train(net, optimizer, epochs, scheduler=None, weights=WEIGHTS, save_epoch=1):
losses = np.zeros(1000000)
mean_losses = np.zeros(100000000)
weights = weights.cuda()
criterion = nn.NLLLoss2d(weight=weights)
iter_ = 0
acc_best = 90.0
for e in range(1, epochs + 1):
if scheduler is not None:
scheduler.step()
net.train()
for batch_idx, (data, dsm, target) in enumerate(train_loader):
data, dsm, target = Variable(data.cuda()), Variable(dsm.cuda()), Variable(target.cuda())
optimizer.zero_grad()
output = net(data, dsm)
loss = CrossEntropy2d(output, target, weight=weights)
loss.backward()
optimizer.step()
losses[iter_] = loss.data
mean_losses[iter_] = np.mean(losses[max(0, iter_ - 100):iter_])
if iter_ % 100 == 0:
clear_output()
rgb = np.asarray(255 * np.transpose(data.data.cpu().numpy()[0], (1, 2, 0)), dtype='uint8')
pred = np.argmax(output.data.cpu().numpy()[0], axis=0)
gt = target.data.cpu().numpy()[0]
print('Train (epoch {}/{}) [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {}'.format(
e, epochs, batch_idx, len(train_loader),
100. * batch_idx / len(train_loader), loss.data, accuracy(pred, gt)))
iter_ += 1
del (data, target, loss)
# if e % save_epoch == 0:
if iter_ % 500 == 0:
net.eval()
acc = test(net, test_ids, all=False, stride=Stride_Size)
net.train()
if acc > acc_best:
torch.save(net.state_dict(), './resultsv_se_ablation/segnet256_epoch{}_{}'.format(e, acc))
acc_best = acc
print('acc_best: ', acc_best)
# time_start=time.time()
# train(net, optimizer, 50, scheduler)
# time_end=time.time()
# print('Total Time Cost: ',time_end-time_start)
##### test ####
state_dict = torch.load('./resultsv_se_ablation/segnet256_epoch8_92.16619266956769', weights_only=True)
net.load_state_dict(state_dict)
# net.load_state_dict(torch.load(''), weights_only=True)
net.eval()
acc, all_preds, all_gts = test(net, test_ids, all=True, stride=256)
# print("Acc: ", acc)
# for p, id_ in zip(all_preds, test_ids):
# img = convert_to_color(p)
# # plt.imshow(img) and plt.show()
# io.imsave('./resultsp_cross_transunet/inference_9101_tile{}.png'.format(id_), img)
主要就是进度条、heatmap2等的警告,不影响结果,只是在控制台一直警告看着很不爽。下面是控制台输出图:
真的是算力有限,有用的都已经不错了,下面是一些热力图,我就只跑了这篇论文的模型,也没有做对比、没有调参数:
上面就是这篇论文所有的东西了,这篇论文首先从多模态入手,从不同通道的图片去提取特征,然后进行一个特征的融合,基于U-Net模型,主要修改的是编码器阶段,通过自注意力机制和交叉注意力机制,对不同模态的特征图进行特征提取,大概就是这样。