全局–局部傅里叶神经算子(GL-FNO):原理、严格推导与 PINN 混合策略

目录

摘要

1. 背景:从算子学习到傅里叶神经算子(核心概念与定理回顾)

1.1 算子学习的形式化

1.2 傅里叶神经算子(FNO)的构造要点(回顾)

2. GL-FNO 的架构描述与数值动机(全局-局部分支设计)

2.1 直观动机

2.2 架构要素(形式化)

3. 理论推导:近似误差分解与优越性的数学说明

3.1 设定与目标

3.2 误差分解(基本恒等)

3.3 对全局/局部分支近似性的定量说明

3.4 复杂度与速率(FFT 成本的量化)

4. GL-FNO 的逐层数学展开(从算子层到全网的严密表述)

4.1 FNO 层的傅里叶域参数化(形式化推导)

4.2 全局分支与局部分支的算子定义

5. 与 PINN 的结合:形式化混合框架与渐进一致性证明(要点推导)

5.1 目标:在保证数据驱动拟合精度的同时满足 PDE 物理约束

5.2 残差计算的傅里叶域表达(便于理论分析)

5.3 渐进一致性(Sketch of proof)

6. 算法细化:训练流程、残差计算与数值实现要点(伪代码 + 论证)

6.1 训练损失的构建(分项)

6.2 谱微分与非线性项的处理

6.3 伪代码(训练 epoch 的一个批次)

6.4 数值稳定性与超参数提示(理论支撑)

7. 进一步的理论注释与扩展(可作为教材练习题与延展方向)

7.1 神经算子的泛函逼近与离散不变性

7.2 GL-FNO 在实际 MHD/日冕问题上的经验与注意

8. 小结

9.代码复现


摘要

本章系统推导并证明全局–局部傅里叶神经算子(GL-FNO)的数学原理、近似性、频域/时域处理以及与物理约束(PINN)混合时的收敛性与一致性。首先回顾神经算子与傅里叶神经算子的基本定义与理论保证,随后逐步构造 GL-FNO 的算子层、全局/局部分支与拼接策略,给出误差分解与复杂度分析,并证明其在“低频全球结构 + 高频局部细节”任务上的优势。最后详细推导如何将 PINN(物理信息神经网络)损失嵌入到 GL-FNO 训练中,证明在若干合理假设下算子近似误差与 PDE 残差可联合约束到可控范围,从而获得物理解的渐进一致性。本文陈述尽量自洽,所有关键结论均标注参考来源与必要假设,并且每一步推导按书本教材风格逐步展开。GitHub+3arXiv+3arXiv+3


1. 背景:从算子学习到傅里叶神经算子(核心概念与定理回顾)

1.1 算子学习的形式化

1.2 傅里叶神经算子(FNO)的构造要点(回顾)


2. GL-FNO 的架构描述与数值动机(全局-局部分支设计)

2.1 直观动机

在许多物理场(例如日冕磁场)中,解具有明确的多尺度特性:一方面存在平滑的、占主导能量的低频(大尺度)全局结构;另一方面存在局部的高频细节(尖峰、边界层、局部非线性耦合)。单一尺度的傅里叶算子在捕捉全局低频上高效,但在高频局部细节或局部非平稳结构上可能需要保留大量频谱模式,从而引起计算与参数开销激增。GL-FNO 的核心思想是把任务分解为“全局-低分辨率分支”与“局部-高分辨率分支”——前者在粗网格上重建全局低频图形,后者在原始或高分辨率上重点恢复高频细节,然后通过融合模块合成最终高分辨率预测。该设计既保持了 FNO 的频域效率,又通过局部分支精修高频,使总体参数/计算效率优于单一高分辨率 FNO。该架构在实验上对日冕磁场问题体现了显著加速与准确性优势(见实证结果)。arXiv+1

2.2 架构要素(形式化)


3. 理论推导:近似误差分解与优越性的数学说明

3.1 设定与目标

3.2 误差分解(基本恒等)

3.3 对全局/局部分支近似性的定量说明

3.4 复杂度与速率(FFT 成本的量化)


4. GL-FNO 的逐层数学展开(从算子层到全网的严密表述)

4.1 FNO 层的傅里叶域参数化(形式化推导)

4.2 全局分支与局部分支的算子定义


5. 与 PINN 的结合:形式化混合框架与渐进一致性证明(要点推导)

5.1 目标:在保证数据驱动拟合精度的同时满足 PDE 物理约束

5.2 残差计算的傅里叶域表达(便于理论分析)

5.3 渐进一致性(Sketch of proof)

我们给出一个有条件的收敛性陈述(非完全严格的定理证明,但按教材风格给出清晰的假设与结论):

命题(渐近一致性):


6. 算法细化:训练流程、残差计算与数值实现要点(伪代码 + 论证)

6.1 训练损失的构建(分项)

6.2 谱微分与非线性项的处理

6.3 伪代码(训练 epoch 的一个批次)

 
for batch a in DataLoader:
    # 1. 全局分支(下采样)
    a_g = Downsample(a)
    u_g = GLOB_FNO(a_g)           # 在粗网格上获得全局预测
    U_g = Upsample(u_g)           # 插值到高分辨率

    # 2. 局部分支(高分辨率)
    u_loc = LOC_FNO(a)            # 在高分辨率上获得局部预测

    # 3. 融合
    hat_u = Fusion(U_g, u_loc)    # 可学习权重或小网络

    # 4. 物理残差(谱法)
    hat_u_spec = FFT(hat_u)
    phys_spec = L_hat(kappa)*hat_u_spec + FFT( Nonlinear( IFFT(hat_u_spec) ) )
    L_phys = mean(|IFFT(phys_spec)|^2)

    # 5. 数据损失(若有观测)
    L_data = mean(|hat_u - u_obs|^2)

    # 6. 合并与反向
    L = L_data + lambda * L_phys + beta * regularization
    L.backward(); optimizer.step()

6.4 数值稳定性与超参数提示(理论支撑)


7. 进一步的理论注释与扩展(可作为教材练习题与延展方向)

7.1 神经算子的泛函逼近与离散不变性

根据神经算子理论(Kovachki 等),若算子满足一定的连续性与局部可表示性,神经算子族(包括 FNO)可以在任意稠密子空间上逼近该算子到任意精度;其中 FNO 的频谱参数化通过限制为“卷积核在傅里叶域的表示”保证了实现上的参数共享与离散不变性。这一理论为 GL-FNO 的两分支设计提供更深一层的保底:全局分支与局部分支分别在不同离散尺度上逼近算子(因此可在多分辨率下共享参数与泛化)。机器学习研究期刊

7.2 GL-FNO 在实际 MHD/日冕问题上的经验与注意

实证结果表明 GL-FNO 在日冕磁场建模任务中在精度与推理时间上均优于多种基线模型(FNO、U-FNO、ViT、CNN-RNN 等),并在部分测试集上实现了极大的加速(论文给出“超过 20,000×”的速度提升作为比较于传统数值解算器的经验数值;注意该因子包含了数值解算器在高精度时间步长和边界处理上的开销)。在工程上,建议参考作者开源实现以获得合理的网络超参与数据预处理细节(例如 modes 数、隐藏通道数、融合网络结构等)。arXiv+1


8. 小结

全局–局部傅里叶神经算子(GL-FNO)在理论与实践上都可以被视为对傅里叶神经算子的一种多分辨率并行化推广:通过把低频全局结构放到下采样分支、把高频细节放到高分辨率局部分支并在谱域/物理域间交替计算非线性项,GL-FNO 在计算效率与精度上实现了有利的折衷。将 PINN 的物理约束显式加入 GL-FNO 的训练目标可以通过谱微分获得高精度的 PDE 残差估计,从而在数据稀缺/噪声情况下提供物理一致性的保证。在合理的假设下(神经算子近似性、损失可优化性、残差算子连续性),联合训练下的模型在训练分布上可获得渐近一致性。实践中,为确保数值稳定性需注意频谱截断、插值兼容与物理损失权重的调节。

9.代码复现

#!/usr/bin/env python3
# glfno_pinn.py
#
# Single-file, complete and runnable PyTorch implementation of a Global-Local Fourier Neural Operator (GL-FNO)
# with a PINN-style spectral physics residual loss. This script:
#  - builds synthetic 2D training data (smooth global + high-frequency local details)
#  - defines FNO2D (spectral layers), a global (downsampled) branch, a local (full-res) branch,
#    and a learned fusion network
#  - computes physics residuals using spectral differentiation (FFT-based Laplacian + simple nonlinearity)
#  - trains the GL-FNO with combined data + physics loss
#
# Requirements: Python 3.8+, PyTorch 1.12+ (with torch.fft), numpy, tqdm
#
# Usage:
#   python glfno_pinn.py
#
# The script is intentionally self-contained and organized for clarity and engineering use.
# No additional files are required.

import math
import argparse
import time
from tqdm import tqdm

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# ---------------------------
# Utilities
# ---------------------------

def get_device():
    return "cuda" if torch.cuda.is_available() else "cpu"

def meshgrid_2d(nx, ny, device='cpu'):
    x = torch.linspace(0, 1, nx, device=device)
    y = torch.linspace(0, 1, ny, device=device)
    xv, yv = torch.meshgrid(x, y, indexing='xy')
    return xv, yv

# ---------------------------
# Synthetic dataset
# ---------------------------
class SyntheticCoronalDataset(Dataset):
    """
    Creates synthetic 2D fields with a global low-frequency component and local high-frequency details.
    Input `a` simulates boundary/parameter maps; target `u` is a field derived from `a`.
    The dataset returns tensors of shape [nx, ny].
    """
    def __init__(self, n_samples=2000, nx=128, ny=128, seed=0):
        super().__init__()
        self.n_samples = n_samples
        self.nx = nx
        self.ny = ny
        rng = np.random.RandomState(seed)
        # Pre-generate random parameter fields controlling amplitudes and phases
        self.params = []
        for i in range(n_samples):
            p = {
                'amp_global': float(0.5 + rng.rand()*1.5),
                'amp_local': float(0.02 + rng.rand()*0.2),
                'freq_local': int(4 + rng.randint(1, 10)),
                'phase': float(rng.rand()*2*np.pi)
            }
            self.params.append(p)

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        p = self.params[idx]
        nx, ny = self.nx, self.ny
        # grid
        xs = np.linspace(0, 1, nx)
        ys = np.linspace(0, 1, ny)
        xv, yv = np.meshgrid(xs, ys, indexing='xy')
        # global low-frequency component: sum of a few smooth Gaussians / cos modes
        g = (np.cos(2*np.pi*(1.0*xv + 0.5*yv - 0.1*p['phase'])) +
             0.5*np.cos(2*np.pi*(0.5*xv - 0.3*yv + 0.2*p['phase'])))
        g = p['amp_global'] * gaussian_blur_np(g, sigma=0.05*nx)
        # local high-frequency detail: radial oscillations modulated by random centers
        k = p['freq_local']
        r = np.sqrt((xv-0.5)**2 + (yv-0.5)**2)
        local = p['amp_local'] * np.sin(2*np.pi*k*(xv + yv) + p['phase']) * np.exp(-50*(r-0.3)**2)
        # assemble
        u = g + local
        # input param a: provide a low-res downsampled version of u with some noise (simulating boundary data)
        a = gaussian_blur_np(u, sigma=3.0)  # coarse map
        a = a[::4, ::4]  # downsample by 4 (global branch will use similar)
        # convert to torch tensors
        a_t = torch.tensor(a, dtype=torch.float32)
        u_t = torch.tensor(u, dtype=torch.float32)
        return a_t.unsqueeze(0), u_t.unsqueeze(0)  # [1, nx_g, ny_g] and [1, nx, ny]

def gaussian_blur_np(img, sigma):
    # simple gaussian blur via FFT-based multiplication (fast)
    nx, ny = img.shape
    kx = np.fft.fftfreq(nx)
    ky = np.fft.fftfreq(ny)
    KX, KY = np.meshgrid(kx, ky, indexing='xy')
    kernel = np.exp(-0.5*( (2*np.pi*sigma*KX)**2 + (2*np.pi*sigma*KY)**2 ))
    img_f = np.fft.fft2(img)
    blurred = np.real(np.fft.ifft2(img_f * kernel))
    return blurred

# ---------------------------
# FNO2D: Fourier Neural Operator layer and stack
# ---------------------------
class SpectralConv2d(nn.Module):
    """
    2D Fourier layer: perform FFT, multiply the lowest `modes` frequencies by learned complex weights,
    inverse FFT, and add pointwise linear mixing.
    """
    def __init__(self, in_channels, out_channels, modes_x, modes_y):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes_x = modes_x
        self.modes_y = modes_y
        # complex weights for lower-left frequency block and upper-left/others to handle negative frequencies
        scale = 1 / (in_channels * out_channels)
        self.weights = nn.Parameter(scale * torch.randn(in_channels, out_channels, modes_x, modes_y, 2))
        # 2 at the end stands for real+imag parts param.
        # pointwise linear mixing for residual
        self.w0 = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def compl_mul2d(self, input_ft, weights):
        # input_ft: [batch, in_channels, nx, ny] complex (torch.complex)
        # weights: [in_channels, out_channels, mx, my, 2]
        # returns [batch, out_channels, nx, ny] complex
        # multiply over selected low-frequency block
        batch = input_ft.shape[0]
        device = input_ft.device
        mx, my = self.modes_x, self.modes_y
        # convert weights to complex tensor
        w = torch.view_as_complex(weights.to(device))
        # allocate output in freq domain
        out_ft = torch.zeros(batch, self.out_channels, input_ft.shape[-2], input_ft.shape[-1], dtype=torch.cfloat, device=device)
        # lower frequencies (0:mx, 0:my)
        out_ft[:, :, :mx, :my] = torch.einsum("bixy, ioyy -> boxy", input_ft[:, :, :mx, :my], w)
        # handle negative frequencies on the other ends
        out_ft[:, :, -mx:, :my] += torch.einsum("bixy, ioyy -> boxy", input_ft[:, :, -mx:, :my], w)
        out_ft[:, :, :mx, -my:] += torch.einsum("bixy, ioyy -> boxy", input_ft[:, :, :mx, -my:], w)
        out_ft[:, :, -mx:, -my:] += torch.einsum("bixy, ioyy -> boxy", input_ft[:, :, -mx:, -my:], w)
        return out_ft

    def forward(self, x):
        """
        x shape: [batch, in_channels, nx, ny] real
        """
        batchsize = x.shape[0]
        nx, ny = x.shape[-2], x.shape[-1]
        # FFT: [batch, in_channels, nx, ny] complex
        x_ft = torch.fft.fft2(x, norm='ortho')
        out_ft = self.compl_mul2d(x_ft, self.weights)
        # inverse FFT
        out_ifft = torch.fft.ifft2(out_ft, norm='ortho').real
        # pointwise conv
        out = self.w0(x)
        return out + out_ifft

class FNO2d(nn.Module):
    def __init__(self, in_channels, width, modes_x, modes_y, n_layers=4):
        super().__init__()
        self.in_channels = in_channels
        self.width = width
        self.fc0 = nn.Conv2d(in_channels, self.width, kernel_size=1)
        self.layers = nn.ModuleList()
        for _ in range(n_layers):
            self.layers.append(SpectralConv2d(self.width, self.width, modes_x, modes_y))
            self.layers.append(nn.GELU())
        self.fc1 = nn.Conv2d(self.width, 128, kernel_size=1)
        self.fc2 = nn.Conv2d(128, 1, kernel_size=1)

    def forward(self, x):
        # x: [batch, in_channels, nx, ny]
        x = self.fc0(x)
        for layer in self.layers:
            x = layer(x)
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc2(x)
        return x  # [batch, 1, nx, ny]

# ---------------------------
# GL-FNO model: global branch (coarse), local branch (full-res), fusion network
# ---------------------------
class GLFNO(nn.Module):
    def __init__(self,
                 nx, ny,
                 nx_glob, ny_glob,
                 global_modes=(16, 16),
                 local_modes=(32, 32),
                 width_global=64,
                 width_local=64):
        super().__init__()
        self.nx = nx
        self.ny = ny
        self.nxg = nx_glob
        self.nyg = ny_glob

        # global branch operates on downsampled inputs (assume input 'a' is already downsampled)
        self.global_fno = FNO2d(in_channels=1, width=width_global,
                                 modes_x=global_modes[0], modes_y=global_modes[1], n_layers=4)

        # local branch operates on full-res inputs (we will upsample coarse 'a' to full-res as additional channel)
        self.local_fno = FNO2d(in_channels=2, width=width_local,
                                modes_x=local_modes[0], modes_y=local_modes[1], n_layers=4)

        # fusion network (small conv net) takes concatenation of upsampled global prediction and local prediction
        self.fusion = nn.Sequential(
            nn.Conv2d(2, 32, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv2d(32, 1, kernel_size=1)
        )
        # gating network to adaptively weight global vs local per spatial location
        self.gate = nn.Sequential(
            nn.Conv2d(2, 16, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv2d(16, 1, kernel_size=1),
            nn.Sigmoid()
        )

    def upsample_global(self, u_g):
        # u_g: [batch, 1, nxg, nyg] -> upsample to [batch, 1, nx, ny] using bilinear
        return F.interpolate(u_g, size=(self.nx, self.ny), mode='bilinear', align_corners=True)

    def forward(self, a_coarse, a_fullres=None):
        """
        a_coarse: [batch, 1, nxg, nyg]  (input for global branch)
        a_fullres: [batch, 1, nx, ny] (input for local branch; optional: if None we upsample a_coarse)
        returns: prediction [batch, 1, nx, ny]
        """
        batch = a_coarse.shape[0]
        # Global branch (coarse)
        u_g = self.global_fno(a_coarse)  # [batch,1,nxg,nyg]
        u_g_up = self.upsample_global(u_g)  # [batch,1,nx,ny]

        # prepare local input channels: we provide both fullres input and upsampled coarse a
        if a_fullres is None:
            a_fullres = F.interpolate(a_coarse, size=(self.nx, self.ny), mode='bilinear', align_corners=True)
        x_local = torch.cat([a_fullres, a_fullres], dim=1)  # duplicate param as 2 channels if desired
        # local branch
        u_loc = self.local_fno(x_local)  # [batch,1,nx,ny]

        # gating and fusion
        gate_in = torch.cat([u_g_up, u_loc], dim=1)
        alpha = self.gate(gate_in)  # [batch,1,nx,ny] in (0,1)
        fused = alpha * u_g_up + (1.0 - alpha) * u_loc
        # final refinement
        out = self.fusion(torch.cat([fused, u_loc], dim=1))
        return out, u_g_up, u_loc, alpha

# ---------------------------
# Physics loss: spectral Laplacian + simple cubic nonlinearity residual
# R[u] = -Delta u + k*u + u^3 - f  (for synthetic demonstration we set f=0)
# We compute residual in spectral space exactly for linear terms and compute nonlinearity in physical space.
# ---------------------------

def spectral_laplacian(u, lx, ly):
    """
    Compute Laplacian via FFT for 2D real tensor u of shape [batch, 1, nx, ny].
    Returns Laplacian of same shape (real).
    """
    # FFT conventions: use torch.fft.fft2 with norm='ortho'
    u_c = torch.fft.fft2(u, norm='ortho')
    batch, c, nx, ny = u.shape
    device = u.device
    kx = torch.fft.fftfreq(nx, d=1.0/nx).to(device) * 2 * math.pi  # angular frequencies scaled to domain [0,1]
    ky = torch.fft.fftfreq(ny, d=1.0/ny).to(device) * 2 * math.pi
    KX, KY = torch.meshgrid(kx, ky, indexing='xy')
    K2 = (KX ** 2 + KY ** 2).unsqueeze(0).unsqueeze(0)  # shape [1,1,nx,ny]
    lap_spec = - (K2) * u_c
    lap = torch.fft.ifft2(lap_spec, norm='ortho').real
    return lap

def physics_residual(u_pred, coef_k=1.0):
    """
    Compute physics residual R = -Delta u + k*u + u^3
    u_pred: [batch,1,nx,ny] real
    returns residual same shape
    """
    lap = spectral_laplacian(u_pred, None, None)
    nonlinear = u_pred ** 3
    res = -lap + coef_k * u_pred + nonlinear
    return res

# ---------------------------
# Training utilities
# ---------------------------
def collate_fn(batch):
    # batch elements: (a_coarse [1,nxg,nyg], u_full [1,nx,ny])
    # We will stack them into tensors.
    a_list = [b[0] for b in batch]
    u_list = [b[1] for b in batch]
    a_batch = torch.stack(a_list, dim=0)  # [B,1,nxg,nyg]
    u_batch = torch.stack(u_list, dim=0)  # [B,1,nx,ny]
    return a_batch, u_batch

# ---------------------------
# Main training routine
# ---------------------------
def train_glfno(args):
    device = get_device()
    print("Using device:", device)
    # dataset & loader
    dataset = SyntheticCoronalDataset(n_samples=args.n_train + args.n_val, nx=args.nx, ny=args.ny, seed=args.seed)
    train_set, val_set = torch.utils.data.random_split(dataset, [args.n_train, args.n_val], generator=torch.Generator().manual_seed(args.seed+1))
    train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn)

    # Determine coarse grid sizes from dataset (downsample factor fixed in dataset code)
    nxg = dataset[0][0].shape[-2]
    nyg = dataset[0][0].shape[-1]
    print(f"Full res: {args.nx}x{args.ny}, Coarse res: {nxg}x{nyg}")

    # model
    model = GLFNO(nx=args.nx, ny=args.ny,
                  nx_glob=nxg, ny_glob=nyg,
                  global_modes=(args.global_modes, args.global_modes),
                  local_modes=(args.local_modes, args.local_modes),
                  width_global=args.width_global,
                  width_local=args.width_local).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step, gamma=args.lr_gamma)
    mse = nn.MSELoss()

    best_val = 1e9
    t0 = time.time()
    for epoch in range(1, args.epochs+1):
        model.train()
        train_loss = 0.0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{args.epochs}", leave=False)
        for a_coarse, u_true in pbar:
            a_coarse = a_coarse.to(device)
            u_true = u_true.to(device)
            # upsample coarse input to full-res for local branch
            a_full = F.interpolate(a_coarse, size=(args.nx, args.ny), mode='bilinear', align_corners=True)

            optimizer.zero_grad()
            u_pred, u_g_up, u_loc, alpha = model(a_coarse, a_full)

            # data loss (MSE)
            L_data = mse(u_pred, u_true)

            # physics residual (spectral). compute on prediction
            res = physics_residual(u_pred)
            L_phys = torch.mean(res**2)

            # total loss
            loss = args.w_data * L_data + args.w_phys * L_phys

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            train_loss += loss.item() * a_coarse.shape[0]
            pbar.set_postfix({"loss": f"{loss.item():.4e}", "Ldata": f"{L_data.item():.4e}", "Lphys": f"{L_phys.item():.4e}"})

        scheduler.step()
        train_loss /= len(train_loader.dataset)

        # validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for a_coarse, u_true in val_loader:
                a_coarse = a_coarse.to(device)
                u_true = u_true.to(device)
                a_full = F.interpolate(a_coarse, size=(args.nx, args.ny), mode='bilinear', align_corners=True)
                u_pred, u_g_up, u_loc, alpha = model(a_coarse, a_full)
                Ld = mse(u_pred, u_true)
                res = physics_residual(u_pred)
                Lp = torch.mean(res**2)
                loss = args.w_data * Ld + args.w_phys * Lp
                val_loss += loss.item() * a_coarse.shape[0]
            val_loss /= len(val_loader.dataset)

        print(f"Epoch {epoch}  TrainLoss {train_loss:.5e}  ValLoss {val_loss:.5e}  time {time.time()-t0:.1f}s")
        # save best
        if val_loss < best_val:
            best_val = val_loss
            torch.save({'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'epoch': epoch}, args.checkpoint)
            print(f"  Saved checkpoint to {args.checkpoint} (val {best_val:.4e})")

    print("Training completed. Best val loss:", best_val)

# ---------------------------
# Argument parsing
# ---------------------------
def parse_args():
    parser = argparse.ArgumentParser(description="GL-FNO + spectral-PINN demo (single-file)")
    parser.add_argument("--nx", type=int, default=128, help="full-resolution x grid")
    parser.add_argument("--ny", type=int, default=128, help="full-resolution y grid")
    parser.add_argument("--n_train", type=int, default=800, help="num training samples")
    parser.add_argument("--n_val", type=int, default=200, help="num validation samples")
    parser.add_argument("--batch_size", type=int, default=8)
    parser.add_argument("--epochs", type=int, default=30)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--lr_step", type=int, default=10)
    parser.add_argument("--lr_gamma", type=float, default=0.5)
    parser.add_argument("--weight_decay", type=float, default=1e-6)
    parser.add_argument("--global_modes", type=int, default=12)
    parser.add_argument("--local_modes", type=int, default=28)
    parser.add_argument("--width_global", type=int, default=64)
    parser.add_argument("--width_local", type=int, default=64)
    parser.add_argument("--w_data", type=float, default=1.0)
    parser.add_argument("--w_phys", type=float, default=1e-3)
    parser.add_argument("--checkpoint", type=str, default="glfno_checkpoint.pth")
    parser.add_argument("--seed", type=int, default=42)
    return parser.parse_args()

# ---------------------------
# Entry point
# ---------------------------
if __name__ == "__main__":
    args = parse_args()
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    train_glfno(args)

傅里叶神经算子(Fourier Neural Operator, FNO)是一种用于求解参数化偏微分方程(PDEs)的深度学习方法,它通过结合傅里叶变换神经网络的优势,能够高效地学习从输入函数到输出函数的非线性映射关系。FNO的核心思想是利用傅里叶空间中的卷积操作来捕捉空间依赖性,从而提升模型的表示能力和计算效率。 ### 原理 FNO的基本原理建立在函数空间变换的基础上。传统的神经网络通常在欧几里得空间中进行操作,而FNO则直接在函数空间中进行映射。具体来说,FNO通过以下步骤实现函数到函数的映射: 1. **输入映射**:输入函数(如初始条件或边界条件)首先被映射到一个高维特征空间。 2. **傅里叶变换**:对高维特征进行傅里叶变换,将其转换到频域空间。这一过程可以快速捕捉全局的空间依赖性。 3. **傅里叶空间中的线性变换**:在傅里叶空间中,应用一个可学习的线性变换(通常是点态乘法),模拟卷积操作,从而实现对频域特征的处理。 4. **逆傅里叶变换**:将变换后的频域特征转换回时域空间。 5. **非线性激活残差连接**:引入非线性激活函数和残差连接以增强模型的表达能力。 6. **输出映射**:最终的特征被映射回目标函数空间,得到PDE的解。 这种设计使得FNO能够在不同分辨率下保持性能,具有良好的泛化能力。 ### 应用 FNO已被广泛应用于各种参数化PDE问题的求解,特别是在物理模拟、流体动力学、气候建模、材料科学等领域。例如: - 在流体动力学中,FNO可用于预测不同雷诺数下的流场演化。 - 在气候建模中,FNO可以用于预测温度、风速等气象变量的时空变化。 - 在金融工程中,FNO可用于求解期权定价模型中的偏微分方程。 这些应用展示了FNO在处理高维、非线性、参数化PDE问题上的强大能力。 ### 实现方法 FNO的实现主要包括以下几个关键步骤: 1. **数据预处理**:将输入数据(如网格坐标和函数值)转换为适合模型处理的形式,通常包括归一化、网格生成等。 2. **网络结构设计**:构建FNO层,通常包含多个傅里叶层,每层包括傅里叶变换、线性变换、逆傅里叶变换以及非线性激活函数。 3. **损失函数设计**:通常使用均方误差(MSE)作为损失函数,衡量模型预测解真实解之间的差异。 4. **训练优化**:使用优化算法(如Adam)对模型进行训练,调整参数以最小化损失函数。 5. **泛化评估**:在不同参数配置或更高分辨率下测试模型性能,验证其泛化能力。 以下是一个简化的FNO层实现示例: ```python import torch import torch.nn as nn import torch.fft as fft class FNO2DLayer(nn.Module): def __init__(self, in_channels, out_channels, modes1, modes2): super(FNO2DLayer, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.modes1 = modes1 self.modes2 = modes2 self.scale = 1 / (in_channels * out_channels) self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat)) self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat)) def forward(self, x): batch_size, channels, height, width = x.shape x_ft = fft.rfft2(x, dim=[2, 3]) out_ft = torch.zeros(batch_size, self.out_channels, height, width//2 + 1, dtype=torch.cfloat, device=x.device) out_ft[:, :, :self.modes1, :self.modes2] = torch.einsum("bixy,ioxy->boxy", x_ft[:, :, :self.modes1, :self.modes2], self.weights1) out_ft[:, :, -self.modes1:, :self.modes2] = torch.einsum("bixy,ioxy->boxy", x_ft[:, :, -self.modes1:, :self.modes2], self.weights2) x = fft.irfft2(out_ft, s=(height, width)) return x ``` 该代码展示了FNO层的基本操作流程,包括傅里叶变换、频域中的线性变换以及逆傅里叶变换。通过堆叠多个这样的层,可以构建完整的FNO网络。
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值