Skip to content
Yu liu edited this page Feb 23, 2023 · 5 revisions

ODRF

ODRF

CRAN status R-CMD-check pkgdown Lifecycle: stable

ODRF implements the well-known Oblique Decision Tree (ODT) and ODT-based Random Forest (ODRF), which uses linear combinations of predictors as partitioning variables for both traditional CART and Random Forest. A number of modifications have been adopted in the implementation; some new functions are also provided.

Overview

The ODRF R package consists of the following main functions:

  • ODT() classification and regression using an ODT in which each node is split by a linear combination of predictors.
  • ODRF() classification and regression implemented by the ODRF It’s an extension of random forest based on ODT() and includes random forest as a special case.
  • Online() online training to update existing ODT and ODRF by using new data sets.
  • prune() prune ODT from bottom to top with validation data based on prediction error.
  • print(), predict() and plot() the base R functions in the base R Package to class ODT and ODRF.

ODRF allows users to define their own functions to find the projections at each node, which is essential to the performance of the forests. We also provide a complete comparison and analysis for other ODT and ODRF, more details are available in vignette(“ODRF”).

Installation

You can install the development version of ODRF from GitHub with:

# install.packages("devtools")
devtools::install_github("liuyu-star/ODRF")

Usage

We show how to use the ODRF package with examples.

Classification and regression with functions ODT() and ODRF()

Classification with Oblique Decision Randome Forest.

library(ODRF)
#> Loading required package: partykit
#> Loading required package: grid
#> Loading required package: libcoin
#> Loading required package: mvtnorm
data(seeds, package = "ODRF")
set.seed(19)
train <- sample(1:209, 150)
train_data <- data.frame(seeds[train, ])
test_data <- data.frame(seeds[-train, ])
index <- seq(floor(1*nrow(train_data) / 2))

forest <- ODRF(varieties_of_wheat ~ ., train_data,
  split = "gini", parallel = FALSE
)
pred <- predict(forest, test_data[, -8])
e.forest <- mean(pred != test_data[, 8])
forest1 <- ODRF(varieties_of_wheat ~ ., train_data[index, ],
  split = "gini", parallel = FALSE
)
pred <- predict(forest1, test_data[, -8])
e.forest.1 <- mean(pred != test_data[, 8])
forest2 <- ODRF(varieties_of_wheat ~ ., train_data[-index, ],
  split = "gini", parallel = FALSE
)
pred <- predict(forest2, test_data[, -8])
e.forest.2 <- mean(pred != test_data[, 8])

forest.online <- online(
  forest1, train_data[-index, -8],
  train_data[-index, 8]
)
pred <- predict(forest.online, test_data[, -8])
e.forest.online <- mean(pred != test_data[, 8])
forest.prune <- prune(forest1, train_data[-index, -8],
  train_data[-index, 8],
  useOOB = FALSE
)
pred <- predict(forest.prune, test_data[, -8])
e.forest.prune <- mean(pred != test_data[, 8])
print(c(
  forest = e.forest, forest1 = e.forest.1, forest2 = e.forest.2,
  forest.online = e.forest.online, forest.prune = e.forest.prune
))
#>        forest       forest1       forest2 forest.online  forest.prune 
#>    0.10169492    0.10169492    0.10169492    0.06779661    0.10169492

Regression with Oblique Decision Tree.

data(body_fat, package = "ODRF")
set.seed(42)
train <- sample(1:252, 150)
train_data <- data.frame(body_fat[train, ])
test_data <- data.frame(body_fat[-train, ])
index <- seq(floor(1*nrow(train_data) / 2))

tree <- ODT(Density ~ ., train_data, split = "mse")
pred <- predict(tree, test_data[, -1])
e.tree <- mean((pred - test_data[, 1])^2)
tree1 <- ODT(Density ~ ., train_data[index, ], split = "mse")
pred <- predict(tree1, test_data[, -1])
e.tree.1 <- mean((pred - test_data[, 1])^2)
tree2 <- ODT(Density ~ ., train_data[-index, ], split = "mse")
pred <- predict(tree2, test_data[, -1])
e.tree.2 <- mean((pred - test_data[, 1])^2)

tree.online <- online(tree1, train_data[-index, -1], train_data[-index, 1])
pred <- predict(tree.online, test_data[, -1])
e.tree.online <- mean((pred - test_data[, 1])^2)
tree.prune <- prune(tree1, train_data[-index, -1], train_data[-index, 1])
pred <- predict(tree.prune, test_data[, -1])
e.tree.prune <- mean((pred - test_data[, 1])^2)
print(c(
  tree = e.tree, tree1 = e.tree.1, tree2 = e.tree.2,
  tree.online = e.tree.online, tree.prune = e.tree.prune
))
#>         tree        tree1        tree2  tree.online   tree.prune 
#> 2.619833e-05 4.165811e-05 7.627646e-05 3.634177e-05 4.165811e-05

As shown in the classification and regression results above, the training data train_data is divided into two batches equally, then the first batch is used to train ODT and ODRF, and the second batch is used to update the model by online(). The error after the model update is significantly smaller than that of one batch of data alone.

Print the tree structure of class ODT and the model estimation error of class ODRF

data(iris, package = "datasets")
tree <- ODT(Species ~ ., data = iris)
#> Warning in ODT.compute(formula, Call, varName, X, y, split, lambda,
#> NodeRotateFun, : You are creating a forest for classification
tree
#> ============================================================= 
#> Oblique Classification Tree structure 
#> =============================================================
#> 
#> 1) root
#>    node2)# proj1*X < 0.29 -> (leaf1 = setosa)
#>    node3)  proj1*X >= 0.29
#>       node4)# proj2*X < 0.88 -> (leaf2 = versicolor)
#>       node5)# proj2*X >= 0.88 -> (leaf3 = virginica)
forest <- ODRF(Species ~ ., data = iris, parallel = FALSE)
#> Warning in ODRF.compute(formula, Call, varName, X, y, split, lambda,
#> NodeRotateFun, : You are creating a forest for classification
forest
#> 
#> Call:
#>  ODRF.formula(formula = Species ~ ., data = data, parallel = FALSE) 
#>                Type of oblique decision random forest: classification
#>                                       Number of trees: 100
#>                            OOB estimate of error rate: 4%
#> Confusion matrix:
#>            setosa versicolor virginica class_error
#> setosa         50          0         0  0.00000000
#> versicolor      0         47         3  0.05999988
#> virginica       0          3        47  0.05999988

Plot the tree structure of class ODT

plot(tree)

Getting help

If you encounter a clear bug, please file an issue with a minimal reproducible example on GitHub.


Please note that this project is released with a Contributor Code of Conduct. By participating in this project you agree to abide by its terms.