CART(Classification and Regression Trees)算法,python快速实现

文章介绍了如何用Python实现CART算法,包括计算熵、信息增益以及构建决策树的过程。通过递归地选择最佳特征和阈值来分割数据集,最终建立决策树模型。该模型可用于分类任务,fit方法用于训练,predict方法用于预测。

CART(Classification and Regression Trees)算法,python快速实现

import numpy as np

class DecisionTree:
    def __init__(self):
        self.tree = {}

    def entropy(self, y):
        # 计算给定数据集的熵
        class_counts = np.bincount(y)
        probabilities = class_counts / len(y)
        entropy = np.sum([-p * np.log2(p) for p in probabilities if p > 0])
        return entropy

    def split(self, X, y, feature_index, threshold):
        # 根据给定特征和阈值拆分数据集
        left_mask = X[:, feature_index] <= threshold
        right_mask = ~left_mask
        return X[left_mask], X[right_mask], y[left_mask], y[right_mask]

    def information_gain(self, X, y, feature_index, threshold):
        # 计算信息增益
        parent_entropy = self.entropy(y)
        left_X, right_X, left_y, right_y = self.split(X, y, feature_index, threshold)

        # 计算左右子集的权重
        left_weight = len(left_y) / len(y)
        right_weight = len(right_y) / len(y)

        # 计算左右子集的熵
        left_entropy = self.entropy(left_y)
        right_entropy = self.entropy(right_y)

        # 计算信息增益
        information_gain = parent_entropy - (left_weight * left_entropy + right_weight * right_entropy)
        return information_gain

    def best_split(self, X, y):
        # 寻找最佳拆分点
        best_gain = 0
        best_feature_index = None
        best_threshold = None

        # 遍历所有特征
        for feature_index in range(X.shape[1]):
            unique_values = np.unique(X[:, feature_index])

            # 计算特征值的中点作为候选阈值
            thresholds = (unique_values[:-1] + unique_values[1:]) / 2

            # 在候选阈值中寻找最佳拆分点
            for threshold in thresholds:
                gain = self.information_gain(X, y, feature_index, threshold)

                # 更新最佳拆分点
                if gain > best_gain:
                    best_gain = gain
                    best_feature_index = feature_index
                    best_threshold = threshold

        return best_feature_index, best_threshold

    def build_tree(self, X, y):
        # 构建决策树
        if len(np.unique(y)) == 1:
            # 如果只有一个类别,返回叶节点
            return {'class': y[0]}

        if X.shape[1] == 0:
            # 如果没有特征可用,返回叶节点,使用多数投票决定类别
            class_counts = np.bincount(y)
            return {'class': np.argmax(class_counts)}

        # 寻找最佳拆分点
        best_feature_index, best_threshold = self.best_split(X, y)

        if best_feature_index is None or best_threshold is None:
            # 如果无法找到最佳拆分点,返回叶节点,使用多数投票决定类别
            class_counts = np.bincount(y)
            return {'class': np.argmax(class_counts)}

        # 拆分数据集
        left_X, right_X, left_y, right_y = self.split(X, y, best_feature_index, best_threshold)

        # 递归构建左右子树
        left_subtree = self.build_tree(left_X, left_y)
        right_subtree = self.build_tree(right_X, right_y)

        # 构建当前节点
        return {
            'feature_index': best_feature_index,
            'threshold': best_threshold,
            'left': left_subtree,
            'right': right_subtree
        }

    def fit(self, X, y):
        # 训练决策树模型
        self.tree = self.build_tree(X, y)

    def predict_instance(self, x, node):
        # 预测单个样本的类别
        if 'class' in node:
            return node['class']

        feature_value = x[node['feature_index']]
        if feature_value <= node['threshold']:
            return self.predict_instance(x, node['left'])
        else:
            return self.predict_instance(x, node['right'])

    def predict(self, X):
        # 预测数据集的类别
        return [self.predict_instance(x, self.tree) for x in X]

DecisionTree类实现了决策树分类算法。算法使用熵作为衡量不确定性的指标,并通过计算信息增益来选择最佳特征和阈值进行数据集拆分。递归地构建决策树,直到满足终止条件(例如,数据集只包含一个类别或没有可用特征)。fit方法用于训练模型,predict方法用于对新样本进行预测。

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值