GCN-assisted attention-guided UNet for automated retinal OCT segmentation | 论文阅读记录

TensorFlow-v2.9

TensorFlow-v2.9

TensorFlow

TensorFlow 是由Google Brain 团队开发的开源机器学习框架,广泛应用于深度学习研究和生产环境。 它提供了一个灵活的平台,用于构建和训练各种机器学习模型

文章来源

Expert Systems with Applications, 2024

文章背景与动机

在治疗neovascular age-related macular degeneration (nAMD)时,对视网膜的SD-OCT图像进行分割很重要。
在这里插入图片描述

但人工分割需要大量资源,因此自动分割势在必行。
为了提高自动分割的性能,解决UNet的学习过程会导致空间信息减少

前人工作的缺点:

  • 丢失空间信息
  • 上下文推理存在问题
  • 空间信息推理

提出了graph convolution network (GCN)-assisted attention-guided UNet(结合了图卷积、Transformer与注意力引导的UNet)

研究内容

主要是学习网络的结构,关键是
在这里插入图片描述

GCN-assisted feature embedding

目的:提高模型的图像分割性能,让模型有更强的空间推理(spatial inference)能力
使用了GCN来代替CNN设计嵌入层,这样可以获得远距离的上下文信息
在这里插入图片描述
在这里插入图片描述
其中:

  • X X X C C Cx H H Hx W W W,输入数据
  • X r X_r Xr C r C_r Cr x H W HW HW,降维特征映射
  • X a X_a Xa C n C_n Cn x H W HW HW,投影矩阵
  • X d X_d Xd C n C_n Cn x H H H x W W W,逆投影矩阵
  • H H H(Node Feature): C n C_n Cnx C r C_r Cr X r X_r Xr X a X_a Xa进行矩阵乘法所得
  • F F F C C C x H H H x W W W,经过重塑与卷积变回原本大小

过程

  1. X提供给两个1x1卷积层,并产生一个降维矩阵 X r X_r Xr与一个投影矩阵(projection matrix) X a X_a Xa,其中 X r X_r Xr经过了Reshape, X a X_a Xa经过了Reshape与转置。
  2. 然后 X r X_r Xr X a X_a Xa进行矩阵乘法,得到节点特征 H H H,并将它输入到GCN block中进行图卷积,输出结果为 Z Z Z
  3. X X X经过卷积层的结果重塑为逆投影矩阵(inverse projection matrix) X d X_d Xd,然后将其与 Z Z Z进行矩阵乘法,目的是将数据映射回原本的隐空间
  4. 最后对输出的特征矩阵进行重塑,让其尺寸变为原本的大小

输入与输出的大小是一样,并没有改变原本的尺寸,只是增加了信息,增强模型的空间推理能力

投影矩阵(projection matrix)

是一种用于将高维数据投影到低维空间的线性变换矩阵,通过矩阵乘法来进行降维。
在这里插入图片描述

逆矩阵乘法

是与投影矩阵相反的操作,它用于从低维空间还原到高维空间。逆投影矩阵可以将降维后的特征重新映射回其原始的高维空间,或至少是原始空间的一个近似版本。

图卷积

是一种处理图结构数据的深度学习模型,常用于非欧几里得空间的数据。

在传统的CNN中,卷积操作在规则的二维网格上进行(如图像像素矩阵),但图中的节点和边并不遵循规则的网格结构。
因此,GCN通过将卷积的思想扩展到图结构上,能够捕捉节点之间的局部关系,利用邻居节点的信息更新每个节点的特征表示。

图卷积的核心构成

  • 节点:图的基本单元,每个节点都有自己的特征。将图划分为一个个节点。
  • :表示节点之间的连接或关系
  • 卷积操作:通过邻居节点的特征聚合更新每个节点的特征表示,类似于CNN中卷积核从邻域聚合信息

在这里插入图片描述

在这篇论文中被用于增强特征的空间推理能力,视网膜OCT图像中的像素点之间并不是规则的网格关系,因此将这些像素视为图的节点,通过GCN来捕捉邻域之间的关系,可以更好地建模视网膜结构的长距离依赖性

空间推理(spatial inference)

是指利用空间信息(例如对象的相对位置、形状、距离等)来进行推理和决策的过程,而且空间推理通过需要结合上下文信息
在图像分割邻域,空间推理用于理解图像中像素或区域的相对位置及其所属的对象类别

在这篇论文中,空间推理可以帮助获得更多的远距离上下文信息

Transformer-based reasoning module

这里就是使用了多头注意力机制模块位置编码
多头注意力机制
位置编码
目的:增加注意力信息,让网络可以获得更加丰富的上下文信息,让网络可以捕获远距离像素之间的上下文关系

在这里插入图片描述

long-range context (长距离上下文)

指的是在图像或其他数据中,利用较远位置的元素或信息当前元素进行推断或决策。相对于仅依赖局部信息,长距离上下文强调远处区域的信息在当前区域的推理和理解过程中所起的作用。

在图像中,不同区域的关系可能存在长距离依赖。例如,在医学图像中,某些病灶区域的特征可能跨越较大的空间无法仅通过局部特征进行识别。长距离上下文能帮助模型捕捉到整个结构,识别出相关性更高的远距离区域。

Multi-scale skip connection

原因:经过下采样会导致空间信息的损失,而该模型需要使用空间推理能力来分割图像,因此需要补充这些信息。

从网络结构图中,可以看出不仅解码器块与编码器中的对应层进行了拼接(绿色箭头),不同的解码器块之间也进行拼接(黑色箭头)。
在这里插入图片描述

解码器块之间拼接的运算公式:(看清楚括号)
第n个解码模块的输入由前n-1个解码模块的输出构成,其中第n-1个模块的输出需要经过双线性插值上采样后与剩下的输出进行拼接以此得到第n个模块的输入。然后将其输入到第n个解码模块进行处理。
在这里插入图片描述

  • F n F_n Fn:第n个解码器模块的输出
  • f n ( ) f_n() fn():第n个解码器模块的卷积操作或特征处理操作,也就是对输入数据的处理函数
  • v n − 1 ( F n − 1 ) v_{n-1}(F_{n-1}) vn1(Fn1):对上一层的输出特征图( F n − 1 F_{n-1} Fn1)进行双线性插值上采样操作,以此来匹配当前层所需的分辨率
  • ⊕:拼接操作,即将多个特征图沿特定维度进行拼接,拼接的目的是融合不同层次的特征,以利用它们所包含的多尺度信息。

损失函数

使用了Sorensen–Dice loss与二元交叉熵损失函数(Binary Cross-Entropy, BCE)。

在这里插入图片描述
其中:

  • y i y_i yi:第i个像素的真实值
  • t t t:当前图像的像素个数
  • p i p_i pi:第i个像素的置信度

Sorensen–Dice loss

在这里插入图片描述

Binary Cross-Entropy

二元交叉熵损失函数通常用于二分类任务,也可扩展用于多分类任务.
在这里插入图片描述

实验

评价标准

Dice Score

判断两个数据集的相似度
在这里插入图片描述

Pixel Accuracy

PA=正确分类的像素个数/总像素个数
在这里插入图片描述

Sensitivity

敏感度、真阳性率、召回率
表示在所有真实为正类(例如为患病)的样本中,能被正确预测为正类的比例

  • True Positives (TP):真实为正类且模型预测为正类的样本数。
  • False Negatives (FN):真实为正类但模型错误预测为负类的样本数。
    在这里插入图片描述

敏感度高表示模型能够较好地识别出所有真实的正类样本,不会遗漏正类,但可以会误判

Specificity

真阴性率
表示在所有真实为负类(例如为无病)的样本中,能够正确预测为负类的比例

  • True Negatives (TN):真实为负类且模型正确预测为负类的样本数。
  • False Positives (FP):真实为负类但模型错误预测为正类的样本数。
    在这里插入图片描述

特异度高 表示模型能够较好地识别出所有真实的负类样本,避免将健康的样本误诊为正类,不会误判,但可能会漏判

性能测试

在nAMD进行性能对比测试
IRF、SRF、SHRM、PED是图像中的各个类别

Dice Score

在这里插入图片描述

pixel accuracy

在这里插入图片描述

Sensitivity and Specificity

在这里插入图片描述

消融实验

实验对象:多尺度跳层连接与GCN
在这里插入图片描述

验证在其他数据集中的性能优势

在RETOUCH数据集中测试各个网络的性能,验证网络在其他数据集也有优势
在这里插入图片描述

可视化分割效果

粉色框:被漏分类的
蓝色框:分类错误的

在这里插入图片描述

缺陷

仍会出现分类错误的情况
在这里插入图片描述

您可能感兴趣的与本文相关的镜像

TensorFlow-v2.9

TensorFlow-v2.9

TensorFlow

TensorFlow 是由Google Brain 团队开发的开源机器学习框架,广泛应用于深度学习研究和生产环境。 它提供了一个灵活的平台,用于构建和训练各种机器学习模型

### 带有注意力机制的图卷积网络(GCN) #### 背景介绍 图卷积网络 (Graph Convolutional Networks, GCNs) 是一种专门设计用于处理图结构数据的神经网络模型。传统的 GCN 主要通过聚合邻居节点的信息来进行特征学习,但在某些情况下,这种简单的聚合方式可能无法有效捕捉节点之间的重要关系。为了增强 GCN 的表达能力,研究者们引入了注意力机制 (Attention Mechanism),使得模型能够动态地分配不同的权重给邻接节点。 #### 注意力机制的作用 注意力机制允许模型自适应地调整不同输入的重要性程度。例如,在交通预测任务中,可以通过软注意力机制对提取的空间-时间流量特征进行加权求和,从而提高最终预测的效果[^1]。具体来说,注意力机制的核心在于计算上下文向量 (Context Vector),该向量是对编码序列按注意力概率加权后的结果[^2]。这种方法可以直观地反映各个编码字符对于目标预测的影响大小。 #### 实现带注意力机制的 GCN 以下是基于 PyTorch 的一个简单实现示例: ```python import torch import torch.nn as nn import torch.nn.functional as F class GraphAttentionLayer(nn.Module): """ Simple GAT layer, similar to https://arxiv.org/abs/1710.10903 """ def __init__(self, in_features, out_features, dropout, alpha, concat=True): super(GraphAttentionLayer, self).__init__() self.dropout = dropout self.in_features = in_features self.out_features = out_features self.alpha = alpha self.concat = concat # 定义线性变换 W 和 a 参数矩阵 self.W = nn.Parameter(torch.empty(size=(in_features, out_features))) nn.init.xavier_uniform_(self.W.data, gain=1.414) self.a = nn.Parameter(torch.empty(size=(2*out_features, 1))) nn.init.xavier_uniform_(self.a.data, gain=1.414) self.leakyrelu = nn.LeakyReLU(self.alpha) def forward(self, h, adj): Wh = torch.mm(h, self.W) # h.shape: (N, in_features), Wh.shape: (N, out_features) e = self._prepare_attentional_mechanism_input(Wh) zero_vec = -9e15*torch.ones_like(e) attention = torch.where(adj > 0, e, zero_vec) attention = F.softmax(attention, dim=1) attention = F.dropout(attention, self.dropout, training=self.training) h_prime = torch.matmul(attention, Wh) if self.concat: return F.elu(h_prime) else: return h_prime def _prepare_attentional_mechanism_input(self, Wh): # 计算每一对节点之间的注意力分数 N = Wh.size()[0] Wh_repeated_in_chunks = Wh.repeat_interleave(N, dim=0) Wh_repeated_alternating = Wh.repeat(N, 1) all_combinations_matrix = torch.cat([Wh_repeated_in_chunks, Wh_repeated_alternating], dim=1) # all_combinations_matrix.shape == (N * N, 2 * out_features) return self.leakyrelu(torch.matmul(all_combinations_matrix, self.a).squeeze(1)) class GAT(nn.Module): def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads): """Dense version of GAT.""" super(GAT, self).__init__() self.dropout = dropout self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) for _ in range(nheads)] for i, attention in enumerate(self.attentions): self.add_module('attention_{}'.format(i), attention) self.out_att = GraphAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, concat=False) def forward(self, x, adj): x = F.dropout(x, self.dropout, training=self.training) x = torch.cat([att(x, adj) for att in self.attentions], dim=1) x = F.dropout(x, self.dropout, training=self.training) x = F.elu(self.out_att(x, adj)) return F.log_softmax(x, dim=1) ``` 上述代码定义了一个基本的图注意力网络 (GAT),其中 `GraphAttentionLayer` 类实现了单个注意力头的功能,而 `GAT` 则组合多个这样的层形成完整的网络架构。 #### 渠道注意机制的应用 除了节点级别的注意力外,还可以考虑加入渠道注意机制以进一步提升性能。正如 Hu 等人在其工作中提到的那样,通道注意机制能显著改善特征表示的质量,因为它可以根据重要性重新校准各通道上的响应强度[^3]。这有助于挖掘隐藏在高维张量中的潜在模式并减少冗余信息干扰。 #### 小结 综上所述,将注意力机制融入到传统 GCN 中不仅增强了模型灵活性还提高了准确性。无论是采用单独形式还是与其他技术联合运用都展现了良好效果。
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值