Edge-aware U-net with gated convolution for retinal vessel segmentation

本文介绍了一种结合边缘感知流、门控卷积的U-Net模型,用于提高视网膜血管尤其是微血管边缘的精确分割。通过多任务训练策略,模型在DRIVE、STARE和CHASE_DB1数据集上展示了改进的性能,特别是在微血管边缘检测方面有显著提升。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

带门控卷积的边缘感知U-net用于视网膜血管分割

Biomedical Signal Processing and Control 【2021】

背景

视网膜血管,特别是微小血管的形态变化在特定心血管和眼科疾病的诊断和临床预后中起着重要作用。然而,视网膜血管,特别是视网膜血管分支的血管末端的精确分割往往较差。将新的边缘感知流引入U-Net编码器-解码器架构中,以指导视网膜血管分割,这使得分割对毛细血管的精细边缘更加敏感。具有门控卷积的边缘门控流只关注边缘表示,并使用从编码器路径中提取的特征来学习强调血管边缘,并导出边缘预测结果。边缘下采样流然后从边缘预测结果中提取边缘特征,并将它们反馈到解码器路径中以细化分割结果。

贡献

  1. 提出了一种基于U-Net编码器-解码器结构的具有门控卷积的边缘感知深度网络,用于视网膜血管的有效语义分割。
  2. 设计了语义分割流边缘门控流边缘下采样流来形成多任务训练。边缘门控流使用从语义分割流编码器路径提取的特征来学习突出边界,并导出血管边缘预测结果
  3. 引入了一种边缘下采样流,将从边缘门控流输出中提取的多尺度边缘特征提前融合到语义分割流的解码器路径中,以在血管边缘周围产生更清晰的预测,这显著提高了对微小血管的性能。

实验

数据集:公开的视网膜眼底图像数据集来评估我们提出的方法。DRIVE数据集[17]包含40张健康成年人和患有轻度糖尿病视网膜病变的成年人的眼底图像,包括20张训练图像和20张测试图像。每张图像的分辨率为584×565。STARE数据集包含20个彩色眼底图像,其中10个图像为正常眼底图像,其他10个图像具有不同程度的病变。每张图像的分辨率为700×605。CHASE_DB1数据库[18]包含28个眼底视网膜彩色图像。我们选择前20张图像进行训练,其余8张图像进行测试。图像大小为960×999像素。

在实验开始前,每个图像及其标签都被裁剪成四块。通过这种方式,训练集中的数据得到了扩充。然后,训练集达到DRIVE的80张图像、STARE的40张图像和CHASE_DB1的80张图片。
在这里插入图片描述

方法

在这里插入图片描述

Semantic segmentation flow

解码器路径还集成了从边缘下采样流获得的边缘特征,然后将结果与来自编码器路径的相应特征图连接起来,这实现了更准确的分割结果,特别是小血管的边缘

Edge-gated flow

在这里插入图片描述

受这两种方法Gated-SCNN[13]edge-gated CNNs[14] 的启发,我们提出的方法还包括边缘门控流。与前两种方法不同,边缘门控流的输出是逐步下采样的。提取的边缘特征被转移到U-Net解码器路径的相应级别,以提前参与分割流

确保该流程只处理图像的边界信息。它包括一个1×1卷积,然后重复应用边缘门控块和残差块。我们采用二进制交叉熵(BCE)损失函数和边缘地面实况(GT)进行监督训练来预测最终的图像边界。边缘GT是通过canny边缘滤波器检索的图像梯度

Edge-Downsampling flow

如前所述,通过边缘下采样流从边缘门控流输出中逐步提取边缘特征。通过第一次下采样提取D1边缘特征层,然后通过跳跃连接与U-Net的最后一个上采样层连接。然后通过下一次下采样从D1层中提取D2边缘特征层,并与U-Net的相应比例上采样层连接。我们对D3层和D4层执行类似的操作

训练策略

Multitask training
整个网络由三部分组成:语义分割流、边缘门控流和边缘下采样流。多任务训练的想法借鉴了FusionNet[15]Fusionnet: Edge aware deep convolutional networks for semantic segmentation of remote sensing harbor images,但不同的是,我们的网络训练是分步骤进行的,而不是并行进行的。我们将模型的训练过程分为两个步骤。首先,边缘门控流使用从语义分割流编码器路径中提取的特征图来学习强调边缘,并导出边缘预测结果。该训练步骤由Le损失函数监督。当第一步完成时,语义分割流编码器路径和边缘门控流的网络参数都被冻结。其次,边缘下采样流从边缘结果中提取边缘特征,并将其反馈到语义分割流解码器路径中,以细化分割结果。然后,语义分割流程整合所有数据特征来预测最终的分割结果。我们采用Ls损失函数来监督这一训练步骤。

损失函数

在这里插入图片描述
都是二元交叉熵

Thinking

使用语义分割流编码器的每一层特征提取边缘特征,边缘特征图和边缘图计算二元交叉熵Le,把最后输出的边缘特征经过逐层下采样拼接到语义分割流的解码器上。辅助分割

### SCS-Net 架构实现 SCS-Net 是一种专门设计用于视网膜血管分割的尺度和上下文敏感网络。该模型通过引入多尺度特征融合机制来增强对不同尺度血管结构的学习能力,从而提高分割精度。 #### 多尺度特征提取模块 为了有效捕捉不同尺度下的血管细节,SCS-Net 设计了一个多分支卷积层结构,在每个分支中应用不同大小的感受野来进行特征图的并行计算[^1]。具体来说: - **浅层分支**:负责捕获细小血管的局部纹理信息; - **深层分支**:专注于较大范围内的血管形态变化; - **中间层次分支**:兼顾两者之间的过渡区域; 这些来自多个路径的信息最终会被聚合起来形成更加丰富的表示形式。 ```python import torch.nn as nn class MultiScaleFeatureExtractor(nn.Module): def __init__(self, in_channels=3): super(MultiScaleFeatureExtractor, self).__init__() # 浅层分支 self.branch1 = nn.Sequential( nn.Conv2d(in_channels, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True) ) # 中间层次分支 self.branch2 = nn.Sequential( nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(in_channels, 64, kernel_size=5, padding=2), nn.ReLU(inplace=True), nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) ) # 深层分支 self.branch3 = nn.Sequential( nn.MaxPool2d(kernel_size=4, stride=4), nn.Conv2d(in_channels, 64, kernel_size=7, padding=3), nn.ReLU(inplace=True), nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False) ) def forward(self, x): out_branch1 = self.branch1(x) out_branch2 = self.branch2(x) out_branch3 = self.branch3(x) return torch.cat([out_branch1, out_branch2, out_branch3], dim=1) ``` #### 上下文感知模块 除了关注于单个像素级别的分类决策外,SCS-Net 还特别强调了全局上下文的理解对于改善整体性能的重要性。为此,引入了一种基于注意力机制的方法来动态调整各部分权重分配,使得模型能够更好地聚焦于那些更具区分性的位置上。 ```python class ContextAwareModule(nn.Module): def __init__(self, channels): super(ContextAwareModule, self).__init__() self.conv = nn.Conv2d(channels, channels//8, kernel_size=1) self.softmax = nn.Softmax(dim=-1) def forward(self, feat_map): batch_size, C, height, width = feat_map.size() proj_query = self.conv(feat_map).view(batch_size,-1,width*height).permute(0,2,1) proj_key = self.conv(feat_map).view(batch_size,-1,width*height) energy = torch.bmm(proj_query,proj_key) attention = self.softmax(energy) proj_value = feat_map.view(batch_size,-1,width*height) out = torch.bmm(attention.permute(0,2,1),proj_value) out = out.view(batch_size,C,height,width) return out + feat_map ``` #### 完整架构集成 最后,将上述两个主要组件组合在一起构成完整的 SCS-Net 结构,并采用端到端的方式进行训练优化过程中的参数更新操作。 ```python class SCS_Net(nn.Module): def __init__(self): super(SCS_Net, self).__init__() self.multi_scale_extractor = MultiScaleFeatureExtractor() self.context_aware_module = ContextAwareModule(channels=192) self.final_conv = nn.Conv2d(192, 1, kernel_size=1) def forward(self, input_image): multi_scale_features = self.multi_scale_extractor(input_image) enhanced_features = self.context_aware_module(multi_scale_features) output = self.final_conv(enhanced_features) return torch.sigmoid(output) ```
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值