# -*- coding: utf-8 -*-
"""
Created on Tue Sep 5 16:18:15 2017
@author: piaodexin
"""
from sklearn import datasets
from sklearn import cross_validation
from sklearn.tree import DecisionTreeClassifier
from sklearn import metrics #可以展示混淆矩阵,
data=datasets.load_iris()
x=data.data
y=data.target
x_train,x_test,y_train,y_test=cross_validation.train_test_split(x,y,test_size=0.25,
random_state=0,stratify=y)
#确认模型
cart=DecisionTreeClassifier()
'''
DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
max_features=None, max_leaf_nodes=None,
min_impurity_split=1e-07, min_samples_leaf=1,
min_samples_split=2, min_weight_fraction_leaf=0.0,
presort=False, random_state=None, splitter='best')
'''
#训练模型
cart.fit(x_train,y_train)
cart.score(x_test,y_test)
#展示模型预测结果
print(metrics.classification_report(y_test,cart.predict(x_test)))
print(metrics.confusion_matrix(y_test,cart.predict(x_test)))
'''
print(metrics.classification_report(y_test,cart.predict(x_test)))
precision recall f1-score support
0 1.00 1.00 1.00 13
1 0.93 1.00 0.96 13
2 1.00 0.92 0.96 12
avg / total 0.98 0.97 0.97 38
print(metrics.confusion_matrix(y_test,cart.predict(x_test)))
[[13 0 0]
[ 0 13 0]
[ 0 1 11]]
'''