import numpy as np
from machine_learning.lib.decision_tree_base import DecisionTreeBase
# 对下标属于集合idx的数据计算标签的方差
def get_var(y,idx):
y_avg=np.average(y[idx])*np.ones(len(idx))
return np.linalg.norm(y_avg-y[idx],2)**2/len(idx)
#定义决策树回归算法DecisionTreeRegressor,继承了DecisionTreeBase基类
class DecisionTreeRegressor(DecisionTreeBase):
def __init__(self,max_depth=0,feature_sasmple_rate=1.0):
super().__init__(
max_depth=max_depth,
feature_sample_rate=feature_sasmple_rate,
get_score=get_var
)
机器学习算法导论代码---decision_tree_regressor
最新推荐文章于 2025-03-18 11:21:19 发布
这篇博客介绍了如何实现决策树回归算法,特别是定义了一个名为`DecisionTreeRegressor`的类,该类继承自`DecisionTreeBase`。算法中使用`get_var`函数来计算指定数据集的标签方差,以辅助决策树节点的分裂过程。通过这种方式,提高了模型对数据变化的敏感性和预测准确性。
2万+

被折叠的 条评论
为什么被折叠?



