CART算法中的分类树采用基尼系数的方法来划分特征。而回归树则采用最小二乘法,生成最小二乘回归树。
一:如何选择最优切分点?
- 对每一个特征中相邻的数据取均值,作为候选切分点。假设特征有a个取值,则有a - 1 个候选切分点。
- 然后针对每个切分点,将该特征的数据分成两部分,r1和r2。
- 计算两部分中数据的均值c1和c2。
- 对两部分做最小二乘。损失函数为为(y - 均值)^2,再求和。
- 将两部分最小二乘的结果相加,得到每个候选切分店的损失函数。
- 找出损失函数最小的切分点,作为该特征的切分点。
- 以此类推,进行特征切分。直到结束。
回归树把整个空间切分成不相交的子区域,每个区域预估为这个区域的平均值。
二:实例
为了便于理解,下面举一个简单实例。训练数据见下表,目标是得到一棵最小二乘回归树。
x | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 |
---|---|---|---|---|---|---|---|---|---|---|
y | 5.56 | 5.7 | 5.91 | 6.4 | 6.8 | 7.05 | 8.9 | 8.7 | 9 | 9.05 |
- 选择最优切分变量j与最优切分点s。
在本数据集中,只有一个变量,因此最优切分变量自然是x。
接下来我们考虑9个切分点[1.5,2.5,3.5,4.5,5.5,6.5,7.5,8.5,9.5]
损失函数定义为平方损失函数 Loss(y,f(x))=(f(x)−y)2,将上述9个切分点依次 代入下面的公式,其中 c_m为每个部分所有数据的均值。
例如,取 s=1.5。此时 R1={1},R2={2,3,4,5,6,7,8,9,10},这两个区域的输出值分别为:
c1=5.56,c2=(5.7+5.91+6.4+6.8+7.05+8.9+8.7+9+9.05)/9=7.50。得到下表:
s | 1.5 | 2.5 | 3.5 | 4.5 | 5.5 | 6.5 | 7.5 | 8.5 | 9.5 | |
---|---|---|---|---|---|---|---|---|---|---|
c1 | 5.56 | 5.63 | 5.72 | 5.89 | 6.07 | 6.24 | 6.62 | 6.88 | 7.11 | |
c2 | 7.5 | 7.73 | 7.99 | 8.25 | 8.54 | 8.91 | 8.92 | 9.03 | 9.05 |
把c1,c2的值代入到上式,如:m(1.5)=0+15.72=15.72。同理,可获得下表:
s | 1.5 | 2.5 | 3.5 | 4.5 | 5.5 | 6.5 | 7.5 | .5 | 9.5 |
---|---|---|---|---|---|---|---|---|---|
m(s) | 15.72 | 12.07 | 8.36 | 5.78 | 3.91 | 1.93 | 8.01 | 11.73 | 15.74 |
显然取 s=6.5时,m(s)最小。因此,第一个划分变量j=x,s=6.5
用选定的(j,s)划分区域,并决定输出值
两个区域分别是:R1={1,2,3,4,5,6},R2={7,8,9,10}输出值1=6.24,c2=8.91
对R1继续进行划分:
x | 1 | 2 | 3 | 4 | 5 | 6 | |
---|---|---|---|---|---|---|---|
y | 5.56 | 5.7 | 5.91 | 6.4 | 6.8 | 7.05 |
取切分点[1.5,2.5,3.5,4.5,5.5],则各区域的输出值c如下表
s | 1.5 | 2.5 | 3.5 | 4.5 | 5.5 |
---|---|---|---|---|---|
c1 | 5.56 | 5.63 | 5.72 | 5.89 | 6.07 |
c2 | 6.37 | 6.54 | 6.75 | 6.93 | 7.05 |
计算m(s):
s | 1.5 | 2.5 | 3.5 | 4.5 | 5.5 |
---|---|---|---|---|---|
m(s) | 1.3087 | 0.754 | 0.2771 | 0.4368 | 1.0644 |
s=3.5时m(s)最小。
之后的过程不再赘述。
假设在生成3个区域之后停止划分,那么最终生成的回归树形式如下:
T = 5.72(x < 3.5)
T = 6.75(3.5<=x<6.5)
T = 8.91(x > 6.5))
回归树实际上是一个自顶向下的贪婪式切分方法。从所有样本开始,不断把当前样本划分到两个分支里。每一次的划分,只考虑当前最优,不考虑整体最优。
选择切分的维度(特征) x j x_j xj以及切点s使得划分后树的RSS最小:
R 1 ( j , s ) = { x ∣ x j < s } R 2 ( j , s ) = { x ∣ x j ≥ s } R S S = ∑ x i ∈ R 1 ( j