一、引言
KAN神经网络(Kolmogorov–Arnold Networks)是一种基于Kolmogorov-Arnold表示定理的新型神经网络架构。该定理指出,任何多元连续函数都可以表示为有限个单变量函数的组合。与传统多层感知机(MLP)不同,KAN通过可学习的激活函数和结构化网络设计,在函数逼近效率和可解释性上展现出潜力。
二、技术与原理简介
1.Kolmogorov-Arnold 表示定理
Kolmogorov-Arnold 表示定理指出,如果 是有界域上的多元连续函数,那么它可以写为单个变量的连续函数的有限组合,以及加法的二进制运算。更具体地说,对于 光滑
其中 和 。从某种意义上说,他们表明唯一真正的多元函数是加法,因为所有其他函数都可以使用单变量函数和 sum 来编写。然而,这个 2 层宽度 - Kolmogorov-Arnold 表示可能不是平滑的由于其表达能力有限。我们通过以下方式增强它的表达能力将其推广到任意深度和宽度。,
2.Kolmogorov-Arnold 网络 (KAN)
Kolmogorov-Arnold 表示可以写成矩阵形式
其中
我们注意到 和 都是以下函数矩阵(包含输入和输出)的特例,我们称之为 Kolmogorov-Arnold 层:
其中。
定义层后,我们可以构造一个 Kolmogorov-Arnold 网络只需堆叠层!假设我们有层,层的形状为 。那么整个网络是
相反,多层感知器由线性层和非线错:
KAN 可以很容易地可视化。(1) KAN 只是 KAN 层的堆栈。(2) 每个 KAN 层都可以可视化为一个全连接层,每个边缘上都有一个1D 函数。
三、代码详解
该代码旨在将符号表达式(SymPy表达式)编译成多层Kolmogorov-Arnold网络(MultKAN)模型。通过解析表达式的结构,将其分解为基本运算(加法、乘法、函数等),并构建对应的计算图,最终映射到MultKAN的网络结构中。
A.完整代码
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import seaborn as sns
import torch
import warnings
warnings.filterwarnings('ignore') # 忽略警告信息
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['axes.unicode_minus'] = False
# 读取数据
df = pd.read_excel('预测.xlsx')
df.dropna(inplace=True)
# 定义特征 X 和目标 y
X = df.drop(['Z'], axis=1)
y = df['Z']
# 拆分数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
df.head()
# 如果GPU可用,利用GPU进行训练
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
B.简单调用
from compiler import kanpiler
from sympy import symbols, exp, sin, pi
from torch import nn, torch
# 定义输入变量和表达式
input_vars = symbols('a b')
expression = exp(sin(pi * input_vars[0]) + input_vars[1]**2)
# 编译为 MultKAN 模型
model = kanpiler(input_vars, expression)
# 使用随机数据进行测试
x = torch.rand(100, 2) * 2 - 1 # 生成 100 个随机输入样本
output = model(x) # 计算模型输出
# 可视化模型结构
model.plot()
C.代码详解
1.next_nontrivial_operation
函数
def next_nontrivial_operation(expr, scale=1, bias=0):
'''
移除表达式中的仿射部分(即线性部分)。
Args:
expr : sympy 表达式
scale : float, 缩放因子
bias : float, 偏置
Returns:
expr : sympy 表达式
scale : float
bias : float
'''
- 功能: 该函数用于从表达式中移除仿射部分(即线性部分),并返回剩余的非线性部分。
- 实现: 通过递归地检查表达式的类型(
Add
或Mul
),并提取其中的常数部分(scale
和bias
),最终返回非线性部分。
2.expr2kan
函数
def expr2kan(input_variables, expr, grid=5, k=3, auto_save=False):
'''
将符号表达式编译为 MultKAN 模型。
Args:
input_variables : 输入变量的符号列表
expr : sympy 表达式
grid : int, 网格间隔数
k : int, 样条阶数
auto_save : bool, 是否自动保存模型
Returns:
MultKAN 模型
'''
- 功能: 该函数将符号表达式编译为
MultKAN
模型,MultKAN
是一个基于样条的神经网络模型。 - 实现: 通过递归地解析符号表达式,构建
Node
和SubNode
对象,最终生成MultKAN
模型。
3.Node
类
class Node:
def __init__(self, expr, mult_bool, depth, scale, bias, parent=None, mult_arity=None):
self.expr = expr
self.mult_bool = mult_bool
if self.mult_bool:
self.mult_arity = mult_arity
self.depth = depth
- 功能: 表示表达式树中的一个节点。
- 属性:
expr
: 节点对应的表达式。mult_bool
: 表示节点是否为乘法节点。depth
: 节点在树中的深度。scale
和bias
: 节点的缩放因子和偏置。
4.SubNode
类
class SubNode:
def __init__(self, expr, depth, scale, bias, parent=None):
self.expr = expr
self.depth = depth
- 功能: 表示表达式树中的一个子节点。
- 属性:
expr
: 子节点对应的表达式。depth
: 子节点在树中的深度。scale
和bias
: 子节点的缩放因子和偏置。
5.Connection
类
class Connection:
def __init__(self, affine, fun, fun_name, parent=None, child=None, power_exponent=None):
self.affine = affine
self.fun = fun
self.fun_name = fun_name
self.parent_index = parent.index
self.depth = parent.depth
self.child_index = child.index
self.power_exponent = power_exponent
- 功能: 表示节点之间的连接,包含激活函数和仿射变换。
- 属性:
affine
: 仿射变换参数。fun
: 激活函数。fun_name
: 激活函数的名称。parent_index
和child_index
: 连接的父节点和子节点的索引。
6.create_node
函数
def create_node(expr, parent=None, n_layer=None):
expr, scale, bias = next_nontrivial_operation(expr)
if parent == None:
depth = 0
else:
depth = parent.depth
- 功能: 递归地创建表达式树的节点和子节点。
- 实现: 根据表达式的类型(
Add
或Mul
)创建相应的Node
和SubNode
对象,并建立它们之间的连接。
7.模型构建
Nodes = [[]]
SubNodes = [[]]
Connections = {}
Start_Nodes = []
create_node(expr, n_layer=None)
n_layer = len(Nodes) - 1
Nodes = [[]]
SubNodes = [[]]
Connections = {}
Start_Nodes = []
create_node(expr, n_layer=n_layer)
- 功能: 初始化节点、子节点和连接的存储结构,并通过
create_node
函数构建表达式树。
8.模型参数设置
model = MultKAN(width=width, mult_arity=mult_arities, grid=grid, k=k, auto_save=auto_save)
- 功能: 根据表达式树的结构初始化
MultKAN
模型。 - 参数:
width
: 每层的宽度。mult_arity
: 每层的乘法节点数量。grid
: 网格间隔数。k
: 样条阶数。auto_save
: 是否自动保存模型。
9.模型参数更新
for node in Nodes_flat:
node_depth = node.depth
node_index = node.index
kan_node_depth, kan_node_index = node_index_convert[(node_depth,node_index)]
self.node_scale[kan_node_depth-1].data[kan_node_index] = float(node.scale)
self.node_bias[kan_node_depth-1].data[kan_node_index] = float(node.bias)
- 功能: 将表达式树中的节点参数(
scale
和bias
)更新到MultKAN
模型中。
10. 连接参数更新
for connection in Connections_flat:
c_depth = connection.depth
c_j = connection.parent_index
c_i = connection.child_index
kc_depth, kc_j, kc_i = connection_index_convert[(c_depth, c_j, c_i)]
model.fix_symbolic(kc_depth, kc_i, kc_j, kfun_name, fit_params_bool=False)
model.symbolic_fun[kc_depth].affine.data.reshape(self.width_out[kc_depth+1], self.width_in[kc_depth], 4)[kc_j][kc_i] = torch.tensor(connection.affine)
- 功能: 将表达式树中的连接参数(
affine
和fun
)更新到MultKAN
模型中。
11. 返回模型
return model
- 功能: 返回构建好的
MultKAN
模型。
12. 别名定义
sf2kan = kanpiler = expr2kan
- 功能: 为
expr2kan
函数定义别名sf2kan
和kanpiler
,方便调用。
四、总结与思考
KAN神经网络通过融合数学定理与深度学习,为科学计算和可解释AI提供了新思路。尽管在高维应用中仍需突破,但其在低维复杂函数建模上的潜力值得关注。未来可能通过改进计算效率、扩展理论边界,成为MLP的重要补充。
1. KAN网络架构
-
关键设计:可学习的激活函数:每个网络连接的“权重”被替换为单变量函数(如样条、多项式),而非固定激活函数(如ReLU)。分层结构:输入层和隐藏层之间、隐藏层与输出层之间均通过单变量函数连接,形成多层叠加。参数效率:由于理论保证,KAN可能用更少的参数达到与MLP相当或更好的逼近效果。
-
示例结构:输入层 → 隐藏层:每个输入节点通过单变量函数
连接到隐藏节点。隐藏层 → 输出层:隐藏节点通过另一组单变量函数
组合得到输出。
2. 优势与特点
-
高逼近效率:基于数学定理,理论上能以更少参数逼近复杂函数;在低维科学计算任务(如微分方程求解)中表现优异。
-
可解释性:单变量函数可可视化,便于分析输入变量与输出的关系;网络结构直接对应函数分解过程,逻辑清晰。
-
灵活的函数学习:激活函数可自适应调整(如学习平滑或非平滑函数);支持符号公式提取(例如从数据中恢复物理定律)。
3. 挑战与局限
-
计算复杂度:单变量函数的学习(如样条参数化)可能增加训练时间和内存消耗。需要优化高阶连续函数,对硬件和算法提出更高要求。
-
泛化能力:在高维数据(如图像、文本)中的表现尚未充分验证,可能逊色于传统MLP。
-
训练难度:需设计新的优化策略,避免单变量函数的过拟合或欠拟合。
4. 应用场景
-
科学计算:求解微分方程、物理建模、化学模拟等需要高精度函数逼近的任务。
-
可解释性需求领域:医疗诊断、金融风控等需明确输入输出关系的场景。
-
符号回归:从数据中自动发现数学公式(如物理定律)。
5. 与传统MLP的对比
6. 研究进展
-
近期论文:2024年,MIT等团队提出KAN架构(如论文《KAN: Kolmogorov-Arnold Networks》),在低维任务中验证了其高效性和可解释性。
-
开源实现:已有PyTorch等框架的初步实现。
【作者声明】
本文分享的论文内容及观点均来源于《KAN: Kolmogorov-Arnold Networks》原文,旨在介绍和探讨该研究的创新成果和应用价值。作者尊重并遵循学术规范,确保内容的准确性和客观性。如有任何疑问或需要进一步的信息,请参考论文原文或联系相关作者。
【关注我们】
如果您对神经网络、群智能算法及人工智能技术感兴趣,请关注【灵犀拾荒者】,获取更多前沿技术文章、实战案例及技术分享!