【开源计划】图像配准中常用损失函数的pytorch实现

本文介绍了基于无监督学习的图像非刚性配准模型的损失函数,包括相似性测度(如交叉互相关)和空间正则化(如变形场的L2范数平方)。通过参考VoxelMorph的代码,作者提供了局部和全局交叉互相关以及空间正则化的PyTorch实现。此外,还讨论了仿射变换的损失函数。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

前言

按照开源计划的预告,我们首先从基于深度学习的图像配准任务中常用的损失函数的代码实现开始。从我最开始的那一篇博客,即基于深度学习的医学图像配准综述,可以看出,目前基于无监督学习的图像非刚性配准模型成为了一个比较流行的研究方向。这是因为基于监督学习的方法过分依赖传统方法或者模拟变形的方法来提供监督信息,这样既吃力又不讨好。以我的探索经验来讲,以传统配准方法产生的变形场作为监督信息,对网络进行训练,很容易造成过拟合问题。因此,我综合文献综述的结论与探索的经验(有兴趣的话,我可以总结一下我的探索经历以及经验教训),最终选择了基于无监督学习的配准模型,则本文主要介绍这种模型框架下常用的损失函数。实际上,监督学习的损失函数也比较简单,只需要使用深度学习框架(如TensorFlow、PyTorch)提供的函数计算误差即可,本文使用PyTorch进行实现。

损失函数

基于无监督学习的图像非刚性配准模型的损失函数通常是由两部分组成,一个是参考图像与变形后的浮动图像的相似性测度,一个是网络预测变形场的空间正则化。以比较有名的VoxelMorph为例,它的GitHub仓库可以点击此链接。按照他最早的发表在CVPR上的论文,损失函数如下:

其中第一项就是相似性测度,后面一项就是空间正则化项,用以约束变形场的空间平滑性。下面我们分别对其进行介绍。

相似性测度

常用于测量图像的相似性测度有三个,一个是图像灰度的均方差(mean squared voxel differece),一个是交叉互相关(cross-correlation),一个是互信息(mutual information)。前两个通常用于单模态的图像,而第一个的鲁棒性相比于交叉互相关更差一些,比较容易受图像灰度分布与对比度等的影响。互信息通常用于多模态的图像,在单模态图像的鲁棒性更好,但是到目前为止还没有发现它被用于深度学习网络训练的损失函数中,我的猜想是互信息的计算是基于统计的,不方便进行梯度计算,与反向传播原则相违背。(该看法亟待进一步的考证。好久没看Voxelmorph的开源代码,现在已经有互信息的实现了,有时间可以研究一下)

因此,主要是]使用交叉互相关作为图像配准的损失函数。交叉互相关的公式为:(摘自VoxelMorph)

他们的代码实现-TensorFlow版请查看链接中的NCC。需要指出的是,在他们的实现版本当中,他们对于三维图像使用了一个9*9*9的窗口来计算相似性,因此成为local cross-correlation,即局部交叉互相关。(没想到现在voxelmorph还提供了p

### 图像网络模型的实现方法 #### 基于卷积神经网络的图像概述 图像是指通过对两幅或多幅图像的空间变换,使它们在几何位置上对齐的过程。基于卷积神经网络(CNN)的图像方法因其强大的特征提取能力和自适应学习能力,在现代图像处理领域得到了广泛应用[^1]。 #### 主要编程语言和技术框架 目前主流的图像开源项目大多采用 Python 编程语言,并依赖深度学习框架如 TensorFlow 或 PyTorch实现 CNN 模型训练和推理过程。例如,一个典型的基于 CNN 的图像项目可能涉及以下关键技术: - 数据预处理:包括图像裁剪、缩放以及标化操作。 - 特征提取:利用 CNN 提取输入图像的关键特征。 - 变换估计:通过回归或分类的方式预测空间变换参数。 - 后处理优化:应用插值或其他技术完成最终的图像对齐。 以下是使用 PyTorch 构建简单 CNN 模型的一个例子: ```python import torch import torch.nn as nn import torch.optim as optim class RegistrationNet(nn.Module): def __init__(self): super(RegistrationNet, self).__init__() self.conv_layers = nn.Sequential( nn.Conv2d(2, 32, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2), nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2) ) self.fc_layers = nn.Sequential( nn.Linear(64 * (image_height//4) * (image_width//4), 128), nn.ReLU(), nn.Linear(128, 6) # 输出仿射变换矩阵的六个参数 ) def forward(self, x): features = self.conv_layers(x) flattened = features.view(features.size(0), -1) theta = self.fc_layers(flattened) theta = theta.view(-1, 2, 3) # 调整形状为仿射变换矩阵形式 grid = F.affine_grid(theta, size=x.shape) transformed_x = F.grid_sample(x[:, :1], grid) # 对参考图像进行变换 return transformed_x, theta # 初始化模型并定义损失函数与优化器 model = RegistrationNet() criterion = nn.MSELoss() # 使用均方误差作为损失函数 optimizer = optim.Adam(model.parameters(), lr=0.001) # 训练循环示例 for epoch in range(num_epochs): model.train() for batch_data in dataloader: input_images, target_transforms = batch_data optimizer.zero_grad() output_image, predicted_theta = model(input_images) loss = criterion(predicted_theta, target_transforms) loss.backward() optimizer.step() ``` 上述代码展示了一个简单的端到端 CNN 注册网络架构设计思路及其基本训练流程。 #### 多模态图像的具体策略 对于多模态图像间的显著非线性强度差异问题,可以借鉴文献提出的分阶段解决方案[^4]。具体而言: - **预阶段**:借助改进版 SIFT 算法快速粗略地调整不同模态间的大致相对位姿关系; - **精阶段**:引入块 HARRIS 角点探测机制选取稳定可靠的局部区域兴趣点;再结合各向异性结构张量特性构建鲁棒性强的描述子表示形式;最后融合 TENSOR 方向一致性和 GRADIENT MUTUAL INFORMATION 度量指标共同指导全局最优解搜寻路径直至收敛得到精确映射结果。 #### 数学模型驱动下的迭代求参方式 除了纯数据导向的学习范式外,还有部分研究工作尝试将传统刚体/弹性变形场表达式的解析公式融入到整个端到端可微分计算图当中去形成混合动力学体系结构。这类做法通常先设定好目标能量泛函E(T),其中包含保真项F(I,T)衡量原始采样点集合I经过当前候选转换操作T作用后所得新坐标系下重投影偏差程度大小以及正则化惩罚因子R(T)约束整体形变幅度不至于过大从而保持自然物理意义合理范围之内。接着运用数值最优化手段反复修正初始猜测直到满足终止条件为止[^3]。 --- ###
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值