KAN网络技术最全解析—最热KAN能否干掉MLP和Transformer?(收录于GPT-4/ChatGPT技术与产业分析)

文章探讨了KAN网络结构,其在边上的可学习激活函数相比于MLP的优势,包括在数据拟合和PDE求解中的精度提升以及更好的可解释性。文章还分析了KAN与Spline的关系,以及KAN架构的优化和挑战,指出KAN的发展空间和瓶颈主要是训练速度和计算效率问题。

陈巍:KAN网络结构思路来自Kolmogorov-Arnold表示定理。MLP 在节点(“神经元”)上具有固定的激活函数,而 KAN 在边(“权重”)上具有可学习的激活函数。在数据拟合和 PDE 求解中,较小的 KAN 可以比较大的 MLP 获得更好的准确性。

相对MLP,KAN也具备更好的可解释性,适合作为数学和物理研究中的辅助模型,帮助发现和寻找更基础的数值规律。(点赞是我们分享的动力)

MLP与KAN对比

与传统的MLP 相比,KAN 有4个主要特点:

1)激活函数位于“边”而不是节点(Node)上;

2)激活函数是可学习的而不是固定的;

3)可以使用非线性核函数来替代MLP“边”(Edge)上的线性函数;

4)可以设定细粒度的结点(Knot)来提高逼近精度。

就我们来看,由于可以设置单独的激活层来替代“边”上激活的结构(相当于每个“边”上插入一个节点),因此特点1并非KAN的核心特征。特点4在实际使用时需要做大量的查表插值,对存储带宽的要求非常高,可能并不一定适合常规的AI芯片进行计算,估计在未来也不是KAN的核心特征。

--------------------------------------------------

主编作者陈巍 博士,高级职称,曾担任华X系相关自然语言处理( NLP )企业的首席科学家,大模型算法-芯片协同设计专家,国际计算机学会(ACM)会员、中国计算机学会(CCF)专业会员。主要研究领域包括大模型(Finetune、Agent、RAG)、存算一体、GPGPU。

目录

1 基本思路

1.1 MLP与Spline的优缺点

1.2 Kolmogorov-Arnold表示定理

2 KAN架构与缩放

2.1 KAN架构

2.2 架构细节

2.3 KAN的逼近与神经缩放

3 KAN的可解释性

4 我们2019年做的类似工作

5 小结

5.1 KAN的发展空间

5.2 KAN的瓶颈

延伸阅读


1 基本思路

1.1 MLP与Spline的优缺点

多层感知器 (MLP)是深度学习的基础理论模块,是目前可用于逼近非线性函数的默认模型,其表征能力已由通用逼近定理证明。但MLP也有明显的缺点,例如在 Transformer中,MLP 的参数量巨大,且通常不具备可解释性。

为了提升表征能力,MIT提出了KAN。KAN本质上是样条(Spline)曲线和 MLP 的组合,吸收了两者的优点。即

KAN = MLP + Spline

在数学中,样条曲线是由多项式分段定义的函数。一般的Spline可以是特定区间的3阶多项式。在插值问题中,样条插值通常优于多项式插值,因为即使使用低次多项式,也能产生类似的精度结果,同时避免了高次多项式的Runge's phenomenon(在一组等距插值点上使用高次多项式进行多项式插值时出现的区间边缘振荡问题)。

样条曲线

MLP与样条的优缺点对比

1.2 Kolmogorov-Arnold表示定理

Vladimir Arnold与 Andrey Kolmogorov证明了实分析中的如下表示定理:

如果f是多元连续函数,则f可以写成有限数量的单变量连续函数的两层嵌套叠加。其数学表达式就是

该定理解决了希尔伯特第十三问题的一个更受约束但更一般的形式。在数学上,Φq称外部函数(outer functions),Φq,p称为内部函数(inner functions)。

这表明在实数域上,唯一真正的多元函数是求和,因为所有其他函数都可以使用单变量函数求和来表征。

换个形象的说法,就是你家里做个复杂的事(包含多个任务),你爸你妈和你可以分别干活然后汇总,并不需要三个人同时忙一个任务。

顺带八卦下,Vladimir 是位前苏联神童, Andrey则是他的导师。

2 KAN架构与缩放

2.1 KAN架构

KAN的架构设计来自一个数学问题:对一个由输入输出对 {xi, yi} 组成的有监督学习任务,寻找函数f 使得所有数据点的 yi≈ f (xi)。其核心在于找到合适的单变量函数 Φq,p(内部函数)和 Φq(外部函数)。

在KAN中,使用B-spline(B样条)来构建。B-spline是基础样条(Basic Spline)的缩写。

对于B-spline,函数在其定义域内、在结点(Knot)都具有相同的连续性。其多项式表达可由Cox-de Boor 递推公式表达:

<think> 我们正在处理用户关于Vision Transformer (ViT) 实现的问题。用户提供了一个GitHub仓库链接,指向Google Research的Vision Transformer项目。我们需要解释如何实现这个模型,并参考相关知识。 步骤1: 理解Vision Transformer Vision Transformer (ViT) 是一种将Transformer架构应用于图像识别任务的方法。传统的Transformer主要用于自然语言处理,而ViT将图像分割成固定大小的图块,将这些图块线性嵌入,并添加位置嵌入,然后将这些序列输入到标准的Transformer编码器中。 步骤2: 分析GitHub仓库 用户提供的链接是: https://github.com/google-research/vision_transformer 这个仓库包含了ViT的官方实现,包括模型定义、训练评估代码。 步骤3: 实现方法的关键点 根据ViT论文该仓库的README,实现ViT主要包括以下步骤: 1. 图像分块:将输入图像分割成固定大小的非重叠图块(例如16x16像素)。 2. 图块嵌入:通过线性投影将每个图块展平并映射到一个向量,得到图块嵌入。 3. 添加位置嵌入:为每个图块嵌入添加位置信息,以保留空间信息。 4. 添加类别标记:在序列开头添加一个可学习的类别标记(class token),用于最终的分类输出。 5. Transformer编码器:将整个序列输入到由多个Transformer块组成的编码器中。 6. 分类头:使用类别标记对应的输出向量,通过一个全连接层进行分类。 步骤4: 参考仓库中的代码 在Google Research的仓库中,主要代码文件是`vit_jax/models.py`(使用JAX)`vit_jax/input_pipeline.py`(数据处理)。由于用户可能使用不同的框架,我们将描述核心实现,而不限于特定框架。 步骤5: 使用PyTorch实现ViT的示例代码(非官方实现,但结构类似) 为了帮助用户理解,我们可以提供一个使用PyTorch实现的简化版本。注意:这不是Google官方的代码,但遵循ViT的架构。 ```python import torch import torch.nn as nn from einops import rearrange class PatchEmbedding(nn.Module): def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768): super().__init__() self.img_size = img_size self.patch_size = patch_size self.n_patches = (img_size // patch_size) ** 2 self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x): x = self.proj(x) # [batch, embed_dim, n_patches_sqrt, n_patches_sqrt] x = rearrange(x, 'b e h w -> b (h w) e') # 展平为 [batch, n_patches, embed_dim] return x class VisionTransformer(nn.Module): def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, num_classes=1000): super().__init__() self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim) self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.n_patches + 1, embed_dim)) self.transformer_encoder = nn.TransformerEncoder( nn.TransformerEncoderLayer( d_model=embed_dim, nhead=num_heads, dim_feedforward=mlp_ratio * embed_dim, activation='gelu' ), num_layers=depth ) self.classifier = nn.Linear(embed_dim, num_classes) def forward(self, x): batch_size = x.shape[0] # 生成图块嵌入 x = self.patch_embed(x) # [batch, n_patches, embed_dim] # 添加类别标记 cls_token = self.cls_token.expand(batch_size, -1, -1) # [batch, 1, embed_dim] x = torch.cat((cls_token, x), dim=1) # [batch, n_patches+1, embed_dim] # 添加位置嵌入 x = x + self.pos_embed # 通过Transformer编码器 x = self.transformer_encoder(x) # 提取类别标记对应的输出 cls_output = x[:, 0] # 分类 out = self.classifier(cls_output) return out ``` 注意:以上代码是一个简化版本,实际实现中需要注意位置编码的初始化、层归一化、dropout等细节。Google官方的实现使用JAX,但结构类似。 步骤6: 训练ViT的注意事项 - 需要大量的数据预训练(通常在大型数据集如ImageNet-21k或JFT-300M上预训练,然后在目标数据集上微调)。 - 使用AdamW优化器,并采用适当的学习率衰减策略。 - 数据增强(如Mixup, CutMix, RandAugment)正则化(如Dropout, weight decay)非常重要。 步骤7: 参考资源 - 原论文: [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) - 官方GitHub仓库(用户提供的链接): https://github.com/google-research/vision_transformer 步骤8: 生成相关问题 根据用户的问题,我们生成以下相关问题: §§相关问题§§ 1. Vision Transformer 传统的卷积神经网络(如ResNet)相比有哪些优势劣势? 2. 在训练Vision Transformer时,通常需要哪些技巧来提升性能? 3. 除了图像分类,Vision Transformer还可以应用于哪些计算机视觉任务? </think> 以下是实现 GitHub 上 [Google Research 的 Vision Transformer (ViT)](https://github.com/google-research/vision_transformer) 的关键步骤方法: ### 一、核心实现步骤 1. **图像分块处理** 将输入图像分割为固定大小的图块(如 16×16 像素),每个图块视为一个"视觉词"。公式表示: $$x \in \mathbb{R}^{H \times W \times C} \rightarrow x_p \in \mathbb{R}^{N \times (P^2 \cdot C)}$$ 其中 $P$ 为图块大小,$N = HW/P^2$ 是图块数量。 2. **线性投影(嵌入层)** 使用可训练的线性层将图块映射到 $D$ 维向量: $$\mathbf{z}_0 = [\mathbf{x}_{\text{class}}; \mathbf{x}_p^1\mathbf{E}; \cdots; \mathbf{x}_p^N\mathbf{E}] + \mathbf{E}_{\text{pos}}$$ - $\mathbf{E} \in \mathbb{R}^{(P^2 \cdot C) \times D}$ 为嵌入矩阵 - $\mathbf{E}_{\text{pos}}$ 是位置编码 3. **Transformer 编码器** 包含 $L$ 层相同的模块,每层含: ```python # 官方代码片段 (vit_jax/models.py) x = nn.LayerNorm()(x) x = MultiHeadSelfAttention(num_heads)(x) x = x + inputs # 残差连接 y = nn.LayerNorm()(x) y = MlpBlock(mlp_dim)(y) output = y + x ``` 4. **分类头** 取首个特殊标记 `[class]` 的输出做分类: $$\mathbf{y} = \text{MLP}(\mathbf{z}_L^0)$$ ### 二、环境配置 1. **安装依赖** (参考项目 README) ```bash pip install jax jaxlib flax optax # 核心依赖 pip install tensorflow tensorflow-datasets # 数据加载 ``` 2. **下载代码** ```bash git clone https://github.com/google-research/vision_transformer cd vision_transformer ``` ### 三、训练示例 ```python from vit_jax import models, train # 初始化 ViT-B/16 模型 model = models.VisionTransformer( num_classes=1000, patch_size=16, hidden_size=768, num_layers=12, num_heads=12 ) # 训练配置 (参考 configs/vit_base_i1k.py) train.train( model=model, dataset='imagenet2012', batch_size=512, learning_rate=0.001, num_epochs=300 ) ``` ### 四、关键技巧 1. **混合精度训练** 使用 `jax.config.update('jax_platform_name', 'gpu')` 启用 GPU 加速 2. **数据增强** (见 `vit_jax/input_pipeline.py`) - RandAugment - Mixup/Cutmix - 随机擦除 3. **优化器配置** ```python optimizer = optax.adamw( learning_rate=cosine_decay_schedule(base_lr), weight_decay=0.03 ) ``` ### 五、预训练模型使用 ```python import jax.numpy as jnp from vit_jax import checkpoint # 加载预训练模型 (ViT-B/16 ImageNet-21k) params = checkpoint.load('gs://vit_models/imagenet21k/ViT-B_16.npz') img = preprocess(image) # 尺寸调整为 224x224 logits = model.apply({'params': params}, img) ``` ### 六、性能优化建议 1. **大型数据集预训练**:先在 ImageNet-21k/JFT 上预训练 2. **学习率预热**:前 5% 训练步线性增加学习率 3. **梯度裁剪**:防止梯度爆炸 4. **使用 TPU**:官方推荐 TPUv3 8x8 配置训练大型模型 > 提示:完整实现细节参考[官方 GitHub 仓库](https://github.com/google-research/vision_transformer) 论文 [An Image is Worth 16x16 Words](https://arxiv.org/abs/2010.11929)[^1]。可视化工具可参考 Transformer Explainer[^2]。
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值