FourierKAN源码逐行解读:从__init__到forward的核心计算流程

FourierKAN源码逐行解读:从__init__到forward的核心计算流程

【免费下载链接】FourierKAN 【免费下载链接】FourierKAN 项目地址: https://gitcode.com/GitHub_Trending/fo/FourierKAN

项目概述

FourierKAN是基于Kolmogorov-Arnold网络(KAN)改进的深度学习模型,采用傅里叶系数替代传统样条系数,具有全局表达能力强、数值稳定性高的特点。本文将深入解析核心文件fftKAN.py的实现细节,从网络初始化到前向传播,完整呈现模型的数学原理与工程实现。

类定义与初始化方法

NaiveFourierKANLayer类结构

fftKAN.py定义了NaiveFourierKANLayer类,继承自PyTorch的nn.Module,是整个模型的核心组件。其初始化方法(__init__)负责设置网络参数与可学习权重:

class NaiveFourierKANLayer(th.nn.Module):
    def __init__( self, inputdim, outdim, gridsize, addbias=True):
        super(NaiveFourierKANLayer,self).__init__()
        self.gridsize = gridsize  # 傅里叶级数项数
        self.addbias = addbias    # 是否添加偏置项
        self.inputdim = inputdim  # 输入维度
        self.outdim = outdim      # 输出维度
        
        # 傅里叶系数初始化:形状为(2, outdim, inputdim, gridsize)
        self.fouriercoeffs = th.nn.Parameter( th.randn(2, outdim, inputdim, gridsize) / 
                                             (np.sqrt(inputdim) * np.sqrt(self.gridsize) ) )
        if self.addbias:
            self.bias = th.nn.Parameter( th.zeros(1, outdim))  # 偏置参数

参数初始化策略

傅里叶系数fouriercoeffs采用正态分布初始化,并通过1/(√inputdim * √gridsize)进行标准化,确保不同维度的输出具有单位方差。这种初始化方式在fftKAN.py中实现,是保证模型训练稳定性的关键设计。

前向传播核心计算

输入处理与维度调整

前向传播(forward)方法首先对输入张量进行形状调整,将任意维度的输入统一转换为二维矩阵以简化计算:

def forward(self, x):
    xshp = x.shape
    outshape = xshp[0:-1] + (self.outdim,)  # 保留输入维度结构
    x = th.reshape(x, (-1, self.inputdim))  # 展平为(batch_size, inputdim)

傅里叶级数计算

核心计算流程通过余弦和正弦函数构建傅里叶级数展开,实现从输入空间到输出空间的非线性映射:

# 生成波数向量k (1,1,1,gridsize),起始于1以避免与偏置项重复
k = th.reshape(th.arange(1, self.gridsize+1, device=x.device), (1,1,1,self.gridsize))
xrshp = th.reshape(x, (x.shape[0], 1, x.shape[1], 1))  # 扩展维度以支持广播

# 计算余弦和正弦项
c = th.cos(k * xrshp)  # 形状: (batch_size, 1, inputdim, gridsize)
s = th.sin(k * xrshp)  # 形状: (batch_size, 1, inputdim, gridsize)

特征融合与输出

通过傅里叶系数对三角函数项进行加权求和,得到最终输出:

# 余弦项与正弦项的加权组合
y = th.sum(c * self.fouriercoeffs[0:1], (-2, -1))  # 对输入维度和波数求和
y += th.sum(s * self.fouriercoeffs[1:2], (-2, -1))
if self.addbias:
    y += self.bias  # 添加偏置项
y = th.reshape(y, outshape)  # 恢复原始维度结构
return y

上述过程在fftKAN.py中实现,通过矩阵广播和求和操作,高效完成傅里叶特征的提取与融合。

数学原理可视化

傅里叶KAN的核心思想是将每个输入维度通过傅里叶级数展开为周期函数,再通过线性组合构建高维映射。其数学表达式可表示为:

$$ y_j = \sum_{i=1}^{inputdim} \sum_{k=1}^{gridsize} \left( a_{jik} \cos(k x_i) + b_{jik} \sin(k x_i) \right) + b_j $$

其中$a_{jik}$和$b_{jik}$分别对应fouriercoeffs[0]fouriercoeffs[1]中的系数,$b_j$为偏置项。

代码优化与内存效率

未启用的 einsum 实现

源码中提供了另一种基于einsum的实现方案(被注释掉),通过爱因斯坦求和约定减少中间变量的内存占用:

# 替代实现:使用einsum减少内存占用(当前未启用)
c = th.reshape(c, (1, x.shape[0], x.shape[1], self.gridsize))
s = th.reshape(s, (1, x.shape[0], x.shape[1], self.gridsize))
y2 = th.einsum("dbik,djik->bj", th.concat([c,s], axis=0), self.fouriercoeffs)

该实现通过fftKAN.py中的代码展示,虽然理论上更内存高效,但由于PyTorch对einsum的优化程度较低,实际训练速度可能慢于默认实现。

演示代码解析

网络堆叠与数据处理

demo函数展示了如何构建多层FourierKAN网络,并处理不同维度的输入数据:

def demo():
    bs, L, inputdim, hidden, outdim = 10, 3, 50, 200, 100
    gridsize = 300
    
    # 创建两层FourierKAN网络
    fkan1 = NaiveFourierKANLayer(inputdim, hidden, gridsize)
    fkan2 = NaiveFourierKANLayer(hidden, outdim, gridsize)
    
    # 处理常规输入 (batch_size, inputdim)
    x0 = th.randn(bs, inputdim)
    h = fkan1(x0)
    y = fkan2(h)
    
    # 处理序列输入 (batch_size, seq_len, inputdim)
    xseq = th.randn(bs, L, inputdim)
    h_seq = fkan1(xseq)  # 自动支持任意中间维度

上述代码在fftKAN.py中实现,验证了模型对高维输入的处理能力,以及输出统计特性的稳定性(通过打印均值和方差)。

总结与扩展

FourierKAN通过傅里叶级数替代传统KAN中的样条函数,在保持非线性表达能力的同时,提升了数值稳定性和全局特征捕捉能力。核心实现仅通过约60行代码完成(fftKAN.py),却展现了深刻的数学原理与工程优化思想。后续可进一步研究:

  1. 不同gridsize对模型性能的影响
  2. 与传统KAN在各类任务上的对比
  3. 更高效的傅里叶特征融合策略

通过本文的逐行解析,相信读者已对FourierKAN的实现细节有了清晰认识,可在此基础上进行二次开发与应用探索。

【免费下载链接】FourierKAN 【免费下载链接】FourierKAN 项目地址: https://gitcode.com/GitHub_Trending/fo/FourierKAN

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值