Bart模型应用实例及解析(二)————基于泰坦尼克号数据集的分类模型
前言
这里是在实战中使用Bart模型对数据进行建模及分析,如果有读者对如何建模以及建模函数的参数不了解,对建模后的结果里的参数不清楚的话,可以参考学习作者前面两篇文章内容。以便更好地理解模型、建模过程及思想。
R bartMachine包内bartMachine函数参数详解
https://blog.youkuaiyun.com/qq_35674953/article/details/115774921
BartMachine函数建模结果参数解析
https://blog.youkuaiyun.com/qq_35674953/article/details/115804662
提示:以下是本篇文章正文内容
一、数据集
1、数据集的获取
链接:https://pan.baidu.com/s/1TZejG8fZTS35RQctwtTn-Q
提取码:h6sv
数据部分截图:
2、数据集变量名及意义
变量名 | 意义 |
---|---|
Survived | 分类变量,是否死亡。0代表死亡,1代表存活 |
Pclass | 乘客所持票类,有三种值(1,2,3) |
Name | 乘客姓名 |
Sex | 乘客性别 |
Age | 乘客年龄(有缺失) |
SibSp | 乘客兄弟姐妹/配偶的个数(整数值) |
Parch | 乘客父母/孩子的个数(整数值) |
Ticket | 票号(字符串) |
Fare | 乘客所持票的价格(浮点数,0-500不等) |
Cabin | 乘客所在船舱(有缺失) |
Embark | 乘客登船港口:S、C、Q(有缺失)。赋值1,、2、3。 |
3、数据集处理
由实际经验,作者认为自变量乘客姓名(Name)、船票票号(Ticket)、船舱号(Cabin)对因变量是否存活(Survived)没有影响,所以删去这几个变量。对于变量年龄(Age)的缺失值,由于数据集比较大,就删去了有缺失数据。
二、完整代码
代码如下(示例):
options(java.parameters = "-Xmx10g")
library(ggplot2)
library(bartMachine)
library(reshape2)
library(knitr)
library(ggplot2)
library(GGally)
library(scales)
percent((1:5) / 100)
##读取数据
data<-read.csv(file="C:/Users/LHW/Desktop/tt.csv",header=T,sep=",")
head(data)
n=dim(data)
n
da<-melt(data)
#画出数据箱线图
ggplot(da, aes(x=variable, y=value, fill=variable))+ geom_boxplot()+facet_wrap(~variable,scales="free")
#画出数据直方图
ggplot(da, aes(value, fill=variable))+ geom_histogram()+facet_wrap(~variable,scales="free")
cormat <- round(cor(data[,2:8]), 2)
head(cormat)
melted_cormat <- melt(cormat)
head(melted_cormat)
# 把一侧三角形的值转化为NA
get_upper_tri <- function(cormat){
cormat[lower.tri(cormat)]<- NA
return(cormat)
}
upper_tri <- get_upper_tri(cormat)
upper_tri
#转化为矩阵
library(reshape2)
melted_cormat <- melt(upper_tri,na.rm = T)
#作相关系数热力图
ggplot(data = melted_cormat, aes(x=Var2, y=Var1, fill = value)) +
geom_tile(color = "white") +
scale_fill_gradient2(low = "blue", high = "red", mid = "white",
midpoint = 0, limit = c(-1, 1), space = "Lab",
name="Pearson\nCorrelation") +
theme_minimal() +
theme(axis.text.x = element_text(angle = 45, vjust = 1,
size = 12, hjust = 1)) +
coord_fixed() +
geom_text(aes(Var2, Var1, label = value), color = "black", size = 4)
#随机种子
set.seed(1000)
#按照80%和20%比例划分训练集和测试集
index2=sample(x=2,size=nrow(data),replace=TRUE,prob=c(0.8,0.2))
#训练集
train2=data[index2==1,]
head(train2)
x=train2[,-c(1)]
y=train2[,1]
y = factor(y)
#预测集
data2=data[index2==2,]
x.test_data=data2[,-c(1)]
head(data2)
xp=x.test_data
yp=data2[,1]
yp = factor(yp)
#建立Bart模型
res = bartMachine(x,y,prob_rule_class = 0.5)
print(res)
rm<-res$confusion_matrix
#计算精度、查准率、查全率
A=(rm[1,1]+rm[2,2])/length(y)
cat("精度为:",percent(A,accuracy = 0.01), "\n")
P=rm[2,2]/(rm[2,2]+rm[1,2])
cat("查准率为:",percent(P,accuracy = 0.01), "\n")
R=rm[2,2]/(rm[2,2]+