SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training

1. 问题背景

表格数据具有以下的特殊性:(1)表格数据包含表示连续、分类和有序值等多个特征,其中值可以是独立的,也可以是相关的。(2)表格数据中列的顺序可以是任意的,没有固定的位置信息

由于表格数据的特殊性,表格模型必须处理来自多个离散和连续分布的特征,并且在不依赖位置信息的情况下发现相关性。

下面作者介绍了专门处理表格数据的模型SAINT, SAINT用了很多方法来克服在表格数据上训练的困难。SAINT对连续特征也进行了类embedding的处理,把每个numerical的feature 用一个1Xd的dense层+relu 直接投影到d维的embedding空间中,category部分就直接走embedding层,这些新向量作为token传入到Transformer的编码器中,该编码器通过以下两种方式使用注意力:

第一是自注意力,它关注每个数据样本内的个体特征。

第二是样本间注意力,它将一行(即一个数据样本)与其他行相关联。

另外作者还利用了对比式无监督预训练来提高半监督问题的性能。

2. 相关工作

2.1 表格模型

传统处理表格数据的模型是TabTransformer,它使用了transformer的编码器在分类特征上学习上下文嵌入,连续特征被接到嵌入特征中,并送入MLP。该模型的主要问题是连续数据不经过自注意力块。这意味着任何关于分类特征和连续特征之间相关性的信息都会丢失。

在作者的模型中,通过将连续特征和分类特征投影到高维嵌入空间,并送入transformer块来解决这个问题。作者提出了一种新型的注意力机制,允许数据点相互关注,以获得更好的表示。

2.2 轴向注意力

作者提出了样本间注意力,首先进行一个样本点的特征交互,然后样本点与其他样本相互交互。

2.3 自监督学习

在语言和视觉领域,通过在无标记数据上执行 “前置任务 ”并在有标记数据上进行微调来实现自监督。对表格数据进行自监督的前置任务有掩码、去噪和替换令牌检测。

  • 掩码(或掩码语言模型MLM)是当个体特征被掩盖时,模型去填充它们的值。
  • 去噪会向数据中注入各种类型的噪声,目的是恢复原始值。
  • 替换标记检测(RTD)在特征向量中插入随机值,并试图找到这些随机值的位置。 
  • 对比式预训练,即在最大化两个不同点之间的距离的同时,最小化同一点的两个视图之间的距离。

作者将对比式预训练与去噪相结合,在不同数据集上进行预训练,结果表明该方法优于传统的boosting方法。

3. Self-Attention and Intersample Attention Transformer (SAINT)

假设x_{i}是第i个样本点的长度为n的特征向量。y_{i}是第i个样本点的标签。共m个样本点。我们向每个样本追加一个CLS。这样x_{i}=[CLS, f_{i}^{\{1\}}, f_{i}^{\{2\}}, ...,f_{i}^{\{n\}}]。E表示嵌入层,将每个特征嵌入到d维空间中。这样x_{i}的形状是n+1,E(x_{i})的形状是(n+1)*d

在表格数据中,不同特征的分布可能是不同的,这就需要一种异构的嵌入方式。作者提出将分类特征、连续特征投影到d维空间,然后再通过Transformer编码器。

3.1 模型结构

 SAINT由L个stage组成,每个stage由一个自注意力块和一个样本间注意力块组成。自注意力块由多头自注意力层(h个头)、具有GELU的全连接前馈层、跳跃连接和层归一化组成。样本间注意力块与自注意力块类似,只是将自注意力层换成了样本间注意力层。

有1个stage,一批b个输入的时候,具体公式如下:

3.2 样本间注意力

样本间注意力是一种行注意力,注意力是通过不同数据点之间计算的。具体来说,将单个样本的每个特征串联起来,然后计算各样本间的注意力。这使我们能够通过检查其他点来改进给定点的表示。当某行中某特征缺失或有噪声时,样本间注意力使SAINT能够从其他相似样本中借用相应的特征。

在单个头部中计算样本间注意力的流程如下:

在多头的情况下,将q、k、v投影到 d/h 维上,其中h是头的个数。然后将新向量拼接起来,得到长度为d的向量。

4. 预训练&微调

4.1 数据增强

VIME的作者使用mixup作为数据增强的方法,但这仅限于连续数据。作者在输入空间使用CutMix增强样本,在嵌入空间中使用mixup。

数据增强的公式如下,该公式中,m是一个01掩码向量,α是mixup的参数。在cutmix中,对每一个样本点x_{i}随机选择另一个样本点x_{a}, 利用掩码向量m,第i个样本的部分特征被第a个随机样本对应位上的特征替换掉,得到每个样本点的CutMix版本x_{i}^{'}。在Mixup中,选择一个新的样本点x_{b}^{'}, 执行混合。

4.2 SAINT和投影头

投影头是由一个隐藏层和一个ReLU组成的。使用投影头可以降低维数,并提高效果。

4.3 损失函数

 在预训练阶段,损失包含两部分:

(1)对比损失,它鼓励同一数据点的两个视图(zi 和 z′i)接近,鼓励不同点(zi 和 zj,i != j)相互远离。这里作者使用了InfoNCE损失。

(2)去噪损失,我们试图从噪声视图中预测原始数据样本。我们得到 r′i,我们将输入重构为 x′′ i,以最小化原始数据 xi 与重建数据 x''i 之间的差异。这里使用交叉熵损失(特征是离散的)或均方误差(特征是连续的)。

公式中的MLP是由一个隐藏层和一个ReLU组成。共有n个,每个特征对应一个MLP。\lambda _{pt}\tau是超参数。

4.4 微调

SAINT使用未标记数据进行预训练,使用标记数据进行微调。在最后的预测步骤中,我们只将与 [CLS] 标记相对应的token通过一个简单的 MLP,该 MLP 有一个带 ReLU 激活的单隐层,从而得到最终输出。最后使用交叉熵或均方误差进行评估。

5. 实验评价

模型变体:SAINT-s变体只有自我注意力,SAINT-i只有样本间注意力。

监督环境中,平均而言,SAINT变体每一个都比所有的基线模型表现更好。

半监督环境中,预训练的SAINT模型(具有自我和样本间的双重注意)表现最好。

嵌入连续数据是重要的,可以显著提高模型的性能。

模型选择:当特征数量很大时,SAINT-i始终优于其他变体。当训练数据点较少且特征较多时,SAINT-i的性能明显优于SAINT-s的。另外,虽然SAINT-i的参数数量远高于SAINT-s,但执行速度比SAINT-s快。

鲁棒性:SAINT和SAINT-i模型比SAINT-s模型更稳健。表明使用行注意力提高了模型对噪声训练数据的鲁棒性。当数据的90%缺失时,SAINT 在对损坏的训练数据进行训练时是可靠的。

批次的大小改变时,我们发现SAINT-i的性能变化较小,与没有样本间注意力成分的SAINT-s相当。

6. 参考

https://zhuanlan.zhihu.com/p/2672155027

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值