Abstract
手术场景分割是机器人辅助手术中的一项关键任务。然而,手术场景的复杂性,主要包括局部特征相似性(例如,在不同解剖组织之间)、术中复杂的伪影和难以区分的边界,对准确分割构成了重大挑战。为了解决这些问题,我们提出了长条核注意力网络(LSKANet),包括两个设计良好的模块,分别是双块大核注意力模块(DLKA)和多尺度亲和特征融合模块(MAFF),可以实现手术图像的精确分割。具体来说,通过在两个块中引入具有不同拓扑结构(级联和并行)的条形卷积和大内核设计,DLKA可以充分利用区域和条状手术特征,并提取视觉和结构信息,以减少由局部特征相似性引起的错误分割。在MAFF中,从多尺度特征图计算的亲和矩阵作为特征融合权重应用,这有助于通过抑制不相关区域的激活来解决伪影的干扰。此外,该文提出边界引导头(BGH)的混合损失,以帮助网络有效地分割难以区分的边界。我们在具有不同手术场景的三个数据集上评估了所提出的LSKANet。实验结果表明,我们的方法在所有三个数据集上都取得了最新的结果,mIoU分别提高了2.6%、1.4%和3.4%。此外,我们的方法与不同的backbone兼容,可以显著提高它们的分割精度。
I. Overview
II. Dual-block Large Kernel Attention
DLKA由两个子Block构成,分别为Anatomy Block和Instrument Block。两种block的共有特点是都用条形卷积(Strip Convolution)来实现高效大卷积核设计,区别在于Anatomy Block的条形卷积采用串行设计(先1×W再H×1),而Instrument Block的条形卷积采用并行设计(同时1×W,H×1再相加)。
Q&A
Q1: 大卷积核优缺点。
A1: 优点,扩大感受野,捕捉长距离依赖;缺点,容易引入不相关区域中的噪声。
Q2: 为什么Anatomy Block是串行而Instrument Block是并行。
A2: 串行建模的是矩形区域,并行建模的仍是条状区域,与内镜组织和手术器械的客观形状相符。
Abla Study
将DLKA模块替换为标准卷积后, 完整模型参数由72.26M降低至72.25M, mIoU由69.6降低至68.8。注意原文并未说明这里所述的"标准卷积"的尺寸。
移除整个DLKA模块,mIoU由69.6降低至68.0。
移除DLKA中的Anatomy Block,mIoU由69.6降低至69.5。这里mIoU只降了0.1是因为Anatomy类性能降低的同时Instrument类的性能有所上升。
移除DLKA中的Instrument Block,mIoU由69.6降低至69.0。
Code-AnaBlock
import torch
import torch.nn as nn
from thop import profile
class AnaBlock(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
self.conv1_1 = nn.Conv2d(dim, dim, (1, 11), padding=(0, 5), groups=dim)
self.conv1_2 = nn.Conv2d(dim, dim, (11, 1), padding=(5, 0), groups=dim)
self.conv2_1 = nn.Conv2d(dim, dim, (1, 21), padding=(0, 10), groups=dim)
self.conv2_2 = nn.Conv2d(dim, dim, (21, 1), padding=(10, 0), groups=dim)
self.conv3_1 = nn.Conv2d(dim, dim, (1, 31), padding=(0, 15), groups=dim)
self.conv3_2 = nn.Conv2d(dim, dim, (31, 1), padding=(15, 0), groups=dim)
self.conv3 = nn.Conv2d(dim, dim, 1)
def forward(self, x):
u = x.clone()
attn = self.conv0(x)
attn_1 = self.conv1_1(attn)
attn_1 = self.conv1_2(attn_1)
attn_2 = self.conv2_1(attn)
attn_2 = self.conv2_2(attn_2)
attn_3 = self.conv3_1(attn)
attn_3 = self.conv3_2(attn_3)
attn = attn + attn_1 + attn_2 + attn_3
attn = self.conv3(attn)
return attn * u
if __name__ == '__main__':
net = AnaBlock(dim=512)
x = torch.randn([1, 512, 22, 22])
out = net(x)
flops, params = profile(net, inputs=(x, ))
print(out.shape) # 1, 512, 22, 22
print('FLOPs = ' + str