Scikit-Learn 机器学习笔记 -- 决策树

本文通过Scikit-Learn库实现决策树分类器,并利用鸢尾花数据集进行训练和预测。文中详细展示了如何加载数据、构建模型、绘制决策树及进行预测。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

Scikit-Learn 机器学习笔记 – 决策树


参考文档: handson-ml


import numpy as np


# 加载鸢尾花数据集
def load_dataset():
    from sklearn import datasets
    iris = datasets.load_iris()
    # print(iris)
    # 使用第3和第4个特征
    X = iris['data'][:, (2, 3)]
    # bool类型转为数值型
    y = iris['target']
    return X, y, iris


# 决策树分类器
def tree_classify(X, y):
    from sklearn.tree import DecisionTreeClassifier
    tree_clf = DecisionTreeClassifier(max_depth=2)
    tree_clf.fit(X, y)
    print(tree_clf.tree_)
    return tree_clf


# 绘制决策树图
def draw_tree(model, iris):
    from sklearn.tree import export_graphviz
    export_graphviz(
        model,
        out_file="iris_tree.dot",
        feature_names=iris.feature_names[2:],
        class_names=iris.target_names,
        rounded=True,
        filled=True
    )
    # import pydotplus
    # dot_data = export_graphviz(model, out_file=None)
    # graph = pydotplus.graph_from_dot_data(dot_data)
    # graph.write_pdf("iris.pdf")


# 预测
def tree_predict(model, sample):
    # 预测类别
    predict = model.predict(sample)
    # 属于各类别的概率
    predict_prob = model.predict_proba(sample)
    print('决策树预测类别为:', predict, '属于各类别的概率为:', predict_prob)


if __name__ == '__main__':
    # 加载数据集
    X, y, iris = load_dataset()
    # 创建决策树分类器
    tree_clf = tree_classify(X, y)
    # 绘制决策树
    # draw_tree(tree_clf, iris)
    # 预测
    tree_predict(tree_clf, [[5, 1.5]])

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值