非常好的mlr3案例集合,原文来源于:
https://mlr3gallery.mlr-org.com/posts/2020-05-04-moneyball/
这里仅仅是把代码复制下来方便以后的查找和使用:
library("mlr3")
library("mlr3learners")
library("mlr3pipelines")
requireNamespace("mlr3measures")
library("mlr3data")
# 查看缺失数据
skim(moneyball)
# 根据数据类型填充缺失数据
imp_num = po("imputehist", param_vals = list(affect_columns = selector_type(c("integer", "numeric"))))
imp_fct = po("imputeoor", param_vals = list(affect_columns = selector_type("factor")))
graph = imp_num %>>% imp_fct
graph$plot()
task = TaskRegr$new(id = "moneyball", backend = moneyball, target = "rs")
task$missings()
# creates a learner
test_lrn = LearnerRegrRanger$new()
# displays the properties
test_lrn$properties
# Creates pipo to imputate missing values then goes to learners
# creates a normal learner however allows further embedding of PipeOp's.
polrn = PipeOpLearner$new(mlr_learners$get("regr.ranger"))
# sets number of trees to 1000, importance is for later
polrn$param_set$values = list(num.trees = 1000, importance = "permutation")
# the final learner is a graph consisting of the imputer and the normal learner.
lrn = GraphLearner$new(graph = graph %>>% polrn)
# defines the training and testing data; 95% is used for training
train_set = sample(task$nrow, 0.95 * task$nrow)
test_set = setdiff(seq_len(task$nrow), train_set)
# train learner on subset of task
lrn$train(task, row_ids = train_set)
# predict using held out observations
preds = lrn$predict(task, row_ids = test_set)
print(preds)
#抽样核对
cv10 = rsmp("cv", folds = 10)
r = resample(task, lrn, cv10)
scores = r$score(msrs(c("regr.mae", "regr.mse")))
scores
r$aggregate(msr("regr.mae"))
imp_fct = po("imputeoor", param_vals = list(affect_columns = selector_type("factor")))
graph2 = as_graph(imp_fct)
# Ignores two features then recalculates for comparing
feature_names = colnames(moneyball)[!sapply(moneyball, anyNA)]
feature_names = c(
feature_names[feature_names %in% task$feature_names],
"rankseason", "rankplayoffs")
na_select = po("select")
na_select$param_set$values$selector = selector_name(feature_names)
graph2 = graph2 %>>% na_select
graph2$plot()
lrn2 = GraphLearner$new(graph = graph2 %>>% polrn)
r2 = resample(task, lrn2, cv10)
r2$aggregate(msr("regr.mae"))
# variable importance
sort(lrn$model$regr.ranger$model$variable.importance, decreasing = TRUE)