Plotting trees from Random Forest models with ggraph
Today, I want to show how I use Thomas Lin Pederson’s awesome ggraph package to plot decision trees from Random Forest models.
I am very much a visual person, so I try to plot as much of my results as possible because it helps me get a better feel for what is going on with my data.
A nice aspect of using tree-based machine learning, like Random Forest models, is that that they are more easily interpreted than e.g. neural networks as they are based on decision trees. So, when I am using such models, I like to plot final decision trees (if they aren’t too large) to get a sense of which decisions are underlying my predictions.
There are a few very convient ways to plot the outcome if you are using the randomForest
package but I like to have as much control as possible about the layout, colors, labels, etc. And because I didn’t find a solution I liked for caret
models, I developed the following little function (below you may find information about how I built the model):
As input, it takes part of the output from model_rf <- caret::train(... "rf" ...)
, that gives the trees of the final model: model_rf$finalModel$forest
. From these trees, you can specify which one to plot by index.
library(dplyr)
library(ggraph)
library(igraph)
tree_func <- function(final_model,
tree_num) {
# get tree by index
tree <- randomForest::getTree(final_model,
k = tree_num,
labelVar = TRUE) %>%
tibble::rownames_to_column() %>%
# make leaf split points to NA, so the 0s won't get plotted
mutate(`split point` = ifelse(is.na(prediction), `split point`, NA))
# prepare data frame for graph
graph_frame <- data.frame(from = rep(tree$rowname, 2),
to = c(tree$`left daughter`, tree$`right daughter`))
# convert to graph and delete the last node that we don't want to plot
graph <- graph_from_data_frame(graph_frame) %>%
delete_vertices("0")
# set node labels
V(graph)$node_label <- gsub("_", " ", as.character(tree$`split var`))
V(graph)$leaf_label <- as.character(tree$prediction)
V(graph)$split <- as.character(round(tree$`split point`, digits = 2))
# plot
plot <- ggraph(graph, 'dendrogram') +
theme_bw() +
geom_edge_link() +
geom_node_point() +
geom_node_text(aes