牛客题解 | 决策树学习

题目

题目链接

决策树是一个用于分类和回归的模型,它通过将数据集分割成更小的子集来构建树形结构。每个内部节点代表一个特征的测试,每个分支代表测试结果,而每个叶子节点则表示最终的输出类别或值。

通俗点说,就是把一堆数据按照某个特征的某个阈值去分成两份或者多份子节点,然后递归执行这种分裂直到达到某种要求。
在本题中,只需要粗暴地对节点内部的所有特征进行尝试分裂,通过计算熵与信息增益来决定使用哪个阈值进行分裂,然后重复执行这一过程即可。
这里给出熵和信息增益的公式:
H ( S ) = − ∑ i = 1 c p i log ⁡ 2 ( p i ) H(S) = -\sum_{i=1}^{c} p_i \log_2(p_i) H(S)=i=1cpilog2(pi)
其中, H ( S ) H(S) H(S) 是熵, c c c 是类别的数量, p i p_i pi 是属于类别 i i i 的样本比例。

I G ( S , A ) = H ( S ) − ∑ v ∈ V a l u e s ( A ) ∣ S v ∣ ∣ S ∣ H ( S v ) IG(S, A) = H(S) - \sum_{v \in Values(A)} \frac{|S_v|}{|S|} H(S_v) IG(S,A)=H(S)vValues(A)SSvH(Sv)
其中, I G ( S , A ) IG(S, A) IG(S,A) 是信息增益, V a l u e s ( A ) Values(A) Values(A) 是属性 A A A 的所有可能取值, S v S_v Sv 是在属性 A A A 取值为 v v v 时的样本子集。

标准代码如下:

import math
from collections import Counter

def calculate_entropy(labels):
    label_counts = Counter(labels)
    total_count = len(labels)
    entropy = -sum(
        (count / total_count) * math.log2(count / total_count)
        for count in label_counts.values()
    )
    return entropy


def calculate_information_gain(examples, attr, target_attr):
    total_entropy = calculate_entropy([example[target_attr] for example in examples])
    values = set(example[attr] for example in examples)
    attr_entropy = 0
    for value in values:
        value_subset = [
            example[target_attr] for example in examples if example[attr] == value
        ]
        value_entropy = calculate_entropy(value_subset)
        attr_entropy += (len(value_subset) / len(examples)) * value_entropy
    return total_entropy - attr_entropy


def majority_class(examples, target_attr):
    return Counter([example[target_attr] for example in examples]).most_common(1)[0][0]


def learn_decision_tree(examples, attributes, target_attr):
    if not examples:
        return "No examples"
    if all(example[target_attr] == examples[0][target_attr] for example in examples):
        return examples[0][target_attr]
    if not attributes:
        return majority_class(examples, target_attr)

    gains = {
        attr: calculate_information_gain(examples, attr, target_attr)
        for attr in attributes
    }
    best_attr = max(gains, key=gains.get)
    tree = {best_attr: {}}

    for value in set(example[best_attr] for example in examples):
        subset = [example for example in examples if example[best_attr] == value]
        new_attributes = attributes.copy()
        new_attributes.remove(best_attr)
        subtree = learn_decision_tree(subset, new_attributes, target_attr)
        tree[best_attr][value] = subtree

    return tree

def print_tree(tree):
    outs = []
    for key, value in sorted(tree.items()):
        outs.append(f"{key}:{print_tree(value) if isinstance(value, dict) else value}")
    return "{" + ",".join(outs) + "}"

if __name__ == "__main__":
    examples = eval(input())
    attributes = eval(input())
    target_attr = eval(input())
    print(print_tree(learn_decision_tree(examples, attributes, target_attr)))
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值