【译】R包介绍:Online Random Forest

本文介绍了一种名为在线随机森林(ORF)的机器学习方法,该方法由Amir Saffari等人提出,并由Arthur Lui使用Python实现。作者顾全在Python实现的基础上,进一步用R语言重构了代码,并使其支持增量学习和批量学习。文中提供了详细的安装指南和使用示例,包括分类任务和回归任务的应用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

作者:顾全,浙江大学软件工程硕士,现任桃树科技算法工程师

地址:

https://github.com/ZJUguquan/OnlineRandomForest

参与:Cynthia

翻译:本文为天善智能编译,未经容许,禁止转载


介绍

Online Random Forest(ORF) 是由Amir Saffari等人最先提出。之后,Arthur Lui使用Python实现了算法。非常感谢他们的工作。在论文内容和Lui的算法的基础上,我通过R和R6包重构了代码。此外,ORF在此包中的实现,与randomForest结合,使它同时支持增量学习和批量学习,例如:在ORF的基础上构建树,然后通过ORF进行更新。通过这种方法,它将比以前快得多。

安装

if(!require(devtools)) install.packages("devtools")
devtools::install_github("ZJUguquan/OnlineRandomForest")

快速启动

  • 最小举例:增量学习

library(OnlineRandomForest)
param <- list('minSamples'= 1, 'minGain'= 0.1, 'numClasses'= 3, 'x.rng'= dataRange(iris[1:4]))
orf <- ORF$new(param, numTrees = 10)
for (i in 1:150) orf$update(iris[i, 1:4], as.integer(iris[i, 5]))
cat("Mean depth of trees in the forest is:", orf$meanTreeDepth(), "\n")
orf$forest[[2]]$draw()

## Mean depth of trees in the forest is: 3

## Root X4 < 1.21
## |----L: X3 < 2.38
##      |----L: Leaf 1
##      |----R: Leaf 2
## |----R: X4 < 2.15
##      |----L: X1 < 4.92
##           |----L: Leaf 3
##           |----R: Leaf 3
##      |----R: Leaf 3

  • 分类举例

library(OnlineRandomForest)

# data preparation
dat <- iris; dat[,5] <- as.integer(dat[,5])
x.rng <- dataRange(dat[1:4])
param <- list('minSamples'= 2, 'minGain'= 0.2, 'numClasses'= 3, 'x.rng'= x.rng)
ind.gen <- sample(1:150,30) # for generate ORF
ind.updt <- sample(setdiff(1:150, ind.gen), 100) # for uodate ORF
ind.test <- setdiff(setdiff(1:150, ind.gen), ind.updt) # for test

# construct ORF and update
rf <- randomForest::randomForest(factor(Species) ~ ., data = dat[ind.gen, ], maxnodes = 2, ntree = 100)
orf <- ORF$new(param)
orf$generateForest(rf, df.train = dat[ind.gen, ], y.col = "Species")
cat("Mean size of trees in the forest is:", orf$meanTreeSize(), "\n")


## Mean size of trees in the forest is: 3


for (i in ind.updt) {
 orf$update(dat[i, 1:4], dat[i, 5])
}
cat("After update, mean size of trees in the forest is:", orf$meanTreeSize(), "\n")


## After update, mean size of trees in the forest is: 11.9


# predict
orf$confusionMatrix(dat[ind.test, 1:4], dat[ind.test, 5], pretty = T)


##
##  
##    Cell Contents
## |-------------------------|
## |                       N |
## |           N / Row Total |
## |           N / Col Total |
## |-------------------------|
##
##  
## Total Observations in Table:  20
##
##  
##              | actual
##   prediction |         1 |         2 |         3 | Row Total |
## -------------|-----------|-----------|-----------|-----------|
##            1 |         4 |         0 |         0 |         4 |
##              |     1.000 |     0.000 |     0.000 |     0.200 |
##              |     1.000 |     0.000 |     0.000 |           |
## -------------|-----------|-----------|-----------|-----------|
##            2 |         0 |         9 |         2 |        11 |
##              |     0.000 |     0.818 |     0.182 |     0.550 |
##              |     0.000 |     1.000 |     0.286 |           |
## -------------|-----------|-----------|-----------|-----------|
##            3 |         0 |         0 |         5 |         5 |
##              |     0.000 |     0.000 |     1.000 |     0.250 |
##              |     0.000 |     0.000 |     0.714 |           |
## -------------|-----------|-----------|-----------|-----------|
## Column Total |         4 |         9 |         7 |        20 |
##              |     0.200 |     0.450 |     0.350 |           |
## -------------|-----------|-----------|-----------|-----------|
##
##


# compare
table(predict(rf, newdata = dat[ind.test,]) == dat[ind.test, 5])


## FALSE  TRUE
##     9    11


table(orf$predicts(X = dat[ind.test,]) == dat[ind.test, 5])


## FALSE  TRUE
##     2    18


  • 回归举例

# data preparation
if(!require(ggplot2)) install.packages("ggplot2")
data("diamonds", package = "ggplot2")
dat <- as.data.frame(diamonds[sample(1:53000,1000), c(1:6,8:10,7)])
for (col in c("cut","color","clarity")) dat[[col]] <- as.integer(dat[[col]]) # Don't forget this
x.rng <- dataRange(dat[1:9])
param <- list('minSamples'= 10, 'minGain'= 1, 'maxDepth' = 10, 'x.rng'= x.rng)
ind.gen <- sample(1:1000, 800)
ind.updt <- sample(setdiff(1:1000, ind.gen), 100)
ind.test <- setdiff(setdiff(1:1000, ind.gen), ind.updt)


# construct ORF
rf <- randomForest::randomForest(price ~ ., data = dat[ind.gen, ], maxnodes = 20, ntree = 100)
orf <- ORF$new(param)
orf$generateForest(rf, df.train = dat[ind.gen, ], y.col = "price")
orf$meanTreeSize()


## [1] 39


# and update
for (i in ind.updt) {
 orf$update(dat[i, 1:9], dat[i, 10])

}
orf$meanTreeSize()


## [1] 105.7


# predict and compare
if(!require(Metrics)) install.packages("Metrics")
preds.rf <- predict(rf, newdata = dat[ind.test,])
Metrics::rmse(preds.rf, dat$price[ind.test])


## [1] 988.8055


preds <- orf$predicts(dat[ind.test, 1:9])
Metrics::rmse(preds, dat$price[ind.test]) # make progress


## [1] 869.9613


其他用途

  • 在 Tree 类中

ta <- Tree$new("abc", NULL, NULL)
tb <- Tree$new(1, Tree$new(36), Tree$new(3))
tc <- Tree$new(89, tb, ta)
tc$draw()

# update tc
tc$right$updateChildren(Tree$new("666"), Tree$new(999))
tc$right$right$updateChildren(Tree$new("666"), Tree$new(999))
tc$draw()


  • 通过random Forest包配置一个Online random Tree,并升级

# data preparation
library(randomForest)
dat1 <- iris; dat1[,5] <- as.integer(dat1[,5])
rf <- randomForest(factor(Species) ~ ., data = dat1, maxnodes = 3)
treemat1 <- getTree(rf, 1, labelVar=F)
treemat1 <- cbind(treemat1, node.ind = 1:nrow(treemat1))
x.rng1 <- dataRange(dat1[1:4])
param1 <- list('minSamples'= 5, 'minGain'= 0.1, 'numClasses'= 3, 'x.rng'= x.rng1)
ind.gen <- sample(1:150,50) # for generate ORT
ind.updt <- setdiff(1:150, ind.gen) # for update ORT

# origin
ort2 <- ORT$new(param1)
ort2$draw()


## Root 1
##  Leaf 1


# generate a tree


ort2$generateTree(treemat1, df.node = dat1[ind.gen,])
ort2$draw()


## Root X3 < 2.45
## |----L: Leaf 1
## |----R: X3 < 4.75
##      |----L: Leaf 2
##      |----R: Leaf 3


# update this tree
for(i in ind.updt) {
 ort2$update(dat1[i,1:4], dat1[i,5])
}
ort2$draw()


## Root X3 < 2.45
## |----L: Leaf 1
## |----R: X3 < 4.75
##      |----L: Leaf 2
##      |----R: X4 < 2.19
##           |----L: X2 < 3.68
##                |----L: X1 < 7.12
##                     |----L: X3 < 4.06
##                          |----L: Leaf 1
##                          |----R: Leaf 3
##                     |----R: Leaf 3
##                |----R: Leaf 1
##           |----R: Leaf 3


大家都在看

2017年R语言发展报告(国内)

R语言中文社区历史文章整理(作者篇)

R语言中文社区历史文章整理(类型篇)


公众号后台回复关键字即可学习

回复 R                  R语言快速入门及数据挖掘 
回复 Kaggle案例  Kaggle十大案例精讲(连载中)
回复 文本挖掘      手把手教你做文本挖掘
回复 可视化          R语言可视化在商务场景中的应用 
回复 大数据         大数据系列免费视频教程 
回复 量化投资      张丹教你如何用R语言量化投资 
回复 用户画像      京东大数据,揭秘用户画像
回复 数据挖掘     常用数据挖掘算法原理解释与应用
回复 机器学习     人工智能系列之机器学习与实践
回复 爬虫            R语言爬虫实战案例分享

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值