探索泰坦尼克号上的生存之道
kaggle-Titanic - Machine Learning from Disaster
使用机器学习创建一个模型,预测哪些乘客在泰坦尼克号沉船事件中幸存下来。
泰坦尼克号沉没是历史上最臭名昭著的海难之一。
1912 年 4 月 15 日,在处女航中,被广泛认为“永不沉没”的皇家邮轮泰坦尼克号与冰山相撞后沉没。不幸的是,救生艇数量不足以容纳船上所有人,导致 2224 名乘客和船员中有 1502 人丧生。
尽管生存有一定的运气因素,但似乎有些人比其他人群更有可能生存下来。
在这个挑战中,我们要求您使用乘客数据(即姓名、年龄、性别、社会经济阶层等)建立一个预测模型,回答这个问题:“什么样的人更有可能生存?”
1 简介
1. 1 加载并检查数据
# Load packages
library('ggplot2') # visualization
library('ggthemes') # visualization
library('scales') # visualization
library('dplyr') # data manipulation
library('mice') # imputation
library('randomForest') # classification algorithm
现在我们的包已加载,让我们读入并查看数据。
train <- read.csv('../input/train.csv', stringsAsFactors = F)
test <- read.csv('../input/test.csv', stringsAsFactors = F)
full <- bind_rows(train, test) # bind training & test data
# check data
str(full)
## 'data.frame': 1309 obs. of 12 variables:
## $ PassengerId: int 1 2 3 4 5 6 7 8 9 10 ...
## $ Survived : int 0 1 1 1 0 0 0 0 1 1 ...
## $ Pclass : int 3 1 3 1 3 3 1 3 3 2 ...
## $ Name : chr "Braund, Mr. Owen Harris" "Cumings, Mrs. John Bradley (Florence Briggs Thayer)" "Heikkinen, Miss. Laina" "Futrelle, Mrs. Jacques Heath (Lily May Peel)" ...
## $ Sex : chr "male" "female" "female" "female" ...
## $ Age : num 22 38 26 35 35 NA 54 2 27 14 ...
## $ SibSp : int 1 1 0 1 0 0 0 3 0 1 ...
## $ Parch : int 0 0 0 0 0 0 0 1 2 0 ...
## $ Ticket : chr "A/5 21171" "PC 17599" "STON/O2. 3101282" "113803" ...
## $ Fare : num 7.25 71.28 7.92 53.1 8.05 ...
## $ Cabin : chr "" "C85" "" "C123" ...
## $ Embarked : chr "S" "C" "S" "S" ...
我们已经了解了我们的变量、它们的类类型以及每个变量的前几个观察结果。我们知道我们正在处理 12 个变量的 1309 个观测值。由于一些变量名称并不是 100% 具有说明性,为了使事情更加明确,我们需要处理以下内容:
Variable Name | Description |
---|---|
Survived | Survived (1) or died (0) |
Pclass | Passenger’s class |
Name | Passenger’s name |
Sex | Passenger’s sex |
Age | Passenger’s age |
SibSp | Number of siblings/spouses aboard |
Parch | Number of parents/children aboard |
Ticket | Ticket number |
Fare | Fare |
Cabin | Cabin |
Embarked | Port of embarkation |
2 特征工程
2.1 名字意味着什么
引起我注意的第一个变量是乘客姓名,因为我们可以将其分解为其他有意义的变量,这些变量可以提供预测或用于创建其他新变量。例如,乘客头衔包含在乘客姓名变量中,我们可以使用姓氏来代表家庭。让我们来做一些特征工程吧!
# Grab title from passenger names
full$Title <- gsub('(.*, )|(\\..*)', '', full$Name)
# Show title counts by sex
table(full$Sex, full$Title)
##
## Capt Col Don Dona Dr Jonkheer Lady Major Master Miss Mlle Mme
## female 0 0 0 1 1 0 1 0 0 260 2 1
## male 1 4 1 0 7 1 0 2 61 0 0 0
##
## Mr Mrs Ms Rev Sir the Countess
## female 0 197 2 0 0 1
## male 757 0 0 8 1 0
# Titles with very low cell counts to be combined to "rare" level
rare_title <- c('Dona', 'Lady', 'the Countess','Capt', 'Col', 'Don',
'Dr', 'Major', 'Rev', 'Sir', 'Jonkheer')
# Also reassign mlle, ms, and mme accordingly
full$Title[full$Title == 'Mlle'] <- 'Miss'
full$Title[full$Title == 'Ms'] <- 'Miss'
full$Title[full$Title == 'Mme'] <- 'Mrs'
full$Title[full$Title %in% rare_title] <- 'Rare Title'
# Show title counts by sex again
table(full$Sex, full$Title)
##
## Master Miss Mr Mrs Rare Title
## female 0 264 0 198 4
## male 61 0 757 0 25
# Finally, grab surname from passenger name
full$Surname <- sapply(full$Name,
function(x) strsplit(x, split = '[,.]')[[1]][1])
cat(paste('We have <b>', nlevels(factor(full$Surname)), '</b> unique surnames. I would be interested to infer ethnicity based on surname --- another time.'))
2.2 家庭是同沉同浮吗?
现在我们已经将乘客姓名拆分为一些新变量,我们可以更进一步,创建一些新的家庭变量。首先,我们将根据兄弟姐妹/配偶的数量(也许有人有多个配偶?)和孩子/父母的数量来创建家庭规模变量。
# Create a family size variable including the passenger themselves
full$Fsize <- full$SibSp + full$Parch + 1
# Create a family variable
full$Family <- paste(full$Surname, full$Fsize, sep='_')
我们的家庭规模变量是什么样的?为了帮助我们理解它与生存的关系,让我们在训练数据中绘制它。
# Use ggplot2 to visualize the relationship between family size & survival
ggplot(full[1:891,], aes(x = Fsize, fill = factor(Survived))) +
geom_bar(stat='count', position='dodge') +
scale_x_continuous(breaks=c(1:11)) +
labs(x = 'Family Size') +
theme_few()
我们可以看到,单身人士和家庭规模超过 4 人的人会受到生存惩罚。我们可以将此变量分解为三个级别,这将很有帮助,因为大家庭相对较少。让我们创建一个离散的家庭规模变量。
# Discretize family size
full$FsizeD[full$Fsize == 1] <- 'singleton'
full$FsizeD[full$Fsize < 5 & full$Fsize > 1] <- 'small'
full$FsizeD[full$Fsize > 4] <- 'large'
# Show family size by survival using a mosaic plot
mosaicplot(table(full$FsizeD, full$Survived), main='Family Size by Survival', shade=TRUE)
马赛克图显示,我们保留了我们的规则,即单身人士和大家庭会受到生存惩罚,但小家庭的乘客会受益。我想对我们的年龄变量做进一步的处理,但是 263 行缺少年龄值,所以我们必须等到我们解决缺失问题之后。
2.3 处理更多变量
还剩下什么?客舱变量中可能有一些潜在有用的信息,包括有关其甲板的信息。我们来看一下。
# This variable appears to have a lot of missing values
full$Cabin[1:28]
## [1] "" "C85" "" "C123" ""
## [6] "" "E46" "" "" ""
## [11] "G6" "C103" "" "" ""
## [16] "" "" "" "" ""
## [21] "" "D56" "" "A6" ""
## [26] "" "" "C23 C25 C27"
# The first character is the deck. For example:
strsplit(full$Cabin[2], NULL)[[1]]
## [1] "C" "8" "5"
# Create a Deck variable. Get passenger deck A - F:
full$Deck<-factor(sapply(full$Cabin, function(x) strsplit(x, NULL)[[1]][1]))
这里还可以做更多的事情,包括查看列出了多个房间的小屋(例如,第 28 行:“C23 C25 C27”),但考虑到该列的稀疏性,我们将在此停止。
3 缺失值
现在我们准备开始探索缺失的数据并通过插补来纠正它。我们可以通过多种不同的方式来做到这一点。鉴于数据集较小,我们可能不应该选择删除整个观测值(行)或包含缺失值的变量(列)。我们可以选择用给定数据分布的合理值(例如均值、中位数或众数)替换缺失值。最后,我们可以进行预测。我们将使用后两种方法,并且我将依靠一些数据可视化来指导我们的决策。
3.1 合理值估算
# Passengers 62 and 830 are missing Embarkment
full[c(62, 830), 'Embarked']
## [1] "" ""
我们将根据我们认为可能相关的现有数据推断其登机价值:乘客舱位和票价。我们看到他们分别支付了 $80 和 $NA ,他们的班级是 1 和 NA 。那么他们从哪里出发呢?
# Get rid of our missing passenger IDs
embark_fare <- full %>%
filter(PassengerId != 62 & PassengerId != 830)
# Use ggplot2 to visualize embarkment, passenger class, & median fare
ggplot(embark_fare, aes(x = Embarked, y = Fare, fill = factor(Pclass))) +
geom_boxplot() +
geom_hline(aes(yintercept=80),
colour='red', linetype='dashed', lwd=2) +
scale_y_continuous(labels=dollar_format()) +
theme_few()
从沙堡 (‘C’) 出发的头等舱乘客的票价中位数与我们登机不足的乘客支付的 80 美元非常吻合。我认为我们可以安全地将 NA 值替换为“C”。
# Since their fare was $80 for 1st class, they most likely embarked from 'C'
full$Embarked[c(62, 830)] <- 'C'
我们即将修复一些 NA 值。第 1044 行的乘客的票价值为 NA。
# Show row 1044
full[1044, ]
## PassengerId Survived Pclass Name Sex Age SibSp Parch
## 1044 1044 NA 3 Storey, Mr. Thomas male 60.5 0 0
## Ticket Fare Cabin Embarked Title Surname Fsize Family FsizeD
## 1044 3701 NA S Mr Storey 1 Storey_1 singleton
## Deck
## 1044 <NA>
这是从南安普敦(‘S’)出发的三等舱乘客。让我们想象一下所有其他共享舱位和登机情况的票价 (n = 494)。
ggplot(full[full$Pclass == '3' & full$Embarked == 'S', ],
aes(x = Fare)) +
geom_density(fill = '#99d6ff', alpha=0.4) +
geom_vline(aes(xintercept=median(Fare, na.rm=T)),
colour='red', linetype='dashed', lwd=1) +
scale_x_continuous(labels=dollar_format()) +
theme_few()
从该可视化结果来看,将 NA 票价值替换为其舱位和登机时间的中位数(即 8.05 美元)似乎相当合理。
# Replace missing fare value with median fare for class/embarkment
full$Fare[1044] <- median(full[full$Pclass == '3' & full$Embarked == 'S', ]$Fare, na.rm = TRUE)
3.2 预测插补
最后,正如我们之前指出的,我们的数据中缺少相当多的年龄值。我们将更加花哨地估算缺失的年龄值。为什么?因为我们可以。我们将创建一个基于其他变量预测年龄的模型。
# Show number of missing Age values
sum(is.na(full$Age))
## [1] 263
我们绝对可以使用 rpart(回归的递归分区)来预测缺失的年龄,但我将使用 mouse 包来完成此任务,只是为了做一些不同的事情。您可以在此处阅读有关使用 r 中的链式方程进行多重插补的更多信息 (PDF)。由于我们还没有这样做,我将首先对因子变量进行因式分解,然后进行小鼠插补。
# Make variables factors into factors
factor_vars <- c('PassengerId','Pclass','Sex','Embarked',
'Title','Surname','Family','FsizeD')
full[factor_vars] <- lapply(full[factor_vars], function(x) as.factor(x))
# Set a random seed
set.seed(129)
# Perform mice imputation, excluding certain less-than-useful variables:
mice_mod <- mice(full[, !names(full) %in% c('PassengerId','Name','Ticket','Cabin','Family','Surname','Survived')], method='rf')
##
## iter imp variable
## 1 1 Age Deck
## 1 2 Age Deck
## 1 3 Age Deck
## 1 4 Age Deck
## 1 5 Age Deck
## 2 1 Age Deck
## 2 2 Age Deck
## 2 3 Age Deck
## 2 4 Age Deck
## 2 5 Age Deck
## 3 1 Age Deck
## 3 2 Age Deck
## 3 3 Age Deck
## 3 4 Age Deck
## 3 5 Age Deck
## 4 1 Age Deck
## 4 2 Age Deck
## 4 3 Age Deck
## 4 4 Age Deck
## 4 5 Age Deck
## 5 1 Age Deck
## 5 2 Age Deck
## 5 3 Age Deck
## 5 4 Age Deck
## 5 5 Age Deck
# Save the complete output
mice_output <- complete(mice_mod)
让我们将得到的结果与乘客年龄的原始分布进行比较,以确保没有任何问题。
# Plot age distributions
par(mfrow=c(1,2))
hist(full$Age, freq=F, main='Age: Original Data',
col='darkgreen', ylim=c(0,0.04))
hist(mice_output$Age, freq=F, main='Age: MICE Output',
col='lightgreen', ylim=c(0,0.04))
事情看起来不错,所以让我们用小鼠模型的输出替换原始数据中的年龄向量。
# Replace Age variable from the mice model.
full$Age <- mice_output$Age
# Show new number of missing Age values
sum(is.na(full$Age))
## [1] 0
我们现在已经完成了我们关心的所有变量的输入值!现在我们有了完整的 Age 变量,我只想做一些收尾工作。我们可以使用 Age 来做更多的特征工程……
3.3 第二轮特征工程
现在我们知道了每个人的年龄,我们可以创建几个新的与年龄相关的变量:儿童和母亲。孩子只是 18 岁以下的人,母亲是乘客,且满足以下条件:1) 女性,2) 年满 18 岁,3) 有超过 0 个孩子(不是开玩笑!),4) 没有头衔’错过’。
# First we'll look at the relationship between age & survival
ggplot(full[1:891,], aes(Age, fill = factor(Survived))) +
geom_histogram() +
# I include Sex since we know (a priori) it's a significant predictor
facet_grid(.~Sex) +
theme_few()
# Create the column child, and indicate whether child or adult
full$Child[full$Age < 18] <- 'Child'
full$Child[full$Age >= 18] <- 'Adult'
# Show counts
table(full$Child, full$Survived)
##
## 0 1
## Adult 484 274
## Child 65 68
我们将通过创建 Mother 变量来完成我们的特征工程。
# Adding Mother variable
full$Mother <- 'Not Mother'
full$Mother[full$Sex == 'female' & full$Parch > 0 & full$Age > 18 & full$Title != 'Miss'] <- 'Mother'
# Show counts
table(full$Mother, full$Survived)
##
## 0 1
## Mother 16 39
## Not Mother 533 303
# Finish by factorizing our two new factor variables
full$Child <- factor(full$Child)
full$Mother <- factor(full$Mother)
我们关心的所有变量都应该得到照顾,并且不应丢失数据。我要仔细检查一下以确保:
md.pattern(full)
## Warning in data.matrix(x): NAs introduced by coercion
## Warning in data.matrix(x): NAs introduced by coercion
## Warning in data.matrix(x): NAs introduced by coercion
## PassengerId Pclass Sex Age SibSp Parch Fare Embarked Title Surname
## 150 1 1 1 1 1 1 1 1 1 1
## 61 1 1 1 1 1 1 1 1 1 1
## 54 1 1 1 1 1 1 1 1 1 1
## 511 1 1 1 1 1 1 1 1 1 1
## 30 1 1 1 1 1 1 1 1 1 1
## 235 1 1 1 1 1 1 1 1 1 1
## 176 1 1 1 1 1 1 1 1 1 1
## 92 1 1 1 1 1 1 1 1 1 1
## 0 0 0 0 0 0 0 0 0 0
## Fsize Family FsizeD Child Mother Ticket Survived Deck Name Cabin
## 150 1 1 1 1 1 1 1 1 0 0 2
## 61 1 1 1 1 1 1 0 1 0 0 3
## 54 1 1 1 1 1 0 1 1 0 0 3
## 511 1 1 1 1 1 1 1 0 0 0 3
## 30 1 1 1 1 1 0 0 1 0 0 4
## 235 1 1 1 1 1 1 0 0 0 0 4
## 176 1 1 1 1 1 0 1 0 0 0 4
## 92 1 1 1 1 1 0 0 0 0 0 5
## 0 0 0 0 0 352 418 1014 1309 1309 4402
我们终于完成了泰坦尼克号数据集中所有相关缺失值的处理,其中包括对小鼠的一些奇特的插补。我们还成功创建了几个新变量,希望它们能够帮助我们建立一个可靠预测生存的模型。尼克号数据集中所有相关缺失值的处理,其中包括对小鼠的一些奇特的插补。我们还成功创建了几个新变量,希望它们能够帮助我们建立一个可靠预测生存的模型。
4 预测
最后,我们准备根据我们精心策划和处理缺失值的变量来预测泰坦尼克号乘客中的幸存者。为此,我们将依靠随机森林分类算法
4.1 分为训练集和测试集
我们的第一步是将数据拆分回原始测试集和训练集。
# Split the data back into a train set and a test set
train <- full[1:891,]
test <- full[892:1309,]
4.2 建立模型
然后我们在训练集上使用 randomForest 构建模型。
# Set a random seed
set.seed(754)
# Build the model (note: not all possible variables are used)
rf_model <- randomForest(factor(Survived) ~ Pclass + Sex + Age + SibSp + Parch +
Fare + Embarked + Title +
FsizeD + Child + Mother,
data = train)
# Show model error
plot(rf_model, ylim=c(0,0.36))
legend('topright', colnames(rf_model$err.rate), col=1:3, fill=1:3)
黑线显示总体错误率低于 20%。红线和绿线分别显示“死亡”和“幸存”的错误率。我们可以看到,现在我们预测死亡比预测生存要成功得多。我想知道这对我来说意味着什么?
4.3 变量重要性
让我们通过绘制所有树计算的基尼系数的平均下降来看看相对变量的重要性。
# Get importance
importance <- importance(rf_model)
varImportance <- data.frame(Variables = row.names(importance),
Importance = round(importance[ ,'MeanDecreaseGini'],2))
# Create a rank variable based on importance
rankImportance <- varImportance %>%
mutate(Rank = paste0('#',dense_rank(desc(Importance))))
# Use ggplot2 to visualize the relative importance of variables
ggplot(rankImportance, aes(x = reorder(Variables, Importance),
y = Importance, fill = Importance)) +
geom_bar(stat='identity') +
geom_text(aes(x = Variables, y = 0.5, label = Rank),
hjust=0, vjust=0.55, size = 4, colour = 'red') +
labs(x = 'Variables') +
coord_flip() +
theme_few()
在我们所有的预测变量中,它具有最高的相对重要性。
4.4 预测
我们已准备好进行最后一步 - 做出预测!当我们完成这里时,我们可以迭代前面的步骤,进行调整,或者使用不同的模型拟合数据或使用不同的变量组合来实现更好的预测。但这现在对我来说是一个很好的起点(和终点)。
# Predict using the test set
prediction <- predict(rf_model, test)
# Save the solution to a dataframe with two columns: PassengerId and Survived (prediction)
solution <- data.frame(PassengerID = test$PassengerId, Survived = prediction)
# Write the solution to file
write.csv(solution, file = 'rf_mod_Solution.csv', row.names = F)