决策树

决策树的实现代码

class DecisionNode(object):
    def __init__(self, feature_i=None, threshold=None,
                 value = None, true_branch=None, false_branch=None):
        self.feature_i = feature_i
        self.threshold = threshold
        self.value = value
        self.true_branch = true_branch
        self.false_branch = false_branch


class DecisionTree(object):
    def __init__(self, min_sample_split=2, min_impurity=1e-7,
                 max_depth=float("inf")):

        self.root = None
        self.min_sample_split = min_sample_split
        self.min_impurity = min_impurity
        self.max_depth = max_depth
        ### Function to calculate impurity
        self._impurity_caculation = None
        ### Function to determine value of leaf node
        self._leaf_value_caculation = None

    def fit(self, X, y):
        self.root = self._build_tree(X,y)

    def _build_tree(self, X, y, current_depth=0):
        largest_impurity = 0
        best_criteria = None   # Feature index and threshold
        best_sets = None       # Subsets of the data

        X_y = np.concatenate((X, y), axis=1)
        n_samples , n_features = np.shape(X)

        if n_samples >= self.min_sample_split and current_depth <=self.max_depth:
            for feature_i in range(n_features):
                unique_values = np.unique(feature_values)

                for threshold in unique_values:
                    Xy1, Xy2 = divide_on_feature(X_y, feature_i, threshold)

                    y1 = Xy1[:, n_features:]
                    y2 = Xy2[:, n_features:]

                    impurity = self._impurity_caculation(y , y1, y2)

                    if impurity > largest_impurity:
                        largest_impurity = impurity
                        best_criteria = {'feature_i': feature_i,
                                         'threshold': threshold}
                        best_sets = {'leftX': Xy1[:, :n_features],
                                     'lefty': Xy1[:, n_features:],
                                     'rightX': Xy2[:, :n_features],
                                     'righty': Xy2[:,n_features:]}
        if largest_impurity>self.min_impurity:
            true_branch = self._build_tree(best_sets['leftX'], best_sets['lefty'],current_depth+1)
            false_branch = self._build_tree(best_sets['rightX'], best_sets['righty'],current_depth+1)
            return  DecisionNode(feature_i = best_criteria['feature_i'],
                                 threshold = best_criteria['threshold'],
                                 value = None,
                                 true_branch=true_branch,
                                 false_branch=false_branch)

        def predict_value(self, x, tree=None):
            if tree is None:
                tree = self.root
            if tree.value is not None:
                return tree.value

            feature_value = x[tree.feature_i]

            branch = tree.false_branch
            if isinstance(feature_value ,int) or isinstance(feature_value, float):
                if feature_value >= tree.threshold:
                    branch= tree.true_branch
            elif feature_value == tree.threshold:
                branch = tree.true_branch


            return self.predict_value(x, branch)

        def predict(self, X):
            y_pred = []
            for x in X:
                y_pred.append(self.predict_value(x))
            return y_pred

可以根据ID3还是CART自定义self._impurity_calculation函数体,并继承上述类

class ClassificationTree(DecisionTree):
    #### here is ID3
    def _calculate_information_gain(self, y, y1, y2):
        # Calculate information gain
        p = len(y1) / len(y)
        entropy = calculate_entropy(y)
        info_gain = entropy - p * \
            calculate_entropy(y1) - (1 - p) * \
            calculate_entropy(y2) ### entropy calculation omitted 

        return info_gain

    def _majority_vote(self, y):
        most_common = None
        max_count = 0
        for label in np.unique(y):
            # Count number of occurences of samples with label
            count = len(y[y == label])
            if count > max_count:
                most_common = label
                max_count = count
        return most_common

    def fit(self, X, y):
        self._impurity_calculation = self._calculate_information_gain
        self._leaf_value_calculation = self._majority_vote
        super(ClassificationTree, self).fit(X, y)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值