Efficient-KAN 源码详解

Efficient-KAN源码链接

Efficient-KAN (GitHub)

改进细节

1.内存效率提升

KAN网络的原始实现的性能问题主要在于它需要扩展所有中间变量以执行不同的激活函数。对于具有in_features个输入和out_features个输出的层,原始实现需要将输入扩展为shape为(batch_size, out_features, in_features)的tensor以执行激活函数。然而,所有激活函数都是一组固定基函数(3阶B样条)的线性组合。鉴于此,拟将计算重新表述为不同的基函数激活输入,然后将它们线性组合。这种重新表述可以显著减少内存消耗,并使计算变得更加简单的矩阵乘法,自然地适用于前向和后向传递。

2.正则化方法的改变

稀疏化被认为对KAN的可解释性至关重要。作者提出了一种定义在输入样本上的L1正则化,它需要对**(batch_size, out_features, in_features)** tensor进行非线性操作,因此与重新表述不兼容。拟改为对权重进行L1正则化,这在NN中更为常见,并且与重新表述兼容。

3.激活函数缩放选项

除了可学习的激活函数(B样条),原始实现还包括对每个激活函数的可学习缩放 ( w s ) (w_s) (ws)。拟提供一个名为enable_standalone_scale_spline的选项,默认情况下为True,以包含此功能。禁用它会使模型更高效,但可能会影响结果。这需要更多实验验证。

4.参数初始化的改变

为了解决在MNIST数据集上的性能问题,该代码修改了参数的初始化方式,采用Kaiming初始化

KAN_fast.py解析

基本参数和类定义

import torch
import torch.nn.functional as F
import math

class KANLinear(torch.nn.Module):
    def __init__(
        self,
        in_features,
        out_features,
        grid_size=5,  # 网格大小,默认为 5
        spline_order=3, # 分段多项式的阶数,默认为 3
        scale_noise=0.1,  # 缩放噪声,默认为 0.1
        scale_base=1.0,   # 基础缩放,默认为 1.0
        scale_spline=1.0,    # 分段多项式的缩放,默认为 1.0
        enable_standalone_scale_spline=True,
        base_activation=torch.nn.SiLU,  # 基础激活函数,默认为 SiLU(Sigmoid Linear Unit)
        grid_eps=0.02,
        grid_range=[-1, 1],  # 网格范围,默认为 [-1, 1]
    ):
        super(KANLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size # 设置网格大小和分段多项式的阶数
        self.spline_order = spline_order

        h = (grid_range[1] - grid_range[0]) / grid_size   # 计算网格步长

生成网格

 grid = ( # 生成网格
            (
                torch.arange(-spline_order, grid_size + spline_order + 1) * h
                + grid_range[0] 
            )
            .expand(in_features, -1)
            .contiguous()
        )
self.register_buffer("grid", grid)  # 将网格作为缓冲区注册

1. torch.arange(-spline_order, grid_size + spline_order + 1)
  • **torch.arange(start, end)**:生成一个从 startend-1 的整数序列(左闭右开区间)。
  • **-spline_order**:从负的 spline_order 开始。
  • **grid_size + spline_order + 1**:终止于 grid_size + spline_order(不包括 +1)。

这个序列的长度是 grid_size + 2 * spline_order + 1,用于涵盖所有需要的网格点,包括两端的扩展区域。

2. * h
  • 这一步将生成的整数序列乘以步长 h,将索引序列转换为实际的网格位置。
3. + grid_range[0]
  • 这一步将整个网格位置进行平移,使得网格的起始点与 grid_range[0] 对齐。

如果 grid_range[0] = -1,则每个位置都会减去 1

4. .expand(in_features, -1)
  • **.expand()**:将这个网格复制 in_features 次,以适应输入特征的维度。具体来说,它将原本的一维网格向量扩展成一个 in_features × (grid_size + 2 * spline_order + 1) 的二维张量。 其中每一行都是相同的网格向量。
5. .contiguous()
  • **.contiguous()**:确保扩展后的张量在内存中是连续存储的,方便后续的计算和操作。虽然在大多数情况下这个操作是可选的,但它可以提高计算效率并避免潜在的问题。

最终效果:

这段代码生成了一个二维张量 grid,它的形状为 [in_features, grid_size + 2 * spline_order + 1],其中每一行都是相同的、覆盖整个 grid_range 并适当扩展的网格点序列。这个网格用于模型中的 B 样条或其他基函数计算,使得模型可以在输入数据范围内执行灵活的插值和拟合操作。

初始化可训练参数

        self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features)) # 初始化基础权重和分段多项式权重
        self.spline_weight = torch.nn.Parameter(
            torch.Tensor(out_features, in_features, grid_size + spline_order)
        )
        if enable_standalone_scale_spline:  # 如果启用独立的分段多项式缩放,则初始化分段多项式缩放参数
            self.spline_scaler = torch.nn.Parameter(
                torch.Tensor(out_features, in_features)
            )

1. self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features)) ( w b ) (w_b) (wb)

  • **torch.Tensor(out_features, in_features)**:创建一个形状为 (out_features, in_features) 的未初始化张量,用于存储基础线性层的权重。这个张量的元素初始时没有具体的数值,通常在后续的 reset_parameters() 方法中进行初始化。
  • **torch.nn.Parameter**:将这个张量封装成 torch.nn.Parameter 对象。这意味着这个张量会被视为模型的可训练参数,PyTorch 会自动将其包含在模型的参数列表中,并在反向传播时更新其值。
  • **self.base_weight**:这个属性存储的是基础线性变换的权重矩阵。这个矩阵将在前向传播过程中被用来对输入特征进行线性变换。

2. self.spline_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features, grid_size + spline_order)) ( c i ) (c_i) (ci)

  • **torch.Tensor(out_features, in_features, grid_size + spline_order)**:创建一个形状为 (out_features, in_features, grid_size + spline_order) 的未初始化张量,用于存储分段多项式的权重。这些权重将用于 B 样条或其他类似方法的计算。
  • **torch.nn.Parameter**:同样地,将这个张量封装成 torch.nn.Parameter,使其成为模型的可训练参数。
  • **self.spline_weight**:这个属性存储的是与分段多项式相关的权重。这些权重决定了如何将输入特征映射到输出特征,特别是在使用 B 样条等非线性激活函数时。
为什么 spline_weight 的形状是 (out_features, in_features, grid_size + spline_order)
  • **out_features****in_features**:与 base_weight 类似,表示输出和输入的特征数量。
  • **grid_size + spline_order**:这个维度表示在 B 样条或其他分段多项式方法中,每个输入特征需要使用的基函数的数量。通过这些基函数的线性组合,可以生成灵活的非线性激活。

3. if enable_standalone_scale_spline:

  • 这个条件语句检查 enable_standalone_scale_spline 是否为 True。如果为 True,则会为每个分段多项式激活函数引入一个独立的缩放参数。

4. self.spline_scaler = torch.nn.Parameter(torch.Tensor(out_features, in_features)) ( w s ) (w_s) (ws)

  • **torch.Tensor(out_features, in_features)**:创建一个形状为 (out_features, in_features) 的张量,用于存储独立的分段多项式缩放参数。
  • **torch.nn.Parameter**:将张量封装成 torch.nn.Parameter,使其成为可训练参数。
  • **self.spline_scaler**:这个属性存储的是分段多项式的缩放参数。每个 spline_weight 都有一个对应的缩放参数,可以单独调整其幅度,从而提供更大的灵活性。

其他实例属性

        self.scale_noise = scale_noise # 保存缩放噪声、基础缩放、分段多项式的缩放、是否启用独立的分段多项式缩放、基础激活函数和网格范围的容差
        self.scale_base = scale_base
        self.scale_spline = scale_spline
        self.enable_standalone_scale_spline = enable_standalone_scale_spline
        self.base_activation = base_activation()
        self.grid_eps = grid_eps

        self.reset_parameters()  # 重置参数

Kaiming初始化权重(reset_parameters)

def reset_parameters(self):
        torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)# 使用 Kaiming 均匀初始化基础权重
        with torch.no_grad():
            noise = (# 生成缩放噪声
                (
                    torch.rand(self.grid_size + 1
<think>好的,我需要解决用户在安装Python项目时遇到的错误,错误信息是缺少setup.py或pyproject.toml文件。用户提到使用pip install -e时出现这个问题,可能是在尝试以可编辑模式安装项目。 首先,我应该回忆一下pip安装项目的基本要求。根据引用[3],pip需要setup.py或pyproject.toml来识别项目元数据。如果这两个文件都不存在,pip就无法处理安装,所以报错。 接下来,用户的错误可能出现在几种情况下。比如,用户可能在错误的目录下运行命令,或者项目本身缺少必要的构建文件。比如,引用[1]和[2]中的例子,都是因为路径下的项目缺少这两个文件。 解决方案的第一步应该是检查当前目录是否正确。用户需要确认自己是否在项目的根目录下,因为setup.py或pyproject.toml通常位于这里。比如,如果项目结构正确,用户可能不小心进入了子目录,如tests或docs,这时候需要切换到上级目录。 如果确认目录正确,下一步就是检查项目结构。用户需要确保项目包含setup.py或pyproject.toml。对于现代项目,可能使用pyproject.toml和setuptools或poetry。如果用户的项目是从源代码仓库克隆的,可能需要生成这些文件。比如,有些项目使用setup.cfg或者其他构建工具,可能需要先运行生成脚本,或者安装构建依赖。 如果项目本身没有这些文件,用户可能需要手动创建。比如,使用setuptools编写一个简单的setup.py,或者配置pyproject.toml。不过这可能对新手来说有些难度,需要提供示例。 另外,用户可能尝试安装的是一个Whl文件,但路径错误。根据引用[5],用户之前通过将.whl文件放在特定目录并正确指定路径解决了问题。所以,如果用户是安装本地包,需要确认路径是否正确,以及文件是否完整。 对于可编辑安装(-e选项),通常用于开发模式,需要项目有setup.py。如果项目使用pyproject.toml,可能需要确保构建系统支持可编辑安装。比如,使用poetry或flit时可能需要不同的命令。 还有一个可能性是用户的环境有问题,比如权限问题或者pip版本过旧。建议用户升级pip到最新版本,确保兼容性。 总结步骤应该是:检查目录位置,确认项目结构,生成必要文件,尝试其他安装方法,升级工具。同时,引用相关的用户案例,比如引用[5]中的解决方法,可能对用户有帮助。</think>### 解决方案:缺少`setup.py`或`pyproject.toml`文件的错误处理 当使用`pip install -e`安装Python项目时,若出现`Neither 'setup.py' nor 'pyproject.toml' found`错误,通常是因为项目目录缺少必要的构建文件。以下是分步解决方案: --- #### **1. 确认项目目录结构** 确保当前目录是**项目的根目录**,且包含有效的Python项目文件。 - **正确示例**: ``` my_project/ ├── pyproject.toml # 或 setup.py ├── src/ │ └── my_module/ └── setup.cfg ``` - **错误示例**: 若在子目录(如`tests/`或`docs/`)中运行命令,会因找不到构建文件而报错[^1][^2]。 --- #### **2. 生成缺失的构建文件** 如果项目缺少`setup.py`或`pyproject.toml`,需手动创建: - **使用`setup.py`(传统方式)**: ```python from setuptools import setup, find_packages setup( name="my_project", version="0.1", packages=find_packages(), ) ``` - **使用`pyproject.toml`(现代方式)**: ```toml [build-system] requires = ["setuptools>=61.0"] build-backend = "setuptools.build_meta" ``` --- #### **3. 检查项目构建工具** 某些项目依赖特定构建工具(如`poetry`或`flit`),需先安装工具并生成元数据: ```bash # 使用 poetry pip install poetry poetry install # 使用 hatch pip install hatch hatch build ``` --- #### **4. 直接安装本地包(无需可编辑模式)** 若无需开发模式,可尝试直接安装: ```bash # 安装本地目录 pip install . # 安装 .whl 文件(确保路径正确) pip install /path/to/package.whl ``` 参考[^5]中用户通过指定`.whl`文件路径成功安装的案例。 --- #### **5. 升级pip和setuptools** 旧版工具可能导致兼容性问题: ```bash pip install --upgrade pip setuptools ``` --- #### **6. 验证项目完整性** 若项目从Git仓库克隆,检查是否遗漏文件(如`.gitignore`可能排除了构建文件)。 --- ### 总结流程图 ```plaintext 检查目录位置 ↓ 确认存在setup.py/pyproject.toml → 是 → 运行pip install -e . ↓否 手动生成构建文件 → 重试安装 ↓失败 尝试直接安装(非可编辑模式)或检查工具链 ```
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值