ADMM-ESINet 一种用于脑电扩展源成像 成像(ESI)方法 是一种基于深度展开网络(deep unfolding network)的 EEG 源

ADMM-ESINet 核心解读与复现指南

一、核心原理与贡献

ADMM-ESINet 是一种基于深度展开网络(deep unfolding network)的 EEG 源成像(ESI)方法,旨在解决传统模型驱动方法实时性差、纯深度学习方法泛化能力弱的问题。其核心思路是将 交替方向乘子法(ADMM) 的迭代求解过程“展开”为神经网络层,融合模型先验知识与数据驱动学习的优势。

  1. 核心问题:EEG 源成像是典型的“病态逆问题”(电极数量远少于潜在皮质源),需通过先验约束缩小解空间。
  2. 方法创新
    • 采用结构化稀疏约束( L 21 L_{21} L21-范数,同时约束源域和变异域),提升扩展源(extended sources)的重建精度。
    • 将 ADMM 迭代步骤“展开”为级联网络结构,保留迭代过程的可解释性,同时支持端到端训练。
    • 从数据中学习正则化参数和空间变换算子,平衡泛化能力与实时性。
  3. 优势:相比传统模型驱动方法(如 wMNE、LORETA)速度更快(毫秒级),相比纯深度学习方法(如 DeepSIF)泛化能力更强,能准确重建源的位置、范围和时间动态。
二、网络结构与关键模块

ADMM-ESINet 由多个重复的“ADMM 块”组成,每个块对应 ADMM 迭代的一步,包含 3 个核心层:

模块功能数学基础
重建层( S ( n ) S^{(n)} S(n)求解源信号的估计值 S ( n ) = ( L T L + ρ ( n ) I ) − 1 [ L T X + ρ ( n ) ( Z ( n − 1 ) − M ( n − 1 ) ) ] S^{(n)} = (L^TL + \rho^{(n)}I)^{-1}[L^TX + \rho^{(n)}(Z^{(n-1)} - M^{(n-1)})] S(n)=(LTL+ρ(n)I)1[LTX+ρ(n)(Z(n1)M(n1))]
辅助变量层( Z ( n ) Z^{(n)} Z(n)通过梯度下降更新辅助变量,引入结构化稀疏约束基于 L 21 L_{21} L21-范数的子梯度更新,结合卷积层提取空间特征
乘子更新层( M ( n ) M^{(n)} M(n)更新拉格朗日乘子,保证约束条件满足 M ( n ) = η 1 ( n ) M ( n − 1 ) + η 2 ( n ) S ( n ) − η 3 ( n ) Z ( n ) M^{(n)} = \eta_1^{(n)}M^{(n-1)} + \eta_2^{(n)}S^{(n)} - \eta_3^{(n)}Z^{(n)} M(n)=η1(n)M(n1)+η2(n)S(n)η3(n)Z(n)

网络输入为 EEG 信号 X X X 和导联场矩阵 L L L,输出为重建的皮质源信号 S S S,损失函数为重建源与真实源的均方误差(MSE)。

三、复现指南与核心代码

ADMM-ESINet 的源代码已开源,可直接基于官方仓库复现:
开源地址https://github.com/hangj-cache/ADMM-ESINet

1. 复现步骤概览
  1. 环境准备

    • 依赖:Python 3.x、PyTorch、MATLAB(用于生成合成数据)、Brainstorm(用于头模型和导联场矩阵计算)。
    • 安装命令:pip install -r requirements.txt(仓库中提供依赖列表)。
  2. 数据生成

    • 用 MATLAB 脚本(仓库 MATLAB/Data Generate 文件夹)生成合成源信号和对应的 EEG 数据。
    • 关键参数:源大小(5-32 cm²)、信噪比(-5 dB 至 10 dB)、导联场矩阵(通过 OpenMEEG 计算)。
  3. 模型训练

    • 核心文件:model.py(网络定义)、train.py(训练循环)。
    • 训练配置:Adam 优化器(初始学习率 0.003,每 25 轮衰减一半)、最大 epoch 200、批大小根据 GPU 调整。
  4. 测试与评估

    • 用测试集评估指标:AUC(敏感性)、DLE(定位误差)、SD(空间弥散)、RMSE(时间序列误差)。
2. 核心代码片段解析

以下为网络核心模块的简化实现(基于 PyTorch),完整代码请参考仓库:

import torch
import torch.nn as nn
import torch.optim as optim


class ADMMBlock(nn.Module):
    """ADMM 迭代块:包含 S、Z、M 三个子层"""
    def __init__(self, n_sources, rho_init=600):
        super(ADMMBlock, self).__init__()
        self.n_sources = n_sources
        # 可学习的惩罚参数 rho
        self.rho = nn.Parameter(torch.tensor(rho_init, dtype=torch.float32))
        # Z 层的卷积模块(空间变换算子 V)
        self.conv_V = nn.Conv1d(in_channels=n_sources, out_channels=n_sources, 
                               kernel_size=16, stride=1, padding=7)  # 保持维度
        # 非线性变换(对应 L21 范数的子梯度)
        self.h = lambda y, lam: y * torch.clamp(1 - lam / torch.norm(y, dim=1, keepdim=True), min=0)

    def forward(self, X, L, S_prev, Z_prev, M_prev):
        """
        X: 输入 EEG 信号 (batch, n_channels, n_time)
        L: 导联场矩阵 (n_channels, n_sources)
        S_prev: 上一轮的源估计 (batch, n_sources, n_time)
        Z_prev: 上一轮的辅助变量 (batch, n_sources, n_time)
        M_prev: 上一轮的乘子 (batch, n_sources, n_time)
        """
        # 1. 重建层 S^(n)
        L_T = L.t()  # 转置
        term1 = torch.matmul(L_T, X)  # L^T X
        term2 = self.rho * (Z_prev - M_prev)  # rho (Z_prev - M_prev)
        inv_matrix = torch.inverse(torch.matmul(L_T, L) + self.rho * torch.eye(self.n_sources, device=L.device))
        S = torch.matmul(inv_matrix, term1 + term2)  # 式(11)

        # 2. 辅助变量层 Z^(n)(简化版:单步梯度下降)
        Z0 = S + M_prev  # 初始值 Z^(n,0)
        VZ_prev = self.conv_V(Z_prev)  # V Z_prev(空间变换)
        grad = self.rho * (Z_prev - Z0) + self.h(VZ_prev, lam=0.01)  # 梯度
        Z = Z_prev - 0.01 * grad  # 梯度下降更新(步长 0.01)

        # 3. 乘子更新层 M^(n)
        M = 1.0 * M_prev + 1.0 * (S - Z)  # 式(13),可学习参数 eta1, eta2, eta3 简化为 1.0

        return S, Z, M


class ADMM_ESINet(nn.Module):
    """ADMM-ESINet 整体网络"""
    def __init__(self, n_blocks=6, n_sources=1024, n_channels=64):
        super(ADMM_ESINet, self).__init__()
        self.blocks = nn.ModuleList([ADMMBlock(n_sources) for _ in range(n_blocks)])
        self.L = nn.Parameter(torch.randn(n_channels, n_sources))  # 导联场矩阵(可学习或固定)

    def forward(self, X):
        # 初始化 S, Z, M
        batch_size, n_channels, n_time = X.shape
        S = torch.zeros(batch_size, self.L.shape[1], n_time, device=X.device)
        Z = torch.zeros_like(S)
        M = torch.zeros_like(S)
        
        # 多块迭代
        for block in self.blocks:
            S, Z, M = block(X, self.L, S, Z, M)
        return S


# 训练示例
if __name__ == "__main__":
    # 配置
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = ADMM_ESINet(n_blocks=6).to(device)
    criterion = nn.MSELoss()  # 损失函数:均方误差
    optimizer = optim.Adam(model.parameters(), lr=0.003)

    # 模拟数据(batch=8, 64通道, 100时间点)
    X = torch.randn(8, 64, 100, device=device)  # 输入 EEG
    S_gt = torch.randn(8, 1024, 100, device=device)  # 真实源信号

    # 训练循环
    for epoch in range(200):
        optimizer.zero_grad()
        S_pred = model(X)
        loss = criterion(S_pred, S_gt)
        loss.backward()
        optimizer.step()
        if (epoch + 1) % 25 == 0:
            # 学习率衰减
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 0.5
            print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")
三、复现注意事项
  1. 导联场矩阵(L):需根据实际电极配置计算(仓库用 OpenMEEG 生成),代码中可固定或作为可学习参数。
  2. 数据维度:输入 EEG 需投影到 temporal basis functions(TBFs)子空间(通过 SVD 降维,见文档 Eq.2),仓库数据生成脚本已包含此步骤。
  3. 计算效率:6 块网络在 RTX 3090 上约 67 ms/样本,减少块数(如 4 块)可提速 36%,精度损失极小。
  4. 真实数据适配:对新 EEG 系统,建议用对应导联场矩阵重新训练,以适配硬件特性(通道布局、噪声水平等)。
四、参考资源
  • 源代码仓库:https://github.com/hangj-cache/ADMM-ESINet(含数据生成、训练、测试完整流程)。
  • 关键依赖:PyTorch 1.8+、MATLAB 2020+(数据生成)、Brainstorm(头模型与导联场计算)。
  • 评估指标:仓库 utils/metrics.py 实现了 AUC、DLE、SD、RMSE 的计算,可直接调用。

通过以上步骤,可复现 ADMM-ESINet 并验证其在 EEG 源成像中的性能。若需适配特定场景(如癫痫病灶定位),可参考仓库中真实数据处理示例。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值