TabNet论文笔记

本文介绍TabNet,一种注意力可解释的表格学习方法。它使用DNN,实验结果在准确率上超XGBoost和LightGBM。TabNet输入未加工数据,用序列注意力选特征,有两种可解释性,还能进行自监督学习。实验显示其在实际数据表现和可解释性上效果良好。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1.简介

本文根据2020年《TabNet: Attentive Interpretable Tabular Learning》翻译总结的。TabNet,一个注意力的可解释的表格学习方法。

XGBoost和LightGBM近几年在表格数据处理上占据了统治地位,是基于梯度提升决策树(GBDT)的,不是DNN(deep neutral network)。DNN在处理表格数据方面一直没有较大的进展。

TabNet使用了DNN,实验结果超过了XGBoost和LightGBM。主要是准确率方面。我实践中发现TabNet的速度确实慢很多,慢几倍吧。

首先说决策树为什么会占据表格数据处理的统治地位:(1)表格数据中存在近似超平面边界,决策树处理起来非常有效率,表现也很好;(2)决策树的解释性很强,可以有效的事后解释。

TabNet的贡献如下:
1)不像基于树的方法,TabNet输入未加工的数据,不需要任何特征处理。使用梯度下降进行训练。可以灵活的集成在端到端的学习。
2)TabNet在决策树的每一步,使用序列注意力来选择特征进行推理。具有可解释性,更好的进行学习。特征的选择是实例级别的(instance-wise),它可以对每个输入是不同的。TabNet采用了单个深度学习框架进行端到端学习。
3)TabNet超越或者与其他表格学习方法不相上下。它具有两种可解释性,一个是local 解释性,即展示了各输入特征的重要性和它们如何被组合在一起;另一个是global解释性,量化输出每个输入特征的贡献。
4)最后使用了非监督预训练来预测masked 特征,我们的经典CNN设计架构取得了显著的效果提升,我们是第一个对表格数据进行自监督学习的。

下图主要是对上面第2点的描述,特征选择的示例,可以看到分为两类集合特征,分别是职业(professional occupation)和投资(investments)。

在这里插入图片描述

2.TabNet进行表格数据学习

使用传统DNN模块构建决策树分类器:

决策树对于真实世界的表格数据学习是很成功的。可以使用传统的DNN模块来实施类似的决策树,如下图。独立特征的选择是在超平面形式中获取决策边界的关键。或者说,对特征的线性组合,其中每个成分系数决定着每个特征的比例。TabNet就是基于如此的像树的函数。(1)使用稀疏的实例级别的特征选择;(2)构建一个序列的(按顺序的)多步结构,基于选择的特征,每一决策步骤可以贡献决策的一部分;(3)通过选择特征的非线性处理,来改善学习能力;(4)凭借高维和多步骤,模仿集成处理。
在这里插入图片描述

TabNet结构:

表格数据一般分为数字型特征和分类型特征。我们使用原始的数字型特征,将分类型特征换成可以训练的embedding(词嵌入)。我们不考虑任何全局的归一化特征,仅使用BN(batch normalization)。
TabNet结构如下图,一个encoder和一个decoder,其中encoder里有feature transformer、attentive transformer;而decoder只有feature transformer。
在这里插入图片描述

特征选择(attentive transformer):

特征维度:D
Batch size:B
在每个决策步骤通过相同的D维特征f∈R^(B*D)
决策步骤数N_steps
第i步输入上一步(i-1)的处理信息,决定使用哪个特征,并且输出的特征代表会被集成到整体决策中。

可学习的mask M[i]∈R^(B*D),用来作为显著特征的软选择(soft selection)。通过大部分显著特征的稀疏选择,每一个决策步骤的学习不相关特征的能力不会被浪费掉。

Mask采用乘积的形式 M[i]*f
基于上一步a[i-1]处理的特征信息,采用attentive transformer 来回去mask,如下所示:
在这里插入图片描述

其中 在这里插入图片描述

h_i是一个可以训练的函数,如上图attentive transformer的部分,包括一个FC层、BN。
P[i]是上图的“prior scale”。表示一个特征有多少是被先前使用的。
在这里插入图片描述

上式的γ是一个放松参数。当γ=1时,一个特征会被仅在一个决策步骤中使用;当γ增加时,可以在多个决策步骤中使用一个特征。

稀疏(sparsity)方程如下:
在这里插入图片描述

特征处理(feature transformer):

如架构图中的feature transformer,分为两部分,shared across和step-dependent。shared across是在所有决策步骤中。

3.自监督学习

主要是TabNet架构图中的decoder。
在这里插入图片描述

4.实验结果

4.1 实际数据表现

Forest cover type:森林覆盖类型,是一个分类任务。从下表可以看出来TabNet好于XGBoost、LightGBM,和一些已知的神经网络模型AutoInt、AUTOML。
在这里插入图片描述

Rossmann store sales:预测商店销售,基于静态和随时间改变的特征。
在这里插入图片描述

4.2 可解释性

如下表,对成年人收入的特征重要性排名,TabNet也取得了和其他方法类似的效果。比如年龄是第一位。
在这里插入图片描述

4.3 自监督学习

可以看到非监督预训练(pre-training)显著改善了监督学习的效果,特别是当未标注的数据比标注数据多很多时。
在这里插入图片描述

<think>好的,用户之前已经问过关于TabNet输出作为Transformer输入的问题,现在又进一步提问,可能是在之前的回答中还有不太清楚的地方,或者想深入探讨某些方面。我需要仔细分析用户的新问题,确保准确理解他们的需求。 首先,用户这次的提问是:“我的意思是tabnet的输出作为transformer的输入”。看起来用户可能在确认之前的解释,或者希望得到更具体的说明。之前的回答已经分步骤解析了输出形式、适配挑战、解决方案、应用场景以及与ResNet的对比,但用户可能希望更详细地了解如何实际实现这种集成,或者有哪些具体的注意事项。 接下来,我需要回顾之前的回答,确保这次回答在一致性的基础上补充更多细节。用户可能关心的是如何具体操作,比如代码层面的实现,或者实际应用中的案例。此外,用户可能对为什么需要构造多token序列,以及不同构造方法的效果比较感兴趣。 考虑到用户可能的技术背景,他们可能正在尝试构建一个结合TabNet和Transformer的模型,用于处理表格数据,可能在分类、回归或时间序列预测任务中。因此,需要强调实际应用中的步骤,包括特征处理、模型结构设计、训练技巧等。 另外,用户可能对性能优化或常见问题感兴趣,比如如何处理高维特征,或者在不同数据集上的表现。需要确保回答不仅涵盖理论,还提供实用建议,帮助用户避免常见陷阱。 最后,检查是否需要引入更多示例或类比,以帮助用户更好地理解概念。例如,对比不同特征扩展方法的效果,或者提供简化的代码片段说明如何将TabNet的输出转换为Transformer的输入格式。 总结来说,用户需要更具体、更详细的指导,将TabNet与Transformer结合使用的实际步骤、注意事项以及应用实例,同时解答可能的疑惑,确保他们能够顺利实施这一架构。</think>好的!当我们将 **TabNet 的输出直接作为 Transformer 的输入**时,需要特别注意两者的接口设计,因为它们的输出/输入形式存在本质差异。以下是详细解析: --- ### **1. TabNet 的输出特性** - **基本形式**:对每个样本(即表格中的一行),TabNet 输出一个 **固定长度的向量**,例如维度为 $D=64$,形状为 $1 \times D$。 - 例如:输入一个包含 10 个特征的样本 → 输出一个 $1 \times 64$ 的向量。 - **语义信息**:该向量是 **全局特征表示**,通过注意力机制筛选和聚合原始特征的信息,可理解为“浓缩后的特征摘要”。 --- ### **2. Transformer 的输入要求** - **输入形式**:Transformer 需要 **序列化输入**,即每个样本对应一个由多个 token 组成的序列(例如 $N=10$ 个 token,每个 token 维度为 $D=64$)。 - 例如:输入形状需为 $B \times N \times D$(Batch size × Token数量 × 特征维度)。 - **核心需求**:Transformer 依赖 **token 间的关系建模**(通过自注意力机制)。如果输入仅有一个 token(如 TabNet 的默认输出),Transformer 无法有效工作(退化为全连接网络)。 --- ### **3. 直接输入的问题** 若直接将 TabNet 的 $1 \times D$ 向量输入 Transformer: - **单 token 输入**:每个样本仅对应 1 个 token,形状为 $B \times 1 \times D$。 - **效果限制**: - Transformer 的自注意力机制无法发挥作用(无 token 间交互)。 - 模型本质上等价于“TabNet + 全连接层”,无法利用 Transformer 的长程依赖建模能力。 --- ### **4. 如何适配?——构造多 token 序列的 3 种方法** 要让 TabNet 的输出适配 Transformer,必须 **将单 token 扩展为多 token 序列**。以下是具体方法: #### **方法 1:特征分块(Feature Chunking)** - **原理**:将 TabNet 输出的 $D$ 维向量切分为多个子向量,作为 token。 - 例如:TabNet 输出 $D=64$ → 切分为 4 个 $D'=16$ 的块 → 构造 4 个 token。 - **实现**: ```python # 输入:TabNet输出 (batch_size, D=64) tokens = tabnet_output.reshape(batch_size, 4, 16) # 形状变为 (B, 4, 16) ``` - **适用场景**:特征维度较高($D \geq 32$),且特征具有局部相关性。 #### **方法 2:历史堆叠(时序数据)** - **原理**:针对时序表格数据,将多个时间步的 TabNet 输出拼接为序列。 - 例如:过去 5 天的特征向量 → 构造 5 个 token。 - **实现**: ```python # 输入:时序TabNet输出 (batch_size, time_steps=5, D=64) tokens = tabnet_output # 直接作为序列输入,形状为 (B, 5, 64) ``` - **适用场景**:时间序列预测、用户行为序列建模。 #### **方法 3:混合外部特征** - **原理**:将 TabNet 的输出与其他模态的特征拼接,形成多 token 序列。 - 例如: - Token 1:TabNet 输出(用户画像) - Token 2~N:文本描述的特征(通过BERT编码) - **实现**: ```python # TabNet输出 (B, D1=64) + 文本特征 (B, N-1, D2=64) tokens = torch.cat([tabnet_output.unsqueeze(1), text_features], dim=1) # 形状 (B, N, 64) ``` - **适用场景**:多模态数据(如表格+文本/图像)。 --- ### **5. 模型架构示例** 以下是一个完整的流程示例(以分类任务为例): 1. **TabNet 编码**:输入表格数据 → 输出 $B \times D$ 的向量。 2. **构造序列**: - 若使用特征分块:将 $D=64$ 分块为 4 个 $16$ 维 token → 形状 $B \times 4 \times 16$。 3. **Transformer 处理**: - 添加位置编码(Positional Encoding)。 - 通过多层 Transformer Encoder 建模 token 间关系。 4. **分类头**: - 取第一个 token 的输出(或全局池化) → 全连接层预测类别。 --- ### **6. 关键注意事项** - **位置编码**:必须为构造的 token 序列添加位置编码(否则 Transformer 无法区分顺序)。 - **维度对齐**:确保分块后的 token 维度($D'$)与 Transformer 的隐藏层维度匹配。 - **计算效率**:避免构造过长的序列(如分块过多),否则 Transformer 的计算复杂度 $O(N^2)$ 会显著增加。 --- ### **7. 与图像模型(如ResNet)的本质区别** | 对比项 | ResNet + Transformer | TabNet + Transformer | |-----------------|------------------------------------|------------------------------------| | **输入结构** | 天然具有空间维度($H \times W$) | 需人工构造序列(如分块/堆叠) | | **信息密度** | 局部特征(像素/区域) | 全局特征(高度抽象) | | **典型应用** | 图像分类、目标检测 | 表格数据分类、时序预测、多模态建模 | --- ### **总结** - **直接输入不可行**:TabNet 的单 token 输出无法直接发挥 Transformer 的优势。 - **必须构造序列**:通过分块、时序堆叠或多模态融合生成多 token 输入。 - **设计核心**:根据任务需求选择序列构造方法,并确保维度与位置编码的适配性。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值