randomForest() model
Highlights & Limitations
- Uses the
randomForest::getTree()
to parse each decision path. - In-line functions in the formulas are not supported:
- OK -
wt ~ mpg + am
- OK -
mutate(mtcars, newam = paste0(am))
and thenwt ~ mpg + newam
- Not OK -
wt ~ mpg + as.factor(am)
- Not OK -
wt ~ mpg + as.character(am)
- OK -
- Interval functions are not supported:
tidypredict_interval()
&tidypredict_sql_interval()
How it works
Here is a simple randomForest()
model using the iris
dataset:
library(randomForest)
model <- randomForest(Species ~ .,data = iris ,ntree = 100, proximity = TRUE)
The SQL translations returns a single SQL CASE WHEN
operation. Each decision path is a WHEN
statement.
library(tidypredict)
tidypredict_sql(model, dbplyr::simulate_mssql())
## <SQL> CASE
## WHEN (((`Petal.Length`) <= 2.5)) THEN ('setosa')
## WHEN ((((`Petal.Length`) > 5.05) AND ((`Petal.Length`) > 2.5))) THEN ('virginica')
## WHEN (((((`Petal.Width`) > 1.9) AND ((`Petal.Length`) > 2.5)) AND ((`Petal.Length`) <= 5.05))) THEN ('virginica')
## WHEN ((((((`Petal.Length`) > 2.5) AND ((`Sepal.Length`) <= 4.95)) AND ((`Petal.Width`) <= 1.9)) AND ((`Petal.Length`) <= 5.05))) THEN ('virginica')
## WHEN (((((((`Sepal.Length`) > 4.95) AND ((`Petal.Length`) > 2.5)) AND ((`Petal.Width`) <= 1.75)) AND ((`Petal.Width`) <= 1.9)) AND ((`Petal.Length`) <= 5.05))) THEN ('versicolor')
## WHEN ((((((((`Petal.Width`) > 1.75) AND ((`Sepal.Length`) > 4.95)) AND ((`Petal.Length`) > 2.5)) AND ((`Sepal.Width`) <= 3.0)) AND ((`Petal.Width`) <= 1.9)) AND ((`Petal.Length`) <= 5.05))) THEN ('virginica')
## WHEN ((((((((`Sepal.Width`) > 3.0) AND ((`Petal.Width`) > 1.75)) AND ((`Sepal.Length`) > 4.95)) AND ((`Petal.Length`) > 2.5)) AND ((`Petal.Width`) <= 1.9)) AND ((`Petal.Length`) <= 5.05))) THEN ('versicolor')
## END
Alternatively, use tidypredict_to_column()
if the results are the be used or previewed in dplyr
.
iris %>%
tidypredict_to_column(model) %>%
head(10)
## Sepal.Length Sepal.Width Petal.Length Petal.Width Species fit
## 1 5.1 3.5 1.4 0.2 setosa setosa
## 2 4.9 3.0 1.4 0.2 setosa setosa
## 3 4.7 3.2 1.3 0.2 setosa setosa
## 4 4.6 3.1 1.5 0.2 setosa setosa
## 5 5.0 3.6 1.4 0.2 setosa setosa
## 6 5.4 3.9 1.7 0.4 setosa setosa
## 7 4.6 3.4 1.4 0.3 setosa setosa
## 8 5.0 3.4 1.5 0.2 setosa setosa
## 9 4.4 2.9 1.4 0.2 setosa setosa
## 10 4.9 3.1 1.5 0.1 setosa setosa
Under the hood
The parser is based on the output from the randomForest::getTree()
function. It will return as many decision paths as there are non-NA rows in the prediction
field.
getTree(model, labelVar = TRUE) %>%
head()
## left daughter right daughter split var split point status prediction
## 1 2 3 Petal.Length 2.50 1 <NA>
## 2 0 0 <NA> 0.00 -1 setosa
## 3 4 5 Petal.Length 5.05 1 <NA>
## 4 6 7 Petal.Width 1.90 1 <NA>
## 5 0 0 <NA> 0.00 -1 virginica
## 6 8 9 Sepal.Length 4.95 1 <NA>
The parsed model contains one row for each path. The field
, operator
and split_point
field list every step in a concatenated character variable.
parse_model(model)
## # A tibble: 8 x 7
## labels vals type estimate field operator split_point
## <chr> <chr> <chr> <dbl> <chr> <chr> <chr>
## 1 path-1 setosa path 0 Petal.Len~ left 2.5
## 2 path-2 virginica path 0 Petal.Len~ right{:}r~ 5.05{:}2.5
## 3 path-3 virginica path 0 Petal.Wid~ right{:}l~ 1.9{:}5.05{~
## 4 path-4 virginica path 0 Sepal.Len~ left{:}le~ 4.95{:}1.9{~
## 5 path-5 versicolor path 0 Petal.Wid~ left{:}ri~ 1.75{:}4.95~
## 6 path-6 virginica path 0 Sepal.Wid~ left{:}ri~ 3{:}1.75{:}~
## 7 path-7 versicolor path 0 Sepal.Wid~ right{:}r~ 3{:}1.75{:}~
## 8 model randomForest variable NA <NA> <NA> <NA>
The output from parse_model()
is transformed into a dplyr
, a.k.a Tidy Eval, formula. The entire decision tree becomes one dplyr::case_when()
statement
tidypredict_fit(model)
## case_when(((Petal.Length) <= 2.5) ~ "setosa", (((Petal.Length) >
## 5.05) & ((Petal.Length) > 2.5)) ~ "virginica", ((((Petal.Width) >
## 1.9) & ((Petal.Length) > 2.5)) & ((Petal.Length) <= 5.05)) ~
## "virginica", (((((Petal.Length) > 2.5) & ((Sepal.Length) <=
## 4.95)) & ((Petal.Width) <= 1.9)) & ((Petal.Length) <= 5.05)) ~
## "virginica", ((((((Sepal.Length) > 4.95) & ((Petal.Length) >
## 2.5)) & ((Petal.Width) <= 1.75)) & ((Petal.Width) <= 1.9)) &
## ((Petal.Length) <= 5.05)) ~ "versicolor", (((((((Petal.Width) >
## 1.75) & ((Sepal.Length) > 4.95)) & ((Petal.Length) > 2.5)) &
## ((Sepal.Width) <= 3)) & ((Petal.Width) <= 1.9)) & ((Petal.Length) <=
## 5.05)) ~ "virginica", (((((((Sepal.Width) > 3) & ((Petal.Width) >
## 1.75)) & ((Sepal.Length) > 4.95)) & ((Petal.Length) > 2.5)) &
## ((Petal.Width) <= 1.9)) & ((Petal.Length) <= 5.05)) ~ "versicolor")
From there, the Tidy Eval formula can be used anywhere where it can be operated. tidypredict
provides three paths:
- Use directly inside
dplyr
,mutate(iris, !! tidypredict_fit(model))
- Use
tidypredict_to_column(model)
to a piped command set - Use
tidypredict_to_sql(model)
to retrieve the SQL statement
How it performs
Currently, the formula matches 147 out of 150 prediction of the test model. The threshold
in tidypredict_test()
is a integer indicating the number of records are OK to be different than the baseline prediction that the predict()
function returns.
test <- tidypredict_test(model, iris, threshold = 5)
test
## tidypredict test results
##
## Success, test is under the set threshold of: 5
## Predictions that did not match predict(): 3
test$raw_results %>%
filter(predict != tidypredict)
## Sepal.Length Sepal.Width Petal.Length Petal.Width Species predict
## 1 4.9 2.4 3.3 1.0 versicolor versicolor
## 2 6.0 2.7 5.1 1.6 versicolor versicolor
## 3 6.0 2.2 5.0 1.5 virginica virginica
## tidypredict
## 1 virginica
## 2 virginica
## 3 versicolor