Efficient-KAN 源码详解

Efficient-KAN源码链接

Efficient-KAN (GitHub)

改进细节

1.内存效率提升

KAN网络的原始实现的性能问题主要在于它需要扩展所有中间变量以执行不同的激活函数。对于具有in_features个输入和out_features个输出的层,原始实现需要将输入扩展为shape为(batch_size, out_features, in_features)的tensor以执行激活函数。然而,所有激活函数都是一组固定基函数(3阶B样条)的线性组合。鉴于此,拟将计算重新表述为不同的基函数激活输入,然后将它们线性组合。这种重新表述可以显著减少内存消耗,并使计算变得更加简单的矩阵乘法,自然地适用于前向和后向传递。

2.正则化方法的改变

稀疏化被认为对KAN的可解释性至关重要。作者提出了一种定义在输入样本上的L1正则化,它需要对**(batch_size, out_features, in_features)** tensor进行非线性操作,因此与重新表述不兼容。拟改为对权重进行L1正则化,这在NN中更为常见,并且与重新表述兼容。

3.激活函数缩放选项

除了可学习的激活函数(B样条),原始实现还包括对每个激活函数的可学习缩放 ( w s ) (w_s) (ws)。拟提供一个名为enable_standalone_scale_spline的选项,默认情况下为True,以包含此功能。禁用它会使模型更高效,但可能会影响结果。这需要更多实验验证。

4.参数初始化的改变

为了解决在MNIST数据集上的性能问题,该代码修改了参数的初始化方式,采用Kaiming初始化

KAN_fast.py解析

基本参数和类定义

import torch
import torch.nn.functional as F
import math

class KANLinear(torch.nn.Module):
    def __init__(
        self,
        in_features,
        out_features,
        grid_size=5,  # 网格大小,默认为 5
        spline_order=3, # 分段多项式的阶数,默认为 3
        scale_noise=0.1,  # 缩放噪声,默认为 0.1
        scale_base=1.0,   # 基础缩放,默认为 1.0
        scale_spline=1.0,    # 分段多项式的缩放,默认为 1.0
        enable_standalone_scale_spline=True,
        base_activation=torch.nn.SiLU,  # 基础激活函数,默认为 SiLU(Sigmoid Linear Unit)
        grid_eps=0.02,
        grid_range=[-1, 1],  # 网格范围,默认为 [-1, 1]
    ):
        super(KANLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size # 设置网格大小和分段多项式的阶数
        self.spline_order = spline_order

        h = (grid_range[1] - grid_range[0]) / grid_size   # 计算网格步长

生成网格

 grid = ( # 生成网格
            (
                torch.arange(-spline_order, grid_size + spline_order + 1) * h
                + grid_range[0] 
            )
            .expand(in_features, -1)
            .contiguous()
        )
self.register_buffer("grid", grid)  # 将网格作为缓冲区注册

1. torch.arange(-spline_order, grid_size + spline_order + 1)
  • **torch.arange(start, end)**:生成一个从 startend-1 的整数序列(左闭右开区间)。
  • **-spline_order**:从负的 spline_order 开始。
  • **grid_size + spline_order + 1**:终止于 grid_size + spline_order(不包括 +1)。

这个序列的长度是 grid_size + 2 * spline_order + 1,用于涵盖所有需要的网格点,包括两端的扩展区域。

2. * h
  • 这一步将生成的整数序列乘以步长 h,将索引序列转换为实际的网格位置。
3. + grid_range[0]
  • 这一步将整个网格位置进行平移,使得网格的起始点与 grid_range[0] 对齐。

如果 grid_range[0] = -1,则每个位置都会减去 1

4. .expand(in_features, -1)
  • **.expand()**:将这个网格复制 in_features 次,以适应输入特征的维度。具体来说,它将原本的一维网格向量扩展成一个 in_features × (grid_size + 2 * spline_order + 1) 的二维张量。 其中每一行都是相同的网格向量。
5. .contiguous()
  • **.contiguous()**:确保扩展后的张量在内存中是连续存储的,方便后续的计算和操作。虽然在大多数情况下这个操作是可选的,但它可以提高计算效率并避免潜在的问题。

最终效果:

这段代码生成了一个二维张量 grid,它的形状为 [in_features, grid_size + 2 * spline_order + 1],其中每一行都是相同的、覆盖整个 grid_range 并适当扩展的网格点序列。这个网格用于模型中的 B 样条或其他基函数计算,使得模型可以在输入数据范围内执行灵活的插值和拟合操作。

初始化可训练参数

        self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features)) # 初始化基础权重和分段多项式权重
        self.spline_weight = torch.nn.Parameter(
            torch.Tensor(out_features, in_features, grid_size + spline_order)
        )
        if enable_standalone_scale_spline:  # 如果启用独立的分段多项式缩放,则初始化分段多项式缩放参数
            self.spline_scaler = torch.nn.Parameter(
                torch.Tensor(out_features, in_features)
            )

1. self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features)) ( w b ) (w_b) (wb)

  • **torch.Tensor(out_features, in_features)**:创建一个形状为 (out_features, in_features) 的未初始化张量,用于存储基础线性层的权重。这个张量的元素初始时没有具体的数值,通常在后续的 reset_parameters() 方法中进行初始化。
  • **torch.nn.Parameter**:将这个张量封装成 torch.nn.Parameter 对象。这意味着这个张量会被视为模型的可训练参数,PyTorch 会自动将其包含在模型的参数列表中,并在反向传播时更新其值。
  • **self.base_weight**:这个属性存储的是基础线性变换的权重矩阵。这个矩阵将在前向传播过程中被用来对输入特征进行线性变换。

2. self.spline_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features, grid_size + spline_order)) ( c i ) (c_i) (ci)

  • **torch.Tensor(out_features, in_features, grid_size + spline_order)**:创建一个形状为 (out_features, in_features, grid_size + spline_order) 的未初始化张量,用于存储分段多项式的权重。这些权重将用于 B 样条或其他类似方法的计算。
  • **torch.nn.Parameter**:同样地,将这个张量封装成 torch.nn.Parameter,使其成为模型的可训练参数。
  • **self.spline_weight**:这个属性存储的是与分段多项式相关的权重。这些权重决定了如何将输入特征映射到输出特征,特别是在使用 B 样条等非线性激活函数时。
为什么 spline_weight 的形状是 (out_features, in_features, grid_size + spline_order)
  • **out_features****in_features**:与 base_weight 类似,表示输出和输入的特征数量。
  • **grid_size + spline_order**:这个维度表示在 B 样条或其他分段多项式方法中,每个输入特征需要使用的基函数的数量。通过这些基函数的线性组合,可以生成灵活的非线性激活。

3. if enable_standalone_scale_spline:

  • 这个条件语句检查 enable_standalone_scale_spline 是否为 True。如果为 True,则会为每个分段多项式激活函数引入一个独立的缩放参数。

4. self.spline_scaler = torch.nn.Parameter(torch.Tensor(out_features, in_features)) ( w s ) (w_s) (ws)

  • **torch.Tensor(out_features, in_features)**:创建一个形状为 (out_features, in_features) 的张量,用于存储独立的分段多项式缩放参数。
  • **torch.nn.Parameter**:将张量封装成 torch.nn.Parameter,使其成为可训练参数。
  • **self.spline_scaler**:这个属性存储的是分段多项式的缩放参数。每个 spline_weight 都有一个对应的缩放参数,可以单独调整其幅度,从而提供更大的灵活性。

其他实例属性

        self.scale_noise = scale_noise # 保存缩放噪声、基础缩放、分段多项式的缩放、是否启用独立的分段多项式缩放、基础激活函数和网格范围的容差
        self.scale_base = scale_base
        self.scale_spline = scale_spline
        self.enable_standalone_scale_spline = enable_standalone_scale_spline
        self.base_activation = base_activation()
        self.grid_eps = grid_eps

        self.reset_parameters()  # 重置参数

Kaiming初始化权重(reset_parameters)

def reset_parameters(self):
        torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)# 使用 Kaiming 均匀初始化基础权重
        with torch.no_grad():
            noise = (# 生成缩放噪声
                (
                    torch.rand(self.grid_size + 1, self
<think>好的,我现在需要帮助用户查找与KAN图片相关的资源或信息,特别是关于格式或查看器的内容。首先,我需要回顾用户提供的引用内容,看看是否有相关的信息。 用户提供的引用中有四个来源。引用[1]提到了KAN 2.0,即Kolmogorov-Arnold网络,可能涉及相关的科学应用,但没有直接提到图片格式或查看器。引用[2]和[3]涉及到PyTorch模型的使用和项目结构,其中引用[3]的项目地址可能包含相关资源。引用[4]是一个百度论坛的链接,可能包含社区讨论的信息。 接下来,我需要考虑用户的问题:“KAN 图片 格式 OR KAN 图片 查看器”。这里的KAN可能指代不同的内容,比如Kolmogorov-Arnold Networks相关的图片,或者是某种特定格式的图片文件。用户可能在使用KAN模型时遇到了需要查看相关图片的情况,或者有特定格式的图片需要处理。 首先,我需要确认KAN是否指代某种图片格式。通常,常见的图片格式如PNG、JPEG等,但KAN可能是一个特定项目或工具使用的专有格式。根据引用[3]中的项目地址,可能存在与KAN相关的自定义图片处理工具或格式。因此,建议用户查看该项目的文档或代码仓库,了解是否有相关的图片格式说明或查看器工具。 其次,引用[2]中提到了使用PIL库处理图片,可能KAN项目在图像处理中使用特定的预处理步骤,但并未提及专门的查看器。如果用户需要处理模型训练中的图片,可能需要按照引用[2]中的步骤使用PIL和transforms进行处理,但这可能不是用户所需的查看器。 另外,用户可能需要的是KAN模型生成的图片,比如可视化结果。在这种情况下,建议用户参考项目文档,查看是否有生成或可视化图片的工具。例如,许多机器学习项目会提供脚本或工具来可视化网络结构或输出结果,这可能涉及特定格式的图片。 最后,用户提到的百度论坛链接(引用[4])可能包含社区成员讨论的KAN相关资源,可能有其他用户分享过查看器工具或处理KAN图片的方法。建议用户访问该链接,搜索相关讨论。 总结可能的解决方案:1. 检查KAN项目的官方文档或代码仓库(引用[3])获取图片格式和工具;2. 使用通用图片查看器尝试打开,若为专有格式可能需要转换;3. 查阅社区论坛(引用[4])寻求帮助;4. 根据引用[2]中的方法处理图片,但可能不涉及查看器。 在回答时,需要将这些步骤清晰地列出,并确保引用正确的来源。同时,生成相关问题以帮助用户进一步探索,如询问项目结构、处理KAN图片的方法等。</think>根据您提出的需求,以下是关于KAN图片资源及查看方式的系统化解决方案: ### 一、KAN图片资源获取渠道 1. **官方项目仓库** 访问引用[3]中提供的项目地址(https://gitcode.com/gh_mirrors/to/torch-conv-kan),查看`docs`或`examples`目录下的可视化案例。部分科研项目会在仓库中提供网络结构图、训练过程可视化结果等图片资源。 2. **论文配套材料** 查阅KAN 2.0相关论文的补充材料(如引用[1]),科研论文通常会提供网络架构图、性能对比图等高质量示意图,这些图片多采用PDF/SVG矢量格式或PNG格式。 3. **开发者社区** 在引用[4]的百度开发者论坛中搜索关键词"KAN 可视化"或"KAN 图片",查看是否有开发者分享过可视化工具或资源链接。 ### 二、图片格式处理建议 1. **通用格式支持** 若图片为标准格式(PNG/JPG等),可直接使用: - 系统默认图片查看器 - 专业工具:Adobe Photoshop/GIMP - 科研工具:Matplotlib、ImageJ 2. **特殊格式处理** 若遇到`.kanviz`等专用格式: ```python # 示例:通过项目代码解析专有格式 from kan.visualization import load_kan_image image_data = load_kan_image("network.kanviz") image_data.render() # 生成可交互可视化界面 ``` *需确认具体项目是否提供类似API(参考引用[3]的代码结构)* ### 三、关键操作注意事项 1. **格式转换规范** 使用引用[2]中的图像处理流程时: ```python from PIL import Image img = Image.open("input.kanimg").convert("RGB") # 格式转换 img.save("output.png") # 转换为通用格式 ``` 2. **可视化开发指引** - 网络结构可视化推荐工具:Netron、TensorBoard - 特征图可视化需通过hook机制捕获中间层输出(引用[3]项目可能已包含相关示例)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值