机器学习案例详解的直播互动平台——
机器学习训练营(入群联系qq:2279055353)
下期直播案例预告:大数据预测商品的销售量波动趋势
案例简介
本案例要求根据乘客的旅程属性,建立一个模型预测纽约市出租车的行程时间,相关数据集来自Google云平台。该案例使用R语言编码。
我们的解决方案将分成以下三步进行:
-
可视化数据集,加工新特征,检查离群点。
-
增加外部数据集
-
XGBoost分类模型
数据描述
数据由1.5M的训练观测train.csv
和630K的检验观测test.csv
组成。每行观测代表一个乘车旅程。
介绍
加载R包和函数
首先,我们加载必需的R包。
library('ggplot2') # visualisation
library('scales') # visualisation
library('grid') # visualisation
library('RColorBrewer') # visualisation
library('corrplot') # visualisation
library('alluvial') # visualisation
library('dplyr') # data manipulation
library('readr') # input/output
library('data.table') # data manipulation
library('tibble') # data wrangling
library('tidyr') # data wrangling
library('stringr') # string manipulation
library('forcats') # factor manipulation
library('lubridate') # date and time
library('geosphere') # geospatial locations
library('leaflet') # maps
library('leaflet.extras') # maps
library('maps') # maps
library('xgboost') # modelling
library('caret') # modelling
然后,我们定义一个多图函数,该函数将在可视化时使用。
# Define multiple plot function
#
# ggplot objects can be passed in ..., or to plotlist (as a list of ggplot objects)
# - cols: Number of columns in layout
# - layout: A matrix specifying the layout. If present, 'cols' is ignored.
#
# If the layout is something like matrix(c(1,2,3,3), nrow=2, byrow=TRUE),
# then plot 1 will go in the upper left, 2 will go in the upper right, and
# 3 will go all the way across the bottom.
#
multiplot <- function(..., plotlist=NULL, file, cols=1, layout=NULL) {
# Make a list from the ... arguments and plotlist
plots <- c(list(...), plotlist)
numPlots = length(plots)
# If layout is NULL, then use 'cols' to determine layout
if (is.null(layout)) {
# Make the panel
# ncol: Number of columns of plots
# nrow: Number of rows needed, calculated from # of cols
layout <- matrix(seq(1, cols * ceiling(numPlots/cols)),
ncol = cols, nrow = ceiling(numPlots/cols))
}
if (numPlots==1) {
print(plots[[1]])
} else {
# Set up the page
grid.newpage()
pushViewport(viewport(layout = grid.layout(nrow(layout), ncol(layout))))
# Make each plot, in the correct location
for (i in 1:numPlots) {
# Get the i,j matrix positions of the regions that contain this subplot
matchidx <- as.data.frame(which(layout == i, arr.ind = TRUE))
print(plots[[i]], vp = viewport(layout.pos.row = matchidx$row,
layout.pos.col = matchidx$col))
}
}
}
加载数据
这里,我们使用data.table
包的fread
函数,加快数据的读取。
train <- as.tibble(fread('../input/nyc-taxi-trip-duration/train.csv'))
test <- as.tibble(fread('../input/nyc-taxi-trip-duration/test.csv'))
查看数据
让我们来观察一下训练集和检验集的数据分布和变量类型等信息。以训练集为例:
summary(train)