FTTransformer,一个很能打的模型

FTTransformer,是一个BERT模型架构在结构化数据集上的迁移变体。和BERT一样,它非常能打。

它可能是少数能够在大多数结构化数据集上取得超过或者匹配LightGBM结果的深度模型。

本范例我们将应用它在来对Covertype植被覆盖数据集进行一个多分类任务。

我们在测试集取得了91%的准确率,相比之下LightGBM只有83%的准确率。

公众号算法美食屋后台回复关键词:torchkeras,获取本文notebook源码和所用Covertype数据集下载链接。

〇,原理讲解

FTTransformer是一个可以用于结构化(tabular)数据的分类和回归任务的模型。

FT 即 Feature Tokenizer的意思,把结构化数据中的离散特征和连续特征都像单词一样编码成一个向量。

从而可以像对text数据那样 应用 Transformer对 Tabular数据进行特征抽取。

值得注意的是,它对Transformer作了一些微妙的改动以适应 Tabular数据。

例如:去除第一个Transformer输入的LayerNorm层,仿照BERT的设计增加了output token(CLS token) 与features token 一起进行进入Transformer参与注意力计算。

一,准备数据

 
 
import numpy as np 
import pandas as pd 
from sklearn.model_selection import train_test_split


file_path = "covertype.parquet"
dfdata = pd.read_parquet(file_path)
...


dftmp, dftest_raw = train_test_split(dfdata, random_state=42, test_size=0.2)
dftrain_raw, dfval_raw = train_test_split(dftmp, random_state=42, test_size=0.2)


print("len(dftrain) = ",len(dftrain_raw))
print("len(dfval) = ",len(dfval_raw))
print("len(dftest) = ",len(dftest_raw))
dfdata.shape =  (581012, 13)
target_col =  Cover_Type
cat_cols =  ['Wilderness_Area', 'Soil_Type']
num_cols =  ['Elevation', 'Aspect', 'Slope', '...']
len(dftrain) =  371847
len(dfval) =  92962
len(dftest) =  116203
 
 
from torchkeras.tabular import TabularPreprocessor
from sklearn.preprocessing import OrdinalEncoder


#特征工程
...


dftest = pipe.transform(dftest_raw.drop(target_col,axis=1))
dftest[target_col] = encoder.transform(
    dftest_raw[target_col].values.reshape(-1,1)).astype(np.int32)
 
 
from torchkeras.tabular import TabularDataset
from torch.utils.data import Dataset,DataLoader 


def get_dataset(dfdata):
    return TabularDataset(
                data = dfdata,
                task = 'classification',
                target = [target_col],
                con
评论 1
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值