通过源码学算法--AdaBoost (CART) -- train.m

本文深入解析了CART算法中用于构建决策树的过程,包括如何通过递归方式划分数据集、选择最优分裂点以及评估分裂效果的方法。介绍了算法的核心概念如错误率计算、节点纯度改进等。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

train.m

第一步,划分根节点root

对输入的数据(特征集,分类标签,权重)用CART算法进行划分,得到左右子节点。

传入的node已经设定最大层数为3,所以很快就结束,没有CART通常包含的pruning, optimal selection 及 cross-validation 等步骤

注意:这个Matlab包里所有的传入node都只是用来传递一个空的tree_node节点,相当于构造函数,不会修改也不会返回

max_split = node.max_split;

[left right spit_error] = do_learn_nu(node, dataset, labels, weights);

nodes = {left, right};

这里其实不是很明白:看起来应该分别计算左右节点的错误率,但left_pos不应该是true positive rate 吗?为什么和left_neg在一起取小??

left_pos  = sum((calc_output(left , dataset) == labels) .* weights);
left_neg  = sum((calc_output(left , dataset) == -labels) .* weights);
right_pos = sum((calc_output(right, dataset) == labels) .* weights);
right_neg = sum((calc_output(right, dataset) == -labels) .* weights);

errors = [min(left_pos, left_neg), min(right_pos, right_neg)];

总而言之,这个errors是衡量划分左右节点的标准。从最大的开始遍历

[errors, IDX] = sort(errors);
errors = flipdim(errors,2);
IDX    = flipdim(IDX,2);
nodes  = nodes(IDX);

分叉时用到的临时变量

splits = [];
split_errors = [];
deltas = [];

从errors值最大的叶节点开始遍历

用同样的方法对该node划分,但是只使用属于该节点的数据

重复上面找最大errors的步骤,对同一层的所有节点做同样的操作

比较所有节点的划分结果,最佳分叉是误差值最大的(即减少该节点数据不纯度增幅最大的那种划分)

英文也有点绕:maximises the decrease in node impurity

最后用该节点的左右子节点代替该节点,如果还没到max_split,进入下一轮


for i = 2 : max_split
    for j = 1 : length(errors)
        
        if(length(deltas) >= j)
            continue;
        end
        
        max_node = nodes{j};
        max_node_out = calc_output(max_node, dataset);
       
        mask = find(max_node_out == 1);  
       
        [left right spit_error] = do_learn_nu(node, dataset(:,mask), labels(mask), weights(mask), max_node);
              
        
        left_pos  = sum((calc_output(left , dataset) == labels) .* weights);
        left_neg  = sum((calc_output(left , dataset) == -labels) .* weights);
        right_pos = sum((calc_output(right, dataset) == labels) .* weights);
        right_neg = sum((calc_output(right, dataset) == -labels) .* weights);
        
        splits{end+1} = left;
        splits{end+1} = right;  
        
        if( (right_pos + right_neg) == 0 || (left_pos + left_neg) == 0)
          deltas(end+1) = 0;
        else
          deltas(end+1) = errors(j) - spit_error;
        end
        
        split_errors(end+1) = min(left_pos, left_neg);
        split_errors(end+1) = min(right_pos, right_neg);
    end  
    
    if(max(deltas) == 0)
        return;
    end
    best_split = find(deltas == max(deltas));
    best_split = best_split(1);
    
    cut_vec = [1 : (best_split-1)  (best_split + 1) : length(errors)];
    nodes   = nodes(cut_vec);
    errors  = errors(cut_vec);
    deltas  = deltas(cut_vec);
    
    nodes{end+1} = splits{2 * best_split - 1};
    nodes{end+1} = splits{2 * best_split};
    
    errors(end+1) = split_errors(2 * best_split - 1);
    errors(end+1) = split_errors(2 * best_split);
    
    cut_vec = [1 : 2 * (best_split-1)  2 * (best_split)+1 : length(split_errors)];
    split_errors = split_errors(cut_vec);    
    splits       = splits(cut_vec);

end


这样的话,根节点A第一次划分,nodes就有两个节点{B,C};

其中一个(假设是C)进行第二次划分得到{D,E},这时nodes有{B,D,E}三个节点

再来一次到max_split的时候nodes正好返回四个节点的CART树


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值