决策树-泰坦尼克号生还预测

本文介绍如何使用决策树模型预测泰坦尼克号乘客生存情况。通过处理数据缺失值、类别特征编码等步骤,实现模型训练并评估其准确度。
LR和SVM都在某种程度上要求被学习的数据特征和目标之间遵照线性假设。然后许多现实场景下,这种假设不存在。
比如根据年龄预测流感的死亡率,如果用线性模型假设,那只有两个可能:年龄越大/越小,死亡率越高。根据经验,青壮年更不容易因患流感而死亡。年龄和因流感的死亡不存在线性关系。
在机器学习模型中,决策树是描述非线性关系的不二之选。
信用卡申请的审核,涉及多项特征,是典型的决策树模型。对于是否同意申请,是二分类决策任务,只有yes/no两种分类结果。
使用多种不同特征组合搭建多层决策树的情况,模型在学习的时候需要考虑特征节点的选取顺序。常用的方式包括信息熵(Information Gain)和基尼不纯性(Gini Impurity)。本文不做讨论。sklearn中默认配置的决策树模型使用的是Gini impurity作为排序特征的度量指标。
虽然很难获取信用卡的客户资料,但有类似的借助客户档案进行二分类的任务。
本文进行泰坦尼克号的乘客的生还预测,许多专家尝试通过计算机模拟和分析找出隐藏在数据背后的生还逻辑。

Python源码:
#coding=utf-8
import pandas as pd
#-------------data split
from sklearn.cross_validation import train_test_split
#-------------feature transfer
from sklearn.feature_extraction import DictVectorizer
#-------------
from sklearn.tree import DecisionTreeClassifier
#-------------
from sklearn.metrics import classification_report

#-------------download data
titanic=pd.read_csv('http://biostat.mc.vanderbilt.edu/wiki/pub/Main/DataSets/titanic.txt')
print titanic.head()
#transfer to dataFrame format by pandas,use info() to show statistics of data
print titanic.info()
#-------------feature selection
X=titanic[['pclass','age','sex']]
y=titanic['survived']

print 'bf processing\n',X.info()
#-------------feature processing
X['age'].fillna(X['age'].mean(),inplace=True)
print 'af processing\n',X.info
#-------------data split
#75% training set,25% testing set
X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.25,random_state=33)
#-------------feature transfer  from String to int
vec=DictVectorizer(sparse=False)
X_train=vec.fit_transform(X_train.to_dict(orient='record'))
#print vec.feature_names  60
#AttributeError: 'DictVectorizer' object has no attribute 'feature_names'
print vec.get_feature_names()
X_test=vec.transform(X_test.to_dict(orient='record'))
#-------------training
#initialize
dtc=DecisionTreeClassifier()
dtc.fit(X_train,y_train)
y_predict=dtc.predict(X_test)
#-------------performance
print 'The Accuracy is',dtc.score(X_test,y_test)
print classification_report(y_test,y_predict,target_names=['died','survived'])

Result:
   row.names pclass  survived  \
0          1    1st         1
1          2    1st         0
2          3    1st         0
3          4    1st         0
4          5    1st         1


                                              name      age     embarked  \
0                     Allen, Miss Elisabeth Walton  29.0000  Southampton
1                      Allison, Miss Helen Loraine   2.0000  Southampton
2              Allison, Mr Hudson Joshua Creighton  30.0000  Southampton
3  Allison, Mrs Hudson J.C. (Bessie Waldo Daniels)  25.0000  Southampton
4                    Allison, Master Hudson Trevor   0.9167  Southampton


                         home.dest room      ticket   boat     sex
0                     St Louis, MO  B-5  24160 L221      2  female
1  Montreal, PQ / Chesterville, ON  C26         NaN    NaN  female
2  Montreal, PQ / Chesterville, ON  C26         NaN  (135)    male
3  Montreal, PQ / Chesterville, ON  C26         NaN    NaN  female
4  Montreal, PQ / Chesterville, ON  C22         NaN     11    male
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1313 entries, 0 to 1312
Data columns (total 11 columns):
row.names    1313 non-null int64
pclass       1313 non-null object
survived     1313 non-null int64
name         1313 non-null object
age          633 non-null float64
embarked     821 non-null object
home.dest    754 non-null object
room         77 non-null object
ticket       69 non-null object
boat         347 non-null object
sex          1313 non-null object
dtypes: float64(1), int64(2), object(8)
memory usage: 112.9+ KB
None
bf processing
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1313 entries, 0 to 1312
Data columns (total 3 columns):
pclass    1313 non-null object
age       633 non-null float64
sex       1313 non-null object
dtypes: float64(1), object(2)
memory usage: 30.8+ KB
None
/Users/mac/workspace/conda/anaconda/lib/python2.7/site-packages/pandas/core/generic.py:3660: SettingWithCopyWarning:
A value is trying to be set on a copy of a slice from a DataFrame


See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy
  self._update_inplace(new_data)
af processing
<bound method DataFrame.info of      pclass        age     sex
0       1st  29.000000  female
1       1st   2.000000  female
2       1st  30.000000    male
3       1st  25.000000  female
4       1st   0.916700    male
5       1st  47.000000    male
6       1st  63.000000  female
7       1st  39.000000    male
8       1st  58.000000  female
9       1st  71.000000    male
10      1st  47.000000    male
11      1st  19.000000  female
12      1st  31.194181  female
13      1st  31.194181    male
14      1st  31.194181    male
15      1st  50.000000  female
16      1st  24.000000    male
17      1st  36.000000    male
18      1st  37.000000    male
19      1st  47.000000  female
20      1st  26.000000    male
21      1st  25.000000    male
22      1st  25.000000    male
23      1st  19.000000  female
24      1st  28.000000    male
25      1st  45.000000    male
26      1st  39.000000    male
27      1st  30.000000  female
28      1st  58.000000  female
29      1st  31.194181    male
...     ...        ...     ...
1283    3rd  31.194181  female
1284    3rd  31.194181    male
1285    3rd  31.194181    male
1286    3rd  31.194181    male
1287    3rd  31.194181    male
1288    3rd  31.194181    male
1289    3rd  31.194181    male
1290    3rd  31.194181    male
1291    3rd  31.194181    male
1292    3rd  31.194181    male
1293    3rd  31.194181  female
1294    3rd  31.194181    male
1295    3rd  31.194181    male
1296    3rd  31.194181    male
1297    3rd  31.194181    male
1298    3rd  31.194181    male
1299    3rd  31.194181    male
1300    3rd  31.194181    male
1301    3rd  31.194181    male
1302    3rd  31.194181    male
1303    3rd  31.194181    male
1304    3rd  31.194181  female
1305    3rd  31.194181    male
1306    3rd  31.194181  female
1307    3rd  31.194181  female
1308    3rd  31.194181    male
1309    3rd  31.194181    male
1310    3rd  31.194181    male
1311    3rd  31.194181  female
1312    3rd  31.194181    male

[1313 rows x 3 columns]>
['age', 'pclass=1st', 'pclass=2nd', 'pclass=3rd', 'sex=female', 'sex=male']
The Accuracy is 0.781155015198
             precision    recall  f1-score   support

       died       0.78      0.91      0.84       202
   survived       0.80      0.58      0.67       127

avg / total       0.78      0.78      0.77       329


该数据共有1313条乘客信息,有些特征数据是缺失的,有些是数值类型,有些是字符串。
预处理环节中特征的选择十分重要,需要一些背景知识,根据对事故的了解,sex,age,pclass都可能是关键因素。
需要完成的数据处理任务:
1.初始的数据中,age列只有633个需要补充完整,一般,使用平均数或者中位数都是对模型偏离造成最小影响的策略。
2.sex和pclass列的值是列别型,需转化为数值特征,用0/1代替
算法特点:
相对于其它的模型,决策树在模型描述上有巨大的优势,推断逻辑非常直观,具有清晰的可解释性,也方便了模型的可视化。这些特征同时也保证使用该模型时,无需考虑对数据的量化甚至标准化的。与KNN不同,DT仍然属于有参数模型,需花费更多的时间在训练数据上。


### 使用ID3算法进行泰坦尼克号生存预测 #### ID3算法简介 ID3(Iterative Dichotomiser 3)是一种基于熵和信息增益的决策树分类算法。它通过计算每个属性的信息增益来选择最佳分裂节点,从而构建一棵用于分类的决策树。 对于泰坦尼克号生存预测问题,可以利用乘客的数据集(如年龄、性别、船舱等级等),并将其作为输入特征训练一个ID3决策树模型。以下是实现的具体方法: --- #### 数据预处理 由于原始数据可能包含缺失值或非数值型变量,在应用ID3算法前需先对数据进行清洗和转换: - **填充缺失值**:例如,“Age”列中的空缺可以用平均值或其他统计量替代。 - **编码类别变量**:将字符串类型的变量(如“Sex”,“Embarked”)转化为数值形式以便于后续计算。 ```python import pandas as pd from sklearn.preprocessing import LabelEncoder # 加载数据 data = pd.read_csv('titanic.csv') # 填充缺失值 data['Age'].fillna(data['Age'].mean(), inplace=True) # 编码类别变量 label_encoder = LabelEncoder() data['Sex'] = label_encoder.fit_transform(data['Sex']) ``` 上述代码片段展示了如何处理部分常见字段[^1]。 --- #### 构建ID3决策树 Python 中虽然没有内置支持直接创建ID3树的功能模块,但可以通过 `sklearn` 的 DecisionTreeClassifier 来模拟这一过程。默认情况下该类采用的是CART算法而非严格意义上的ID3;不过设置参数 criterion='entropy' 可让其更接近ID3的工作原理——即依据信息熵来进行分割判断而不是基尼指数[Gini impurity]. 下面给出一段简单的例子展示怎样运用 scikit-learn 库建立类似的结构: ```python from sklearn.tree import DecisionTreeClassifier from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score # 准备特征与目标标签 features = ['Pclass', 'Sex', 'Age'] X = data[features] y = data['Survived'] # 划分训练集测试集 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) # 初始化决策树分类器(模仿ID3) clf_id3_like = DecisionTreeClassifier(criterion="entropy") # 训练模型 clf_id3_like.fit(X_train, y_train) # 预测结果 & 输出准确率 predictions = clf_id3_like.predict(X_test) print(f"Accuracy: {accuracy_score(y_test, predictions):.2f}") ``` 此脚本定义了一个类似于ID3的行为模式下的决策树,并评估了它的表现效果[^3]。 注意这里仅选取了几项基础要素做示范用途实际操作时应考虑更多维度以及深入探索各类因素间的关系比如借助热力图查看相互关联程度等等[^2]: ![Heatmap](https://i.imgur.com/yourimage.png) *(假设这是张代表数理关系强度图表)* --- #### 结果解释 最终得出的结果会告诉你所设计系统的精确度水平。如果发现某些特定条件下存活几率特别高或者低,则可进一步挖掘背后隐藏的故事线索进而优化整个流程体系。 --- ### 总结 综上所述,尽管标准库并不完全提供原生版本的支持,但我们依然能够巧妙借用现有工具达成近似目的—即针对给定历史记录资料执行有效的二元判定任务!
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值