Tree based Methods - niranjv/ml-notes GitHub Wiki
- Easy interpretability & visualization
- Can handle qualitative predictors without having to create dummy variables
- Predictive performance of a single tree is not as good as other regression or classification methods
- Tree structure is not robust and different training data can result in very different trees (i.e., high variance)
Decision Trees
Used for both regression & classification
Segment covariate space into numerous simple regions; Predict response (class) for a new data point by taking average (majority) value of training data points for the region to which the new data point belongs.
Set of splitting rules can be summarized as a tree, ergo decision tree.
Basic tree-based methods are simple & easily visualized but don't perform as well as other supervised learning methods, so need to do
,random forests
, etc., which use an ensemble of trees to create a consensus prediction. This usually results in much higher accuracy but less interpretability. -
Regions can be any shape, but are usually boxes for simplicity & ease of interpretation
Boxes are determined s.t. RSS is minimized, but not possible to consider every combination of boxes
So follow greedy top-down approach called
recursive binary splitting
Start with the covariate and cut point that minimizes
Split each of the sub-regions as needed using remaining covariates until a stopping criterion is reached (e.g., each terminal node has at most a certain number of points)
Trees that are too complex have low bias but result in overfitting
So grow a large tree and then prune it to get optimal sub-tree. Not practical to check test error in all possible sub-tree since this number is too large.
cost complexity pruning
/weakest link pruning
, which uses a parameter $\alpha$ to control complexity (likelasso
linear regression). -
$\alpha$ is selected using cross-validation
CV error is a good approximation to test error
For classification trees, instead of RSS, use
classification error rate
. But this is not sensitive enough to evaluate quality of split for growing trees. -
So use
index orcross-entropy
(measures of node purity since they will be low when a node contains mostly data from 1 class) -
classification error rate
is best for prediction but all 3 are good for pruning.
- Decision trees have low bias but
high variance
, i.e., tree structure is very sensitive to training data Bootstrap aggregation
is a method to reduce variance of a statistical learning method- To reduce variance / increase predictive accuracy of a method, take many training sets, build a predictive model for each set and average their predictions
- But many training sets are usually not available, so use
samples from single training set instead & average predictions from each sample - Bagging can be used for many regression methods and works well for decision trees. For classification, use majority vote among bootstrapped trees to determine class of new observation.
- With bagging, individual trees based on bootstrap samples are not pruned so they have low bias and high variance
- The number of trees/bootstrapped samples used is arbitrary and does not lead to overfitting
Out-of-Bag Error
- Alternate method to cross-validation / using validation set to estimate test error
- On average, each bagged tree uses 2/3rd of training set data. Remaining 1/3rd are
(OOB) data. - For $i^th$ data point, predict response using only the trees in which this point was not included (OOB). This is OOB response for the point
- Get OOB prediction for all training data points and calculate OOB MSE / classification error
- OOB error is a valid estimate of test error since prediction was based only on trees which did not contain the point being predicted
- With sufficiently large number of trees, OOB is equivalent to leave-one-out cross-validation error
- Using OOB estimate of test error is very useful for large datasets where cross-validation is not practical
Variable Importance
- Disadvantages of bagging - interpretability suffers.
- Cannot represent results of bagging in tree form and difficult to determine which covariates are important.
- Can estimate variable importance using
&Gini index
. - Track average decrease in RSS/Gini index due to splits on a variable over all bagged trees. Large decrease => important variable
Random Forests
Covariates that are strongly associated with the response will dominate tree growth (i.e., will always be picked first for splits), resulting in bagged trees with similar structure
These trees will have correlated predictions and averaging over these will not reduce variance as much as averaging over uncorrelated trees
So variance of prediction from these bagged trees will not be much better than that from a single tree
Random forests
fixes bagging bydecorrelating
trees -
Build number of trees on bootstrapped training samples like in bagging but for each split in a tree, pick a covariate only from a random subset
of allp
covariates to split on -
Pick a different random subset
for each split of each tree. Typicallym = sqrt(p)
for regression &m = p/3
for classification.m = p
=> bagging. -
This method gives all covariates a chance to be used for a split since, on average,
1 - m/p
covariates will not have the dominant covariate -
With dominant covariates, average of decorrelated trees is less variable than that from bagging
Similar to bagging, using large number of trees will not result in overfitting
- Number of trees in forest (
): - Depth of trees (
) (i.e., min samples / leaf): Tune using CV - Number of features to use in each tree (
): Tune using CV. Smallm
will reduce variance of individual tree but increase bias. Also depends on number of noisy features vs. informative features; with more noisy features, smallm
will result in less likelihood of picking an informative feature during a split.
Another approach to improving predictions for statistical learning methods for regression & classification
Instead, boosting grows trees
, using info from previous grown trees -
Boosting does not use bootstrap samples; each tree is fit on a modified version of the original data
Boosting works by fitting each tree to the residuals from the previous tree as the response
The new tree is added to the fitted function and its residuals are used as response for the next tree
Each tree in this sequence can be quite small with only a few terminal nodes. This allows us to improve the fitted function in areas where it is not doing well
A shrinkage parameters slows down the learning process and allows different types of trees to be fit
Boosting parameters are:
- number of trees to fit. LargeB
can result in overfitting.B
is selected with cross-validation- $\lambda$ - shrinkage parameter that controls the learning rate. Usually set to 0.01 or 0.001
- number of splits in the tree; controlsinteraction depth
& complexity of final model.d = 1
=> tree is a stump => final model is an additive model since each term/tree involves only a single covariate. A tree withd
splits can involve at mostd
Because each tree depends on previous trees, smaller trees are usually sufficient and this leads to better interpretability (e.g., using stumps => additive model)