Tree based Methods - niranjv/ml-notes GitHub Wiki
Overview
Advantages
- Easy interpretability & visualization
- Can handle qualitative predictors without having to create dummy variables
Disadvantages
- Predictive performance of a single tree is not as good as other regression or classification methods
- Tree structure is not robust and different training data can result in very different trees (i.e., high variance)
Decision Trees
-
Used for both regression & classification
-
Segment covariate space into numerous simple regions; Predict response (class) for a new data point by taking average (majority) value of training data points for the region to which the new data point belongs.
-
Set of splitting rules can be summarized as a tree, ergo decision tree.
-
Basic tree-based methods are simple & easily visualized but don't perform as well as other supervised learning methods, so need to do
bagging
,boosting
,random forests
, etc., which use an ensemble of trees to create a consensus prediction. This usually results in much higher accuracy but less interpretability. -
Regions can be any shape, but are usually boxes for simplicity & ease of interpretation
-
Boxes are determined s.t. RSS is minimized, but not possible to consider every combination of boxes
-
So follow greedy top-down approach called
recursive binary splitting
-
Start with the covariate and cut point that minimizes
RSS
-
Split each of the sub-regions as needed using remaining covariates until a stopping criterion is reached (e.g., each terminal node has at most a certain number of points)
-
Trees that are too complex have low bias but result in overfitting
-
So grow a large tree and then prune it to get optimal sub-tree. Not practical to check test error in all possible sub-tree since this number is too large.
-
Use
cost complexity pruning
/weakest link pruning
, which uses a parameter $\alpha$ to control complexity (likelasso
linear regression). -
$\alpha$ is selected using cross-validation
-
CV error is a good approximation to test error
-
For classification trees, instead of RSS, use
classification error rate
. But this is not sensitive enough to evaluate quality of split for growing trees. -
So use
Gini
index orcross-entropy
(measures of node purity since they will be low when a node contains mostly data from 1 class) -
classification error rate
is best for prediction but all 3 are good for pruning.
Bagging
- Decision trees have low bias but
high variance
, i.e., tree structure is very sensitive to training data Bootstrap aggregation
orbagging
is a method to reduce variance of a statistical learning method- To reduce variance / increase predictive accuracy of a method, take many training sets, build a predictive model for each set and average their predictions
- But many training sets are usually not available, so use
bootstrap
samples from single training set instead & average predictions from each sample - Bagging can be used for many regression methods and works well for decision trees. For classification, use majority vote among bootstrapped trees to determine class of new observation.
- With bagging, individual trees based on bootstrap samples are not pruned so they have low bias and high variance
- The number of trees/bootstrapped samples used is arbitrary and does not lead to overfitting
Out-of-Bag Error
- Alternate method to cross-validation / using validation set to estimate test error
- On average, each bagged tree uses 2/3rd of training set data. Remaining 1/3rd are
out-of-bag
(OOB) data. - For $i^th$ data point, predict response using only the trees in which this point was not included (OOB). This is OOB response for the point
- Get OOB prediction for all training data points and calculate OOB MSE / classification error
- OOB error is a valid estimate of test error since prediction was based only on trees which did not contain the point being predicted
- With sufficiently large number of trees, OOB is equivalent to leave-one-out cross-validation error
- Using OOB estimate of test error is very useful for large datasets where cross-validation is not practical
Variable Importance
- Disadvantages of bagging - interpretability suffers.
- Cannot represent results of bagging in tree form and difficult to determine which covariates are important.
- Can estimate variable importance using
RSS
&Gini index
. - Track average decrease in RSS/Gini index due to splits on a variable over all bagged trees. Large decrease => important variable
Random Forests
-
Covariates that are strongly associated with the response will dominate tree growth (i.e., will always be picked first for splits), resulting in bagged trees with similar structure
-
These trees will have correlated predictions and averaging over these will not reduce variance as much as averaging over uncorrelated trees
-
So variance of prediction from these bagged trees will not be much better than that from a single tree
-
Random forests
fixes bagging bydecorrelating
trees -
Build number of trees on bootstrapped training samples like in bagging but for each split in a tree, pick a covariate only from a random subset
m
of allp
covariates to split on -
Pick a different random subset
m
ofp
for each split of each tree. Typicallym = sqrt(p)
for regression &m = p/3
for classification.m = p
=> bagging. -
This method gives all covariates a chance to be used for a split since, on average,
1 - m/p
covariates will not have the dominant covariate -
With dominant covariates, average of decorrelated trees is less variable than that from bagging
-
Similar to bagging, using large number of trees will not result in overfitting
Parameters
- Number of trees in forest (
B
): - Depth of trees (
d
) (i.e., min samples / leaf): Tune using CV - Number of features to use in each tree (
m
): Tune using CV. Smallm
will reduce variance of individual tree but increase bias. Also depends on number of noisy features vs. informative features; with more noisy features, smallm
will result in less likelihood of picking an informative feature during a split.
Boosting
-
Another approach to improving predictions for statistical learning methods for regression & classification
-
Instead, boosting grows trees
sequentially
, using info from previous grown trees -
Boosting does not use bootstrap samples; each tree is fit on a modified version of the original data
-
Boosting works by fitting each tree to the residuals from the previous tree as the response
-
The new tree is added to the fitted function and its residuals are used as response for the next tree
-
Each tree in this sequence can be quite small with only a few terminal nodes. This allows us to improve the fitted function in areas where it is not doing well
-
A shrinkage parameters slows down the learning process and allows different types of trees to be fit
-
Boosting parameters are:
B
- number of trees to fit. LargeB
can result in overfitting.B
is selected with cross-validation- $\lambda$ - shrinkage parameter that controls the learning rate. Usually set to 0.01 or 0.001
d
- number of splits in the tree; controlsinteraction depth
& complexity of final model.d = 1
=> tree is a stump => final model is an additive model since each term/tree involves only a single covariate. A tree withd
splits can involve at mostd
covariates.
-
Because each tree depends on previous trees, smaller trees are usually sufficient and this leads to better interpretability (e.g., using stumps => additive model)