FourierKAN源码逐行解读:从__init__到forward的核心计算流程
【免费下载链接】FourierKAN 项目地址: https://gitcode.com/GitHub_Trending/fo/FourierKAN
项目概述
FourierKAN是基于Kolmogorov-Arnold网络(KAN)改进的深度学习模型,采用傅里叶系数替代传统样条系数,具有全局表达能力强、数值稳定性高的特点。本文将深入解析核心文件fftKAN.py的实现细节,从网络初始化到前向传播,完整呈现模型的数学原理与工程实现。
类定义与初始化方法
NaiveFourierKANLayer类结构
fftKAN.py定义了NaiveFourierKANLayer类,继承自PyTorch的nn.Module,是整个模型的核心组件。其初始化方法(__init__)负责设置网络参数与可学习权重:
class NaiveFourierKANLayer(th.nn.Module):
def __init__( self, inputdim, outdim, gridsize, addbias=True):
super(NaiveFourierKANLayer,self).__init__()
self.gridsize = gridsize # 傅里叶级数项数
self.addbias = addbias # 是否添加偏置项
self.inputdim = inputdim # 输入维度
self.outdim = outdim # 输出维度
# 傅里叶系数初始化:形状为(2, outdim, inputdim, gridsize)
self.fouriercoeffs = th.nn.Parameter( th.randn(2, outdim, inputdim, gridsize) /
(np.sqrt(inputdim) * np.sqrt(self.gridsize) ) )
if self.addbias:
self.bias = th.nn.Parameter( th.zeros(1, outdim)) # 偏置参数
参数初始化策略
傅里叶系数fouriercoeffs采用正态分布初始化,并通过1/(√inputdim * √gridsize)进行标准化,确保不同维度的输出具有单位方差。这种初始化方式在fftKAN.py中实现,是保证模型训练稳定性的关键设计。
前向传播核心计算
输入处理与维度调整
前向传播(forward)方法首先对输入张量进行形状调整,将任意维度的输入统一转换为二维矩阵以简化计算:
def forward(self, x):
xshp = x.shape
outshape = xshp[0:-1] + (self.outdim,) # 保留输入维度结构
x = th.reshape(x, (-1, self.inputdim)) # 展平为(batch_size, inputdim)
傅里叶级数计算
核心计算流程通过余弦和正弦函数构建傅里叶级数展开,实现从输入空间到输出空间的非线性映射:
# 生成波数向量k (1,1,1,gridsize),起始于1以避免与偏置项重复
k = th.reshape(th.arange(1, self.gridsize+1, device=x.device), (1,1,1,self.gridsize))
xrshp = th.reshape(x, (x.shape[0], 1, x.shape[1], 1)) # 扩展维度以支持广播
# 计算余弦和正弦项
c = th.cos(k * xrshp) # 形状: (batch_size, 1, inputdim, gridsize)
s = th.sin(k * xrshp) # 形状: (batch_size, 1, inputdim, gridsize)
特征融合与输出
通过傅里叶系数对三角函数项进行加权求和,得到最终输出:
# 余弦项与正弦项的加权组合
y = th.sum(c * self.fouriercoeffs[0:1], (-2, -1)) # 对输入维度和波数求和
y += th.sum(s * self.fouriercoeffs[1:2], (-2, -1))
if self.addbias:
y += self.bias # 添加偏置项
y = th.reshape(y, outshape) # 恢复原始维度结构
return y
上述过程在fftKAN.py中实现,通过矩阵广播和求和操作,高效完成傅里叶特征的提取与融合。
数学原理可视化
傅里叶KAN的核心思想是将每个输入维度通过傅里叶级数展开为周期函数,再通过线性组合构建高维映射。其数学表达式可表示为:
$$ y_j = \sum_{i=1}^{inputdim} \sum_{k=1}^{gridsize} \left( a_{jik} \cos(k x_i) + b_{jik} \sin(k x_i) \right) + b_j $$
其中$a_{jik}$和$b_{jik}$分别对应fouriercoeffs[0]和fouriercoeffs[1]中的系数,$b_j$为偏置项。
代码优化与内存效率
未启用的 einsum 实现
源码中提供了另一种基于einsum的实现方案(被注释掉),通过爱因斯坦求和约定减少中间变量的内存占用:
# 替代实现:使用einsum减少内存占用(当前未启用)
c = th.reshape(c, (1, x.shape[0], x.shape[1], self.gridsize))
s = th.reshape(s, (1, x.shape[0], x.shape[1], self.gridsize))
y2 = th.einsum("dbik,djik->bj", th.concat([c,s], axis=0), self.fouriercoeffs)
该实现通过fftKAN.py中的代码展示,虽然理论上更内存高效,但由于PyTorch对einsum的优化程度较低,实际训练速度可能慢于默认实现。
演示代码解析
网络堆叠与数据处理
demo函数展示了如何构建多层FourierKAN网络,并处理不同维度的输入数据:
def demo():
bs, L, inputdim, hidden, outdim = 10, 3, 50, 200, 100
gridsize = 300
# 创建两层FourierKAN网络
fkan1 = NaiveFourierKANLayer(inputdim, hidden, gridsize)
fkan2 = NaiveFourierKANLayer(hidden, outdim, gridsize)
# 处理常规输入 (batch_size, inputdim)
x0 = th.randn(bs, inputdim)
h = fkan1(x0)
y = fkan2(h)
# 处理序列输入 (batch_size, seq_len, inputdim)
xseq = th.randn(bs, L, inputdim)
h_seq = fkan1(xseq) # 自动支持任意中间维度
上述代码在fftKAN.py中实现,验证了模型对高维输入的处理能力,以及输出统计特性的稳定性(通过打印均值和方差)。
总结与扩展
FourierKAN通过傅里叶级数替代传统KAN中的样条函数,在保持非线性表达能力的同时,提升了数值稳定性和全局特征捕捉能力。核心实现仅通过约60行代码完成(fftKAN.py),却展现了深刻的数学原理与工程优化思想。后续可进一步研究:
- 不同
gridsize对模型性能的影响 - 与传统KAN在各类任务上的对比
- 更高效的傅里叶特征融合策略
通过本文的逐行解析,相信读者已对FourierKAN的实现细节有了清晰认识,可在此基础上进行二次开发与应用探索。
【免费下载链接】FourierKAN 项目地址: https://gitcode.com/GitHub_Trending/fo/FourierKAN
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考



