import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn import datasets
from sklearn import cross_validation
import matplotlib.pyplot as plt
%matplotlib inline
from sklearn.metrics import f1_score,precision_score,recall_score,accuracy_score,classification_report
导入鸢尾花数据:
def load_data():
iris=datasets.load_iris()
X_train=iris.data
y_train=iris.target
return cross_validation.train_test_split(X_train,y_train,test_size=0.25,random_state=0,stratify=y_train)
def test_DecisionTreeClassifier(*data):
X_train,X_test,y_train,y_test=data
clf=DecisionTreeClassifier()
clf.fit(X_train,y_train)
y_predict=clf.predict(X_test)
print('training:%f'%(clf.score(X_train,y_train)))
print('test:%f'%(clf.score(X_test,y_test)))
print(classification_report(y_test,y_predict))
测试上一段