library(tidyverse)
library(tidymodels)
library(ISLR2)
library(rpart.plot)
library(partykit)
library(rattle)
#install.packages('ranger')
Chapter 8 Part 2
Decision Trees
Setup
Graphs
Plotting decision trees
There are several R packages that assist with tree plotting
rpart.plot
partykit
rattle
Where to find out more about packages
- Vignettes
- Journal Article (R Journal is great)
- Rstudio Community
- StackOverflow
rpart.plot
We’re going to focus on rpart.plot
, but feel free to try the others!
install.packages("rpart.plot")
library(rpart.plot)
rpart.plot
<- decision_tree(
tree_spec cost_complexity = 0.1,
tree_depth = 10,
mode = "regression") |>
set_engine("rpart")
<- workflow() |>
wf add_recipe(
recipe(Salary ~ Hits + Years + PutOuts + RBI + Walks + Runs,
data = baseball)
|>
) add_model(tree_spec)
<- fit(wf, baseball) model
. . .
rpart.plot
rpart.plot(model$fit$fit$fit,
roundint = FALSE)
Classification Trees
Classification Trees
Very similar to regression trees except it is used to predict a qualitative response rather than a quantitative one
We predict that each observation belongs to the most commonly occuring class of the training observations in a given region
Fitting classification trees
We use recursive binary splitting to grow the tree
Instead of RSS, we can use:
Gini index: \(G = \sum_{k=1}^K \hat{p}_{mk}(1-\hat{p}_{mk})\)
This is a measure of total variance across the \(K\) classes. If all of the \(\hat{p}_{mk}\) values are close to zero or one, this will be small
The Gini index is a measure of node purity small values indicate that node contains predominantly observations from a single class
In
R
, this can be estimated using thegain_capture()
function.
Classification tree - Heart Disease Example
- Classifying whether 303 patients have heart disease based on 13 predictors (
Age
,Sex
,Chol
, etc)
1. Split the data into a cross-validation set
<- vfold_cv(heart, v = 5) heart_cv
How many folds do I have?
2. Create a model specification that tunes based on complexity, \(\alpha\)
3. Fit the model on the cross validation set
. . .
What \(\alpha\)s am I trying?
5. Choose \(\alpha\) that minimizes the Gini Index
<- model %>%
best select_best(metric = "gain_capture")
6. Fit the final model
<- wf |>
final_wf finalize_workflow(best)
<- fit(final_wf, data = heart) final_model
7. Examine how the final model does on the full sample
%>%
final_model predict(new_data = heart) %>%
bind_cols(heart) %>%
conf_mat(truth = HD, estimate = .pred_class) %>%
autoplot(type = "heatmap")
Decision trees
Pros
- simple
- easy to interpret
Cons
- not often competitive in terms of predictive accuracy
- Next we will discuss how to combine multiple trees to improve accuracy
Try Classification Trees
- Fit a classification tree to predict species in the
penguins
data from thepalmerpenguins
packages.
Bagging
Bagging
bagging is a general-purpose procedure for reducing the variance of a statistical learning method (outside of just trees)
It is particularly useful and frequently used in the context of decision trees
Also called bootstrap aggregation
Bagging
Mathematically, why does this work? Let’s go back to intro to stat!
If you have a set of \(n\) independent observations: \(Z_1, \dots, Z_n\), each with a variance of \(\sigma^2\), what would the variance of the mean, \(\bar{Z}\) be?
The variance of \(\bar{Z}\) is \(\sigma^2/n\)
In other words, averaging a set of observations reduces the variance.
This is generally not practical because we generally do not have multiple training sets
Bagging
Averaging a set of observations reduces the variance. This is generally not practical because we generally do not have multiple training sets.
What can we do?
- Bootstrap! We can take repeated samples from the single training data set.
Bagging process
generate \(B\) different bootstrapped training sets
Train our method on the \(b\)th bootstrapped training set to get \(\hat{f}^{*b}(x)\), the prediction at point \(x\)
Average all predictions to get: \(\hat{f}_{bag}(x)=\frac{1}{B}\sum_{b=1}^B\hat{f}^{*b}(x)\)
This is bagging!
Bagging regression trees
- generate \(B\) different bootstrapped training sets
- Fit a regression tree on the \(b\)th bootstrapped training set to get \(\hat{f}^{*b}(x)\), the prediction at point \(x\)
- Average all predictions to get: \(\hat{f}_{bag}(x)=\frac{1}{B}\sum_{b=1}^B\hat{f}^{*b}(x)\)
Bagging classification trees
for each test observation, record the class predicted by the \(B\) trees
Take a majority vote - the overall prediction is the most commonly occuring class among the \(B\) predictions
Out-of-bag Error Estimation
You can estimate the test error of a bagged model
The key to bagging is that trees are repeatedly fit to bootstrapped subsets of the observations
On average, each bagged tree makes use of about 2/3 of the observations (you can prove this if you’d like!, not required for this course though)
The remaining 1/3 of observations not used to fit a given bagged tree are the out-of-bag (OOB) observations
Out-of-bag Error Estimation
You can predict the response for the \(i\)th observation using each of the trees in which that observation was OOB
How many predictions do you think this will yield for the \(i\)th observation?
This will yield \(B/3\) predictions for the \(i\)th observations. We can average this!
This estimate is essentially the LOOCV error for bagging as long as \(B\) is large 🎉
Bagging (vs Boosting) Video
Describing Bagging
See if you can draw a diagram to describe the bagging process to someone who has never heard of this before.
05:00
Random Forests
Do you ❤️ all of the tree puns?
If we are using bootstrap samples, how similar do you think the trees will be?
. . .
Random forests provide an improvement over bagged trees by way of a small tweak that decorrelates the trees
By decorrelating the trees, this reduces the variance even more when we average the trees!
Random Forest process
Like bagging, build a number of decision trees on bootstrapped training samples
Each time the tree is split, instead of considering all predictors (like bagging), a random selection of \(m\) predictors is chosen as split candidates from the full set of \(p\) predictors
The split is allowed to use only one of those \(m\) predictors
A fresh selection of \(m\) predictors is taken at each split
typically we choose \(m \approx \sqrt{p}\)
Choosing m for Random Forest
Let’s say you have a dataset with 100 observations and 9 variables, if you were fitting a random forest, what would a good \(m\) be?
The heart disease example
We are predicting whether a patient has heart disease from 13 predictors
1. Randomly divide the data in half, 149 training observations, 148 testing
set.seed(77)
<- initial_split(heart, prop = 0.5)
heart_split <- training(heart_split) heart_train
2. Create model specification
<- rand_forest(
model_spec mode = "classification",
mtry = ---
|>
) set_engine("ranger")
. . .
mtry here is m. If we are doing bagging what do you think we set this to?
2. Create bagging specification
<- rand_forest(
bagging_spec mode = "classification",
mtry = 13 #<<
|>
) set_engine("ranger")
What would we change mtry to if we are doing a random forest?
2. Create Random Forest specification
<- rand_forest(
rf_spec mode = "classification",
mtry = 3 #<<
|>
) set_engine("ranger")
- The default for
rand_forest
isfloor(sqrt(# predictors))
(so 3 in this case)
3. Create the workflow
<- workflow() |>
wf add_recipe(
recipe(
~ Age + Sex + ChestPain + RestBP + Chol + Fbs +
HD + MaxHR + ExAng + Oldpeak + Slope + Ca + Thal,
RestECG data = heart_train
)|>
) add_model(rf_spec)
4. Fit the model
<- fit(wf, data = heart_train) model
5. Examine how it looks in the test data
<- testing(heart_split)
heart_test |>
model predict(new_data = heart_test) |>
bind_cols(heart_test) |>
conf_mat(truth = HD, estimate = .pred_class) |>
autoplot(type = "heatmap")
Trade Off
What is our final tree?
. . .
With both bagging and random forests, we have traded interpretability with performance.
These approaches will predict better but we no longer have a single represenation fo the tree.
Even if we wanted to pick the best performing tree, it may have a different subset of variables than other similar trees.
Application Exercise
- Open your last application exercise
- Refit your model as a bagged tree and a random forest