Chapter 8 Part 2

Decision Trees

Tyler George

Cornell College
STA 362 Spring 2024 Block 8

Setup

library(tidyverse)
library(tidymodels)
library(ISLR2)
library(rpart.plot)
library(partykit)
library(rattle)
#install.packages('ranger')

Graphs

Plotting decision trees

There are several R packages that assist with tree plotting

  • rpart.plot
  • partykit
  • rattle

Where to find out more about packages

  1. Vignettes
  2. Journal Article (R Journal is great)
  3. Rstudio Community
  4. StackOverflow
  5. Twitter

rpart.plot

We’re going to focus on rpart.plot, but feel free to try the others!

install.packages("rpart.plot")
library(rpart.plot)

rpart.plot

tree_spec <- decision_tree(
  cost_complexity = 0.1,
  tree_depth = 10,
  mode = "regression") |>
  set_engine("rpart")

wf <- workflow() |>
  add_recipe(
    recipe(Salary ~ Hits + Years + PutOuts + RBI + Walks + Runs,
                   data = baseball)
  ) |>
  add_model(tree_spec)

model <- fit(wf, baseball)
rpart.plot(model$fit$fit$fit,
           roundint = FALSE)

rpart.plot

rpart.plot(model$fit$fit$fit, 
           roundint = FALSE)

Classification Trees

Classification Trees

  • Very similar to regression trees except it is used to predict a qualitative response rather than a quantitative one

  • We predict that each observation belongs to the most commonly occuring class of the training observations in a given region

Fitting classification trees

  • We use recursive binary splitting to grow the tree

  • Instead of RSS, we can use:

  • Gini index: \(G = \sum_{k=1}^K \hat{p}_{mk}(1-\hat{p}_{mk})\)

  • This is a measure of total variance across the \(K\) classes. If all of the \(\hat{p}_{mk}\) values are close to zero or one, this will be small

  • The Gini index is a measure of node purity small values indicate that node contains predominantly observations from a single class

  • In R, this can be estimated using the gain_capture() function.

Classification tree - Heart Disease Example

  • Classifying whether 303 patients have heart disease based on 13 predictors (Age, Sex, Chol, etc)

1. Split the data into a cross-validation set

heart_cv <- vfold_cv(heart, v = 5)

How many folds do I have?

2. Create a model specification that tunes based on complexity, \(\alpha\)

tree_spec <- decision_tree(
  cost_complexity = tune(), 
  tree_depth = 10,
  mode = "classification") %>% 
  set_engine("rpart")

wf <- workflow() |>
  add_recipe(
    recipe(HD ~ Age + Sex + ChestPain + RestBP + Chol + Fbs + 
                     RestECG + MaxHR + ExAng + Oldpeak + Slope + Ca,
    data = heart
    )
  ) |>
  add_model(tree_spec)

3. Fit the model on the cross validation set

grid <- expand_grid(cost_complexity = seq(0.01, 0.05, by = 0.01))
model <- tune_grid(wf,
                   grid = grid,
                   resamples = heart_cv,
                   metrics = metric_set(gain_capture, accuracy)) 

What \(\alpha\)s am I trying?

5. Choose \(\alpha\) that minimizes the Gini Index

best <- model %>%
  select_best(metric = "gain_capture")

6. Fit the final model

final_wf <- wf |>
  finalize_workflow(best)

final_model <- fit(final_wf, data = heart)

7. Examine how the final model does on the full sample

final_model %>%
  predict(new_data = heart) %>%
  bind_cols(heart) %>%
  conf_mat(truth = HD, estimate = .pred_class) %>%
  autoplot(type = "heatmap")

Decision trees

Pros

  • simple
  • easy to interpret

Cons

  • not often competitive in terms of predictive accuracy
  • Next we will discuss how to combine multiple trees to improve accuracy

Try Classification Trees

  • Fit a classification tree to predict species in the penguins data from the palmerpenguins packages.

Bagging

Bagging

  • bagging is a general-purpose procedure for reducing the variance of a statistical learning method (outside of just trees)

  • It is particularly useful and frequently used in the context of decision trees

  • Also called bootstrap aggregation

Bagging

  • Mathematically, why does this work? Let’s go back to intro to stat!

  • If you have a set of \(n\) independent observations: \(Z_1, \dots, Z_n\), each with a variance of \(\sigma^2\), what would the variance of the mean, \(\bar{Z}\) be?

  • The variance of \(\bar{Z}\) is \(\sigma^2/n\)

  • In other words, averaging a set of observations reduces the variance.

  • This is generally not practical because we generally do not have multiple training sets

Bagging

Averaging a set of observations reduces the variance. This is generally not practical because we generally do not have multiple training sets.

What can we do?

  • Bootstrap! We can take repeated samples from the single training data set.

Bagging process

  • generate \(B\) different bootstrapped training sets

  • Train our method on the \(b\)th bootstrapped training set to get \(\hat{f}^{*b}(x)\), the prediction at point \(x\)

  • Average all predictions to get: \(\hat{f}_{bag}(x)=\frac{1}{B}\sum_{b=1}^B\hat{f}^{*b}(x)\)

  • This is bagging!

Bagging regression trees

  • generate \(B\) different bootstrapped training sets
  • Fit a regression tree on the \(b\)th bootstrapped training set to get \(\hat{f}^{*b}(x)\), the prediction at point \(x\)
  • Average all predictions to get: \(\hat{f}_{bag}(x)=\frac{1}{B}\sum_{b=1}^B\hat{f}^{*b}(x)\)

Bagging classification trees

  • for each test observation, record the class predicted by the \(B\) trees

  • Take a majority vote - the overall prediction is the most commonly occuring class among the \(B\) predictions

Out-of-bag Error Estimation

  • You can estimate the test error of a bagged model

  • The key to bagging is that trees are repeatedly fit to bootstrapped subsets of the observations

  • On average, each bagged tree makes use of about 2/3 of the observations (you can prove this if you’d like!, not required for this course though)

  • The remaining 1/3 of observations not used to fit a given bagged tree are the out-of-bag (OOB) observations

Out-of-bag Error Estimation

You can predict the response for the \(i\)th observation using each of the trees in which that observation was OOB

How many predictions do you think this will yield for the \(i\)th observation?

  • This will yield \(B/3\) predictions for the \(i\)th observations. We can average this!

  • This estimate is essentially the LOOCV error for bagging as long as \(B\) is large 🎉

Bagging (vs Boosting) Video

https://www.youtube.com/watch?v=tjy0yL1rRRU&t=4s

Describing Bagging

See if you can draw a diagram to describe the bagging process to someone who has never heard of this before.

05:00

Random Forests

Do you ❤️ all of the tree puns?

If we are using bootstrap samples, how similar do you think the trees will be?

  • Random forests provide an improvement over bagged trees by way of a small tweak that decorrelates the trees

  • By decorrelating the trees, this reduces the variance even more when we average the trees!

Random Forest process

  • Like bagging, build a number of decision trees on bootstrapped training samples

  • Each time the tree is split, instead of considering all predictors (like bagging), a random selection of \(m\) predictors is chosen as split candidates from the full set of \(p\) predictors

  • The split is allowed to use only one of those \(m\) predictors

  • A fresh selection of \(m\) predictors is taken at each split

  • typically we choose \(m \approx \sqrt{p}\)

Choosing m for Random Forest

Let’s say you have a dataset with 100 observations and 9 variables, if you were fitting a random forest, what would a good \(m\) be?

The heart disease example

We are predicting whether a patient has heart disease from 13 predictors

1. Randomly divide the data in half, 149 training observations, 148 testing

set.seed(77)
heart_split <- initial_split(heart, prop = 0.5)
heart_train <- training(heart_split)

2. Create model specification

model_spec <- rand_forest(
  mode = "classification",
  mtry = ---
) |> 
  set_engine("ranger")

mtry here is m. If we are doing bagging what do you think we set this to?

2. Create bagging specification

bagging_spec <- rand_forest(
  mode = "classification",
  mtry = 13 #<<
) |> 
  set_engine("ranger")

What would we change mtry to if we are doing a random forest?

2. Create Random Forest specification

rf_spec <- rand_forest(
  mode = "classification",
  mtry = 3 #<<
) |> 
  set_engine("ranger")
  • The default for rand_forest is floor(sqrt(# predictors)) (so 3 in this case)

3. Create the workflow

wf <- workflow() |>
  add_recipe(
    recipe(
      HD ~ Age + Sex + ChestPain + RestBP + Chol + Fbs + 
               RestECG + MaxHR + ExAng + Oldpeak + Slope + Ca + Thal,
             data = heart_train
    )
  ) |>
  add_model(rf_spec)

4. Fit the model

model <- fit(wf, data = heart_train)

5. Examine how it looks in the test data

heart_test <- testing(heart_split)
model |>
  predict(new_data = heart_test) |>
  bind_cols(heart_test) |>
  conf_mat(truth = HD, estimate = .pred_class) |>
  autoplot(type = "heatmap")

Trade Off

What is our final tree?

  • With both bagging and random forests, we have traded interpretability with performance.

  • These approaches will predict better but we no longer have a single represenation fo the tree.

  • Even if we wanted to pick the best performing tree, it may have a different subset of variables than other similar trees.

Application Exercise

  • Open your last application exercise
  • Refit your model as a bagged tree and a random forest