GBDT中的树是回归树。
最近才开始看GBDT的内容,发现基础是回归树,抓起统计学习方法(p69 算法5.5)就开始看,发现书上的那些式子很晦涩,翻阅了很多的博客,大致了解了回归树的建立方法。之后会总结GBDT的回归和分类模型。
一。算法思想,回归树递归地将每个区域划分为两个子区域并决定每个子区域的输出值。
二。步骤:
1.遍历选定的变量j,对固定的切分变量j扫描切分点s,选择使下式最小的最小切分变量j和切分点s。
2.用选定的(j,s)划分出区域R1,R2,并决定它们的输出值,也就是区域内的均值。
3.重复1,2步骤,直到满足结束条件
4.生成回归树,也可以理解为是分段函数
三。代码实现:
此代码是看了此作者的博客写成。参考:https://blog.youkuaiyun.com/xiaoxiao_wen/article/details/54098015
import numpy as np
#参考 https://blog.youkuaiyun.com/xiaoxiao_wen/article/details/54098015
def cart_regression_tree(start,end,y):
if start!=end:#终止条件
m = []
for i in range(start,end):
c1 = np.average(y[start:i+1])
c2 = np.average(y[i+1:end+1])#注意这里需要end+1,要不然最后一个元素取不到
y1 = y[start:i+1]
y2 = y[i+1:end+1]
m.append((sum(pow((y1-c1),2))+sum(pow((y2-c2),2))))
index = m.index(min(m))+start#注意这里需要加start,否则会死循环
print("切分点为:",index)
print("大于",index,"的输出值为",np.average(y[start:index+1]))
print("小于",index,"的输出值为",np.average(y[index+1:end+1]))
cart_regression_tree(start,index,y)
cart_regression_tree(index+1,end,y)
else:
return None
if __name__ == '__main__':
x = np.arange(0,10)
y = [4.5,4.75,4.91,5.34,5.80,7.05,7.90,8.23,8.70,9.00]
cart_regression_tree(0,9,y)