Decision Trees
Cornell College
STA 362 Spring 2024 Block 8
There are several R packages that assist with tree plotting
rpart.plot
partykit
rattle
rpart.plot
We’re going to focus on rpart.plot
, but feel free to try the others!
rpart.plot
rpart.plot
Very similar to regression trees except it is used to predict a qualitative response rather than a quantitative one
We predict that each observation belongs to the most commonly occuring class of the training observations in a given region
We use recursive binary splitting to grow the tree
Instead of RSS, we can use:
Gini index: \(G = \sum_{k=1}^K \hat{p}_{mk}(1-\hat{p}_{mk})\)
This is a measure of total variance across the \(K\) classes. If all of the \(\hat{p}_{mk}\) values are close to zero or one, this will be small
The Gini index is a measure of node purity small values indicate that node contains predominantly observations from a single class
In R
, this can be estimated using the gain_capture()
function.
Age
, Sex
, Chol
, etc)How many folds do I have?
What \(\alpha\)s am I trying?
Try Classification Trees
penguins
data from the palmerpenguins
packages.bagging is a general-purpose procedure for reducing the variance of a statistical learning method (outside of just trees)
It is particularly useful and frequently used in the context of decision trees
Also called bootstrap aggregation
Mathematically, why does this work? Let’s go back to intro to stat!
If you have a set of \(n\) independent observations: \(Z_1, \dots, Z_n\), each with a variance of \(\sigma^2\), what would the variance of the mean, \(\bar{Z}\) be?
The variance of \(\bar{Z}\) is \(\sigma^2/n\)
In other words, averaging a set of observations reduces the variance.
This is generally not practical because we generally do not have multiple training sets
Averaging a set of observations reduces the variance. This is generally not practical because we generally do not have multiple training sets.
What can we do?
generate \(B\) different bootstrapped training sets
Train our method on the \(b\)th bootstrapped training set to get \(\hat{f}^{*b}(x)\), the prediction at point \(x\)
Average all predictions to get: \(\hat{f}_{bag}(x)=\frac{1}{B}\sum_{b=1}^B\hat{f}^{*b}(x)\)
This is bagging!
for each test observation, record the class predicted by the \(B\) trees
Take a majority vote - the overall prediction is the most commonly occuring class among the \(B\) predictions
You can estimate the test error of a bagged model
The key to bagging is that trees are repeatedly fit to bootstrapped subsets of the observations
On average, each bagged tree makes use of about 2/3 of the observations (you can prove this if you’d like!, not required for this course though)
The remaining 1/3 of observations not used to fit a given bagged tree are the out-of-bag (OOB) observations
You can predict the response for the \(i\)th observation using each of the trees in which that observation was OOB
How many predictions do you think this will yield for the \(i\)th observation?
This will yield \(B/3\) predictions for the \(i\)th observations. We can average this!
This estimate is essentially the LOOCV error for bagging as long as \(B\) is large 🎉
Describing Bagging
See if you can draw a diagram to describe the bagging process to someone who has never heard of this before.
05:00
Do you ❤️ all of the tree puns?
If we are using bootstrap samples, how similar do you think the trees will be?
Random forests provide an improvement over bagged trees by way of a small tweak that decorrelates the trees
By decorrelating the trees, this reduces the variance even more when we average the trees!
Like bagging, build a number of decision trees on bootstrapped training samples
Each time the tree is split, instead of considering all predictors (like bagging), a random selection of \(m\) predictors is chosen as split candidates from the full set of \(p\) predictors
The split is allowed to use only one of those \(m\) predictors
A fresh selection of \(m\) predictors is taken at each split
typically we choose \(m \approx \sqrt{p}\)
Choosing m for Random Forest
Let’s say you have a dataset with 100 observations and 9 variables, if you were fitting a random forest, what would a good \(m\) be?
We are predicting whether a patient has heart disease from 13 predictors
mtry here is m. If we are doing bagging what do you think we set this to?
What would we change mtry to if we are doing a random forest?
rand_forest
is floor(sqrt(# predictors))
(so 3 in this case)What is our final tree?
With both bagging and random forests, we have traded interpretability with performance.
These approaches will predict better but we no longer have a single represenation fo the tree.
Even if we wanted to pick the best performing tree, it may have a different subset of variables than other similar trees.
Application Exercise