GBDT-回归树的建立

GBDT中的树是回归树。

最近才开始看GBDT的内容,发现基础是回归树,抓起统计学习方法(p69 算法5.5)就开始看,发现书上的那些式子很晦涩,翻阅了很多的博客,大致了解了回归树的建立方法。之后会总结GBDT的回归和分类模型。

一。算法思想,回归树递归地将每个区域划分为两个子区域并决定每个子区域的输出值。

二。步骤:

1.遍历选定的变量j,对固定的切分变量j扫描切分点s,选择使下式最小的最小切分变量j和切分点s。

2.用选定的(j,s)划分出区域R1,R2,并决定它们的输出值,也就是区域内的均值。

3.重复1,2步骤,直到满足结束条件

4.生成回归树,也可以理解为是分段函数

 

三。代码实现:

此代码是看了此作者的博客写成。参考:https://blog.youkuaiyun.com/xiaoxiao_wen/article/details/54098015

 

import numpy as np

#参考 https://blog.youkuaiyun.com/xiaoxiao_wen/article/details/54098015

def cart_regression_tree(start,end,y):
    if start!=end:#终止条件
        m = []
        for i in range(start,end):
            c1 = np.average(y[start:i+1])
            c2 = np.average(y[i+1:end+1])#注意这里需要end+1,要不然最后一个元素取不到
            y1 = y[start:i+1]
            y2 = y[i+1:end+1]
            m.append((sum(pow((y1-c1),2))+sum(pow((y2-c2),2))))

        index = m.index(min(m))+start#注意这里需要加start,否则会死循环
        print("切分点为:",index)
        print("大于",index,"的输出值为",np.average(y[start:index+1]))
        print("小于",index,"的输出值为",np.average(y[index+1:end+1]))
        cart_regression_tree(start,index,y)
        cart_regression_tree(index+1,end,y)
    else:
        return None

if __name__ == '__main__':
    x = np.arange(0,10)
    y = [4.5,4.75,4.91,5.34,5.80,7.05,7.90,8.23,8.70,9.00]
    cart_regression_tree(0,9,y)

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值