Scikit-Learn与回归树

本文介绍了CART算法在回归任务中的应用,包括最小剩余方差法来构建决策树,并讨论了模型树的概念。通过Scikit-Learn库实现CART回归树,展示了如何利用该算法进行预测分析。

回归算法原理

CART(Classification and Regression Tree)算法是目前决策树算法中最为成熟的一类算法,应用范围也比较广泛。它既可以用于分类。
西方预测理论一般都是基于回归的,CART是一种通过决策树方法实现回归的算法,它具有很多其他全局回归算法不具有的特性。
在创建回归模型时,样本的取值分为观察值和输出值两种,观察值和输出值都是连续的,不像分类函数那样有分类标签,只有根据数据集的数据特征来创建一个预测的模型,反映曲线的变化趋势。在这种情况下,原有分类树的最优划分规则就不再起作用了。在预测中,CART使用最小剩余方差(Squared Residuals Minimization)来判定回归树的最优划分,这个准则期望划分之后的子树与样本点的误差方差最小。这样决策树将数据集划分成很多子模型数据,然后利用线性回归技术来建模。如果每次切分后的数据子集仍然难以拟合,就继续切分。在这种切分方式下创建出的预测树,每个叶子节点都是一个线性回归模型。这些线性回归模型反映了样本集合(观测集合)中蕴含的模式,也被称为模型树。因此,CART不仅支持正体预测,也支持局部模式的预测,并有能力从整体中找到模式,或根据模式组合成一个整体。整体与模式之间的相互结合,对于预测分析有重要价值。因此CART决策树算法在预测中的应用非常广泛。
下面介绍CART的算法流程:
(1)决策树主函数:决策树的主函数是一个递归函数。该函数的主要功能是按照CART的规则生长出决策树的每个分支节点,并根据终止条件结束算法。
a.输入需要分类的数据集和类别标签。
b.使用最小剩余方差判定回归树的最优划分,并创建特征的划分节点——最小剩余方差子函数。
c.在划分节点划分数据集为两部分——二分数据集子函数。
d.根据二分数据的结果构建出新的左右节点,作为树生长出的两个分支。
e.检验是否符合递归的终止条件。
f.将划分的新节点包含的数据集和类别标签作为输入,递归执行上述步骤。
(2)使用最小剩余方差子函数,计算数据集各列的最优划分方差、划分列、划分值
(3)二分数据集:根据给定的分隔列和分隔值将数据集一分为二,分别返回。

最小剩余方差法

在回归树中,数据集均为连续性。连续数据的处理方法与离散数据不同,离散数据是按每个特征的取值划分,而连续特征则要计算出最优划分点。但在连续数据集上计算线性相关度非常简单,算法思想来源于最小二乘法。
最小剩余方差法,首先求取划分数据列的均值和总方差。总方差的计算方法有两种
求取均值std,计算每个数据点与std的方差,然后将n个点求和
求取方差var,然后var_sum = var*n,n为数据集数据数目。
那么,每次最佳分支特征的选取过程如下。
(1)先令最佳方差为无限大 bestVar = inf。
(2)此次遍历所有特征列及每个特征列的所有样本点(这是一个二循环),在每个样本点上二分数据集。
(3)计算二分数据集后的总方差currentVar,如果currentVar < bestVar,则bestVar = currentVar。
返回计算的最优分支特征列、分支特征值(连续特征则为划分点的值)以及左右分支子数据集到主程序。

模型树

使用CART进行预测是把叶子节点设定为一系列的分段线性函数,这些分段线性函数是对源数据曲线的一种模拟,每个线性函数都被称为一颗模型树。模型树具有很多优秀的性质,它包含了如下特征。
一般而言,样本总体的重复性不会很高,但局部模式经常重复,也就是所说的历史不会简单的重复,但会重演。模型比总体对未来的预测而言更有用。
模型给出了数据的范围,它可能是一个时间范围,也可能是一个空间范围;而且模型还给出了变化的趋势,可以是曲线,也可以是直线,这依赖于使用的回归算法。这些因素使模型具有很强的可解释性。
传统的回归方法,无论是线性回归还是非线性回归,都不如模型树包含的信息丰富,因此模型树具有更高的预测准确度。

Scikit-Learn实现

#!/usr/bin/python
# created by lixin 20161118
import numpy as np
from numpy import *
from sklearn.tree import DecisionTreeRegressor
import matplotlib.pyplot as plt


def plotfigure(X,X_test,y,yp):
        plt.figure()
        plt.scatter(X,y,c="k",label="data")
        plt.plot(X_test,yp,c="r",label="max_depth=5",linewidth=2)
        plt.xlabel("data")
        plt.ylabel("target")
        plt.title("Decision Tree Regression")
        plt.legend(loc='upper right')
        plt.show()
        #plt.savefig('./res.png', format='png')


x = np.linspace(-5,5,200)
siny = np.sin(x)
X = mat(x).T
y = siny + np.random.rand(1,len(siny))*1.5
y = y.tolist()[0]
clf = DecisionTreeRegressor(max_depth=4)
clf.fit(X,y)

X_test = np.arange(-5.0,5.0,0.05)[:,np.newaxi
yp = clf.predict(X_test)

plotfigure(X,X_test,y,yp)

图1

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值