数据集下载地址为:www.kaggle.com/c/titanic.
本文通过泰坦尼克号数据集及使用决策树模型来熟悉sklearn相关类的使用,并给出以下例子:
1. 首先将数据集进行数据清洗,然后训练决策树模型并可视化该决策树。
2. 分析不同深度、不同阈值对决策树的影响。
3. 使用GridSearchCV类来选择决策树的最佳参数
1. 数据预处理
这里最值得学习的是,乘客登船的港口Embarked原本为字母“S、C、Q”,这里需要将S、C、Q序列化为数值类型,例如S=0,C=1,Q=2
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
def read_dataset(fname):
# 指定第一列作为行索引
data = pd.read_csv(fname, index_col=0)
# 丢弃无用的数据
data.drop(['Name', 'Ticket', 'Cabin'], axis=1, inplace=True)
# 处理性别数据
data['Sex'] = (data['Sex'] == 'male').astype('int')
# 处理登船港口数据
labels = data['Embarked'].unique().tolist()
data['Embarked'] = data['Embarked'].apply(lambda n: labels.index(n))
# 处理缺失数据
data = data.fillna(0)
return data
train = read_dataset('datasets/titanic/train.csv')
train.head()
2.模型训练
这里可以看到模型对训练样本评分高,但对测试集评分低,这是过拟合的现象。
# 划分数据集
from sklearn.model_selection import train_test_split
y = train['Survived'].values
X = train.drop(['Survived'], axis=1).values
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
print('train dataset: {0}; test dataset: {1}'.format(
X_train.shape, X_test.shape))
'''
train dataset: (712, 7); test dataset: (179, 7)
'''
# 模型训练
from sklearn.tree import DecisionTreeClassifier
clf = DecisionTreeClassifier()
clf.fit(X_train, y_train)
train_score = clf.score(X_train, y_train)
test_score = clf.score(X_test, y_test)
print('train score: {0}; test score: {1}'.format(train_score, test_score))
'''
train score: 0.984550561798; test score: 0.77094972067
'''
3. 可视化决策树
from sklearn.tree import export_graphviz
import pydotplus
with open("titanic.dot", 'w') as f:
f = export_graphviz(clf, out_file=f)
# 画图,保存到pdf文件或png文件
# 设置图像参数
dot_data = export_graphviz(clf, out_file=None,
# feature_names的顺序必须按照放入模型时数据集特征x1,x2,...,xn的
feature_names=train.drop(['Survived'], axis=1).columns.values.tolist(),
# class_name 的顺序必须按照数据集标签的升序排列,这里[unsurvied=0,survived=1]
class_names=['unsurvived','survived'],
filled=True, rounded=True,
special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data)
# 保存图像到pdf文件
graph.write_pdf("tree.pdf")
# 保存图像到png文件
# graph.write_png("tree.png")
由于决策树图片过大,这里只截取一部分
4. 不同深度对决策树的影响
决策树模型里面有一个参数max_depth控制决策树的最大深度,不同深度将会对模型产生不同的影响。
下面通过遍历不同的深度参数来训练决策树,并画出模型的socre曲线。
# 参数选择 max_depth
def cv_score(d):
clf = DecisionTreeClassifier(max_depth=d)
clf.fit(X_train, y_train)
tr_score = clf.score(X_train, y_train)
cv_score = clf.score(X_test, y_test)
return (tr_score, cv_score)
depths = range(2, 15)
scores = [cv_score(d) for d in depths]
tr_scores = [s[0] for s in scores]
cv_scores = [s[1] for s in scores]
# 选出最佳深度及socre
best_score_index = np.argmax(cv_scores)
best_score = cv_scores[best_score_index]
best_param = depths[best_score_index]
print('best param: {0}; best score: {1}'.format(best_param, best_score))
# 画出不同深度对应的score
plt.figure(figsize=(10, 6), dpi=144)
plt.grid()
plt.xlabel('max depth of decision tree')
plt.ylabel('score')
plt.plot(depths, cv_scores, '.g-', label='cross-validation score')
plt.plot(depths, tr_scores, '.r--', label='training score')
plt.legend()
5. 不同阈值对决策树的影响
如果节点的分裂导致的不纯度的下降程度大于或者等于这个节点的值,那么这个节点将会被分裂。
决策树的参数min_impurity_split用于指定信息熵或基尼不纯度的阈值,当决策树分裂后,其信息增益低于这个阈值时,则不再分裂。
# 训练模型,并计算评分
def cv_score(val):
clf = DecisionTreeClassifier(criterion='gini', min_impurity_split=val)
clf.fit(X_train, y_train)
tr_score = clf.score(X_train, y_train)
cv_score = clf.score(X_test, y_test)
return (tr_score, cv_score)
# 指定参数范围,分别训练模型,并计算评分
values = np.linspace(0, 0.5, 50)
scores = [cv_score(v) for v in values]
tr_scores = [s[0] for s in scores]
cv_scores = [s[1] for s in scores]
# 找出评分最高的模型参数
best_score_index = np.argmax(cv_scores)
best_score = cv_scores[best_score_index]
best_param = values[best_score_index]
print('best param: {0}; best score: {1}'.format(best_param, best_score))
# 画出模型参数与模型评分的关系
plt.figure(figsize=(10, 6), dpi=144)
plt.grid()
plt.xlabel('threshold of gini')
plt.ylabel('score')
plt.plot(values, cv_scores, '.g-', label='cross-validation score')
plt.plot(values, tr_scores, '.r--', label='training score')
plt.legend()
6. 使用GridSearchCV类来寻找最佳参数
DecisionTreeClassifier()默认使用criterion=’gini’。
在criterion=’gini’的情况下,使用GridSearchCV类来寻找最佳min_impurity_split。
def plot_curve(train_sizes, cv_results, xlabel):
train_scores_mean = cv_results['mean_train_score']
train_scores_std = cv_results['std_train_score']
test_scores_mean = cv_results['mean_test_score']
test_scores_std = cv_results['std_test_score']
plt.figure(figsize=(10, 6), dpi=144)
plt.title('parameters turning')
plt.grid()
plt.xlabel(xlabel)
plt.ylabel('score')
plt.fill_between(train_sizes,
train_scores_mean - train_scores_std,
train_scores_mean + train_scores_std,
alpha=0.1, color="r")
plt.fill_between(train_sizes,
test_scores_mean - test_scores_std,
test_scores_mean + test_scores_std,
alpha=0.1, color="g")
plt.plot(train_sizes, train_scores_mean, '.--', color="r",
label="Training score")
plt.plot(train_sizes, test_scores_mean, '.-', color="g",
label="Cross-validation score")
plt.legend(loc="best")
from sklearn.model_selection import GridSearchCV
thresholds = np.linspace(0, 0.5, 50)
# Set the parameters by cross-validation
param_grid = {'min_impurity_split': thresholds}
clf = GridSearchCV(DecisionTreeClassifier(), param_grid, cv=5) # DecisionTreeClassifier()默认使用criterion='gini'
clf.fit(X, y)
print("best param: {0}\nbest score: {1}".format(clf.best_params_,
clf.best_score_))
plot_curve(thresholds, clf.cv_results_, xlabel='gini thresholds')
使用GridSearchCV类同时寻找最佳的min_impurity_split、criterion。
from sklearn.model_selection import GridSearchCV
entropy_thresholds = np.linspace(0, 1, 50)
gini_thresholds = np.linspace(0, 0.5, 50)
# Set the parameters by cross-validation
param_grid = [{'criterion': ['entropy'],
'min_impurity_split': entropy_thresholds},
{'criterion': ['gini'],
'min_impurity_split': gini_thresholds},
{'max_depth': range(2, 10)},
{'min_samples_split': range(2, 30, 2)}]
clf = GridSearchCV(DecisionTreeClassifier(), param_grid, cv=5)
clf.fit(X, y)
print("best param: {0}\nbest score: {1}".format(clf.best_params_,
clf.best_score_))
'''
best param: {'min_impurity_split': 0.55102040816326525, 'criterion': 'entropy'}
best score: 0.822671156004
'''