决策树二分类Demo&画树

博客提及了机器学习相关内容,但具体信息较少。机器学习是信息技术领域重要分支,可用于数据挖掘、预测分析等。

1.

# -*-  coding: utf-8 -*-
from sklearn.externals.six import StringIO
import pydotplus
from sklearn import tree
from sklearn.tree import _tree
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import roc_auc_score
import pandas as pd
import joblib
import sys
import os


node_id = 0
max_leaf_nodes = 100
fw = open('/data/liupg/rule_max_leaves_' + str(max_leaf_nodes) + '.txt', 'a')


def draw_tree(model, name, features):
    dot_data = StringIO()
    tree.export_graphviz(model, out_file=dot_data, feature_names=features,
                         rounded=True, filled=True, proportion=True, precision=6)
    graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
    graph.write_pdf(name + ".pdf")


def tree_to_code(tree, feature_names):
    global fw
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]
    print('feature_name:', feature_name)
    print("def tree({}):".format(", ".join(feature_names)))
    fw.write("def tree({}):".format(", ".join(feature_names)) + '\n')

    def recurse(node, depth):
        global node_id
        indent = "  " * depth
        # print('tree_.feature:',tree_.feature)
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            # print('tree_.feature[node]:',tree_.feature[node])
            name = feature_name[node]
            threshold = tree_.threshold[node]
            print("{}if {} <= {}:".format(indent, name, threshold)) 
            fw.write("{}if {} <= {}:".format(indent, name, threshold) + '\n')
            recurse(tree_.children_left[node], depth + 1)
            print("{}else:  # if {} > {}".format(indent, name, threshold))
            fw.write("{}else:  # if {} > {}".format(indent, name, threshold) + '\n')
            recurse(tree_.children_right[node], depth + 1)
        else:
            print("{}return {}".format(indent, tree_.value[node]))
            fw.write("{}return {}".format(indent, node_id) + '\n')
            node_id += 1

    recurse(0, 1)


feature_names = ['feature1', 'feature2', 'feature3', 'feature4']


if __name__ == '__main__':
    dt = DecisionTreeClassifier(max_leaf_nodes=10,max_depth=10,min_samples_split =10000,min_samples_leaf=10000)
    # dt = DecisionTreeClassifier(max_leaf_nodes=10,
    #     max_depth=10,
    #     min_samples_split=10000,
    #     min_samples_leaf=10000,
    #     criterion='entropy',
    #     random_state=0)
    train = pd.read_csv('train.tsv', sep='\t')
    test = pd.read_csv('test.tsv', sep='\t')
    X_tr, X_te = train.loc[:, feature_names], test.loc[:, feature_names]
    y_tr, y_te = train.loc[:, 'label'], test.loc[:, 'label']
    dt.fit(X_tr, y_tr)
    # dt.fit(X_tr, y_tr, sample_weight=weights)
    # y_pred = dt.predict(X_tr)
    tr_score = dt.predict_proba(X_tr)[:, 1]
    te_score = dt.predict_proba(X_te)[:, 1]
    print("train auc = %.4f" % roc_auc_score(y_true=y_tr, y_score=tr_score))
    print("test auc = %.4f" % roc_auc_score(y_true=y_te, y_score=te_score))
    draw_tree(dt, '../../reports/figures/cart', feature_names)
    # joblib.dump(dt, dt_path)
    # dt = joblib.load(dt_path)
    # tree_to_code(dt, feature_names)
    fw.close()

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值