Spark2.3 源码解析 之 梯度提升树 gradient boosting tree

本文深入解析Spark2.3中的梯度提升树算法,涵盖理论基础和源码实现。文章讨论了boosting的概念,重点介绍了gradient boosting的原理,包括损失函数定义、Logloss推导以及Spark中使用的损失函数。在源码部分,详细阐述了训练过程中的决策树构建,以及预测阶段的策略,如加权求和与投票机制。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

Spark2.3 源码解析 之 梯度提升树 gradient boosting tree

一、理论

理论部分源自 Machine Learning-A Probabilistic Perspective(MLAPP)Elements of Statistical Machine Learning(ESML)

1、boosting

boosting是一种greedy算法,书中也称作一种adaptive basis-function model (ABM),形式如下:

φm称作弱学习器,可以是任意算法模型,常用的是CART。boosting由多个弱学习器组成,每个弱学习器的目的都是关注上一轮分类错误的样本,从而进一步减少模型的loss。其中权重w的目的是增加错分点的权重,使得越错误的点越容易被关注。最后所有弱学习器的加权或者投票结果就形成了boosting的模型输出。

2、gradient boosting

     gradient boosting的伪代码如下(摘自MLAPP):

梯度提升(gradient boosting)是一种改进的boosting方法。boosting算法的目标是让模型输出逼近真实值,即最小化Loss=L(y,f)。那么boosting每一步的目标就是:通过 弱学习器φm 来改进f,使得loss逐步减少
     那么问题就来了:f 应该向什么方向改进呢?即φm取什么值才能让loss下降最快呢?这就离梯度gradient很近了,因为loss下降最快的方向就是对f的一阶导数 即梯度gradient。因此根据gradient梯度来更新f就是gradient boosting的目的。
    那么具体怎么更新呢:每次根据gradient(伪代码中 r )拟合一个弱学习器φ,并累加给f(乘以学习速率)。即gradient boosting模型中每个弱学习器都是对gradient的拟合。

(1)损失函数定义loss

根据任务类型不同,损失函数Loss的定义也不同,具体见下图(MLAPP 556页 table16.1):

回归问题主要用squared error,分类问题主要用logloss

(2)Logloss推导

平方损失和绝对值损失不详细介绍,Expoential指数损失过于关注错误样本,因此受到错误样本的干扰很大(比如非常离谱的数据),这里也不做介绍。
此处主要讨论gradient boosting的Logloss如何得来的,不啰嗦,直接赋上ESML的介绍:

公式中Y=1的概率P=1.0 / (1.0 + math.exp(-2.0 * margin)),margin即为公式中的f(x)。训练完成后,即可根据该公式计算概率。
其中,根据p(x)推导出损失函数 l(Y,p(x))不难,根据p(x)也可以推导出f(x)=1/2的log-odds
但问题是:p(x)如何而来?为什么不是LR一样的sigmoid函数?

注意:这里的Loss是 y与f之间的loss,其中导数也是对f求导(因为每个f是一个弱分类器,目的是f拟合梯度gradient),这里的loss和gradient都与x无关
(3)spark中损失函数
    在spark采用的损失函数为:
其中Logloss是在(2)中logloss 2倍࿰
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值