April 12, 2019

Tree Based Methods

We will show how to grow a tree for classification. Regression trees are developed in a very similar way.

As before, these codes are taken from the Introduction to Statistical Learning book.

Carseats data

In these data, Sales is a continuous variable, and so we begin by recoding it as a binary variable.

library(tree)
library(ISLR)
attach(Carseats)
High=ifelse(Sales<=8,"No","Yes")
Carseats=data.frame(Carseats,High)

Fit a tree

tree.carseats=tree(High~.-Sales,Carseats)
summary(tree.carseats)
## 
## Classification tree:
## tree(formula = High ~ . - Sales, data = Carseats)
## Variables actually used in tree construction:
## [1] "ShelveLoc"   "Price"       "Income"      "CompPrice"   "Population" 
## [6] "Advertising" "Age"         "US"         
## Number of terminal nodes:  27 
## Residual mean deviance:  0.4575 = 170.7 / 373 
## Misclassification error rate: 0.09 = 36 / 400

Interpretaion

We see that the training error rate is 9%.

For classification trees, the deviance reported in the output of summary() is given by \[ -2 \sum_{m} \sum_k n_{mk} \log(\hat{p}_{mk}) \] where \(n_{mk}\) is the number of observations in the \(m\)th terminal node that belong to the \(k\)th class.

A small deviance indicates a tree that provides a good fit to the (training) data.

The residual mean deviance reported is simply the deviance divided by \(n - |T_0|\), which in this case is 400 - 27 = 373.

Plotting the tree

  • One of the most attractive properties of trees is that they can be graphically displayed.
  • We use the plot() function to display the tree structure, and the text() function to display the node labels.

Plot tree

plot(tree.carseats)

Plot tree

plot(tree.carseats)
text(tree.carseats,pretty=0)

Plot tree

Shelving location is the most important predictor here.

plot(tree.carseats)
text(tree.carseats,pretty=1, cex = 0.65)

Tree output

  • If we just type the name of the tree object, R prints output corresponding to each branch of the tree.
  • R displays the split criterion (e.g. Price<92.5), the number of observations in that branch, the deviance, the overall prediction for the branch (Yes or No), and the fraction of observations in that branch that take on values of Yes and No.

Tree output in R

tree.carseats
## node), split, n, deviance, yval, (yprob)
##       * denotes terminal node
## 
##   1) root 400 541.500 No ( 0.59000 0.41000 )  
##     2) ShelveLoc: Bad,Medium 315 390.600 No ( 0.68889 0.31111 )  
##       4) Price < 92.5 46  56.530 Yes ( 0.30435 0.69565 )  
##         8) Income < 57 10  12.220 No ( 0.70000 0.30000 )  
##          16) CompPrice < 110.5 5   0.000 No ( 1.00000 0.00000 ) *
##          17) CompPrice > 110.5 5   6.730 Yes ( 0.40000 0.60000 ) *
##         9) Income > 57 36  35.470 Yes ( 0.19444 0.80556 )  
##          18) Population < 207.5 16  21.170 Yes ( 0.37500 0.62500 ) *
##          19) Population > 207.5 20   7.941 Yes ( 0.05000 0.95000 ) *
##       5) Price > 92.5 269 299.800 No ( 0.75465 0.24535 )  
##        10) Advertising < 13.5 224 213.200 No ( 0.81696 0.18304 )  
##          20) CompPrice < 124.5 96  44.890 No ( 0.93750 0.06250 )  
##            40) Price < 106.5 38  33.150 No ( 0.84211 0.15789 )  
##              80) Population < 177 12  16.300 No ( 0.58333 0.41667 )  
##               160) Income < 60.5 6   0.000 No ( 1.00000 0.00000 ) *
##               161) Income > 60.5 6   5.407 Yes ( 0.16667 0.83333 ) *
##              81) Population > 177 26   8.477 No ( 0.96154 0.03846 ) *
##            41) Price > 106.5 58   0.000 No ( 1.00000 0.00000 ) *
##          21) CompPrice > 124.5 128 150.200 No ( 0.72656 0.27344 )  
##            42) Price < 122.5 51  70.680 Yes ( 0.49020 0.50980 )  
##              84) ShelveLoc: Bad 11   6.702 No ( 0.90909 0.09091 ) *
##              85) ShelveLoc: Medium 40  52.930 Yes ( 0.37500 0.62500 )  
##               170) Price < 109.5 16   7.481 Yes ( 0.06250 0.93750 ) *
##               171) Price > 109.5 24  32.600 No ( 0.58333 0.41667 )  
##                 342) Age < 49.5 13  16.050 Yes ( 0.30769 0.69231 ) *
##                 343) Age > 49.5 11   6.702 No ( 0.90909 0.09091 ) *
##            43) Price > 122.5 77  55.540 No ( 0.88312 0.11688 )  
##              86) CompPrice < 147.5 58  17.400 No ( 0.96552 0.03448 ) *
##              87) CompPrice > 147.5 19  25.010 No ( 0.63158 0.36842 )  
##               174) Price < 147 12  16.300 Yes ( 0.41667 0.58333 )  
##                 348) CompPrice < 152.5 7   5.742 Yes ( 0.14286 0.85714 ) *
##                 349) CompPrice > 152.5 5   5.004 No ( 0.80000 0.20000 ) *
##               175) Price > 147 7   0.000 No ( 1.00000 0.00000 ) *
##        11) Advertising > 13.5 45  61.830 Yes ( 0.44444 0.55556 )  
##          22) Age < 54.5 25  25.020 Yes ( 0.20000 0.80000 )  
##            44) CompPrice < 130.5 14  18.250 Yes ( 0.35714 0.64286 )  
##              88) Income < 100 9  12.370 No ( 0.55556 0.44444 ) *
##              89) Income > 100 5   0.000 Yes ( 0.00000 1.00000 ) *
##            45) CompPrice > 130.5 11   0.000 Yes ( 0.00000 1.00000 ) *
##          23) Age > 54.5 20  22.490 No ( 0.75000 0.25000 )  
##            46) CompPrice < 122.5 10   0.000 No ( 1.00000 0.00000 ) *
##            47) CompPrice > 122.5 10  13.860 No ( 0.50000 0.50000 )  
##              94) Price < 125 5   0.000 Yes ( 0.00000 1.00000 ) *
##              95) Price > 125 5   0.000 No ( 1.00000 0.00000 ) *
##     3) ShelveLoc: Good 85  90.330 Yes ( 0.22353 0.77647 )  
##       6) Price < 135 68  49.260 Yes ( 0.11765 0.88235 )  
##        12) US: No 17  22.070 Yes ( 0.35294 0.64706 )  
##          24) Price < 109 8   0.000 Yes ( 0.00000 1.00000 ) *
##          25) Price > 109 9  11.460 No ( 0.66667 0.33333 ) *
##        13) US: Yes 51  16.880 Yes ( 0.03922 0.96078 ) *
##       7) Price > 135 17  22.070 No ( 0.64706 0.35294 )  
##        14) Income < 46 6   0.000 No ( 1.00000 0.00000 ) *
##        15) Income > 46 11  15.160 Yes ( 0.45455 0.54545 ) *

Training and Testing

  • Need test error for proper evaluation.
  • Split the observations into a training set and a test set, build the tree using the training set, and evaluate its performance on the test data.
  • The predict() function can be used for this purpose.
  • In the case of a classification tree, the argument type="class" instructs R to return the actual class prediction.

Test error

set.seed(2)
train=sample(1:nrow(Carseats), 200)
Carseats.test=Carseats[-train,]
High.test=High[-train]
tree.carseats=tree(High~.-Sales,Carseats,subset=train)
tree.pred=predict(tree.carseats,Carseats.test,type="class")
table(tree.pred,High.test)
##          High.test
## tree.pred No Yes
##       No  86  27
##       Yes 30  57

This approach leads to correct predictions for 0.715 cases.

Pruning

  • What you just saw was the full tree.
  • Next, we consider whether pruning the tree might lead to improved results.
  • The function cv.tree() performs cross-validation in order to determine the optimal level of tree complexity; cost complexity pruning is used in order to select a sequence of trees for consideration.

Recap (Classification Losses)

  1. The 0-1 loss or misclassification error rate: \[ \sum_{m = 1}^{|T|} \sum_{x_i \in R_m} 1(y_i \neq y_{R_m}) \]

  2. Cross-entropy (measures the 'purity' of a leaf)

\[ - \sum_{m = 1}^{|T|} q_m \sum_{k=1}^{K} \hat{p}_{mk} \log(\hat{p}_{mk}) \]

where \(\hat{p}_{mk}\) is the proportion of class \(k\) within \(R_m\), and \(q_m\) is the proportion of samples in \(R_m\).

  • Use cross-entropy for growing the tree, while using the misclassification rate when pruning the tree.

Pruning (continued)

  • We use the argument FUN=prune.misclass in order to indicate that we want the classification error rate to guide the cross-validation and pruning process, rather than the default for the cv.tree() function, which is deviance.

  • The cv.tree() function reports
  1. the number of terminal nodes of each tree considered (size),
  2. the corresponding error rate, and
  3. the value of the cost-complexity parameter used (k corresponds to our \(\alpha\))

Cost-complexity pruning

  • Cost-complexity pruning is similar to regularization idea. You put a penalty for the size of the tree - larger trees are penalized more.

\[ \text{minimize} \sum_{m = 1}^{|T|} \sum_{x_i \in R_m} 1(y_i \neq y_{R_m}) + \alpha |T| \]

  • Just like regularized methods, \(\alpha\) is a tuning parameter - we can choose a suitable value by cross-validation.

Using cv.tree()

set.seed(3)
cv.carseats=cv.tree(tree.carseats,FUN=prune.misclass)
names(cv.carseats)
## [1] "size"   "dev"    "k"      "method"
cv.carseats
## $size
## [1] 19 17 14 13  9  7  3  2  1
## 
## $dev
## [1] 55 55 53 52 50 56 69 65 80
## 
## $k
## [1]       -Inf  0.0000000  0.6666667  1.0000000  1.7500000  2.0000000
## [7]  4.2500000  5.0000000 23.0000000
## 
## $method
## [1] "misclass"
## 
## attr(,"class")
## [1] "prune"         "tree.sequence"

Output explanation

  • Note that, despite the name, dev corresponds to the cross-validation error rate in this instance.
  • The tree with 9 terminal nodes results in the lowest cross-validation error rate, with 50 cross-validation errors.

Plotting CV errors

par(mfrow=c(1,2))
plot(cv.carseats$size,cv.carseats$dev,type="b")
plot(cv.carseats$k,cv.carseats$dev,type="b")

Prune the tree

We now apply the prune.misclass() function in order to prune the tree to obtain the nine-node tree.

prune.carseats=prune.misclass(tree.carseats,best=9)
plot(prune.carseats)
text(prune.carseats,pretty=0, cex = 0.7)

Pruned tree

  • How does the pruned tree do in terms of prediction?
tree.pred=predict(prune.carseats,Carseats.test,type="class")
table(tree.pred,High.test)
##          High.test
## tree.pred No Yes
##       No  94  24
##       Yes 22  60

Now precentage of correct prediction is 0.77.

Different k

  • You can choose a different k, e.g. \(k = 15\).
prune.carseats=prune.misclass(tree.carseats,best=15)
tree.pred=predict(prune.carseats,Carseats.test,type="class")
table(tree.pred,High.test)
##          High.test
## tree.pred No Yes
##       No  86  22
##       Yes 30  62
  • Percentage of correct predictions will be lower: 0.74

  • Compare to the best model by CV: 0.77

How does this compare to Logistic?

glm.fits=glm(High~.-Sales,Carseats,family=binomial,subset=train)
summary(glm.fits)
## 
## Call:
## glm(formula = High ~ . - Sales, family = binomial, data = Carseats, 
##     subset = train)
## 
## Deviance Residuals: 
##      Min        1Q    Median        3Q       Max  
## -2.51502  -0.27248  -0.05029   0.21633   2.69809  
## 
## Coefficients:
##                  Estimate Std. Error z value Pr(>|z|)    
## (Intercept)     -0.361952   3.505771  -0.103 0.917769    
## CompPrice        0.148640   0.033357   4.456 8.35e-06 ***
## Income           0.033647   0.010832   3.106 0.001895 ** 
## Advertising      0.321493   0.073583   4.369 1.25e-05 ***
## Population      -0.001044   0.001773  -0.589 0.555845    
## Price           -0.164643   0.028451  -5.787 7.17e-09 ***
## ShelveLocGood    7.802580   1.397498   5.583 2.36e-08 ***
## ShelveLocMedium  3.617614   0.975650   3.708 0.000209 ***
## Age             -0.082738   0.019931  -4.151 3.31e-05 ***
## Education       -0.177913   0.113880  -1.562 0.118223    
## UrbanYes        -0.848113   0.629494  -1.347 0.177886    
## USYes           -0.477470   0.725935  -0.658 0.510712    
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## (Dispersion parameter for binomial family taken to be 1)
## 
##     Null deviance: 269.205  on 199  degrees of freedom
## Residual deviance:  98.008  on 188  degrees of freedom
## AIC: 122.01
## 
## Number of Fisher Scoring iterations: 7

Prediction

glm.probs=predict(glm.fits,Carseats.test,type="response")
contrasts(as.factor(High))
##     Yes
## No    0
## Yes   1
glm.pred=rep("No",200)
glm.pred[glm.probs>.5]="Yes"
table(glm.pred,High.test)
##         High.test
## glm.pred  No Yes
##      No  107   9
##      Yes   9  75
mean(glm.pred==High.test)
## [1] 0.91

Significant predictors drive the prediction.

par(mfrow=c(1,2))
plot(Carseats.test$ShelveLoc,glm.probs); abline(h=0.5)
plot(Carseats.test$Price,glm.probs);abline(h=0.5)

Exercise (not HW)

  1. Plot the ROC curve for both logistic and decision tree (with pruning).
  2. Also apply k-NN with CV and LDA on the same data-set and compare accuracies.
  3. Can you create an artificial data-set where decision tree will do better than logistic or LDA?

Linear Regression vs. Tree

  • Linear Regression: \[ f(X) = \beta_0 + \sum_{j=1}^{p}X_j \beta_j \]

  • Regression Tree Method: \[ f(X) = \sum_{m=1}^{M} c_m 1_{(X \in R_m)} \] where \(R_1, \ldots, R_m\) is a partition of the \(X\)-space.

  • Which model is better? It depends on the problem at hand.

Which model is better?

If true decision boundary is linear, classical methods work well, if it's non-linear tree-based methods might work better.

Regression

tree.car=tree(Sales~.,Carseats,subset=train)
summary(tree.car)
## 
## Regression tree:
## tree(formula = Sales ~ ., data = Carseats, subset = train)
## Variables actually used in tree construction:
## [1] "High"      "ShelveLoc" "CompPrice" "Price"    
## Number of terminal nodes:  10 
## Residual mean deviance:  1.798 = 341.6 / 190 
## Distribution of residuals:
##     Min.  1st Qu.   Median     Mean  3rd Qu.     Max. 
## -3.64200 -0.93070 -0.08071  0.00000  1.10800  3.09100

Plot Regression Tree

plot(tree.car, type = "unif")
text(tree.car,pretty=0,cex=0.7)

CV for Regression Tree

cv.car=cv.tree(tree.car)
cv.car
## $size
## [1] 10  9  7  5  4  3  2  1
## 
## $dev
## [1]  569.4074  562.4486  589.3855  589.3579  585.4891  585.6050  603.0508
## [8] 1585.5328
## 
## $k
## [1]       -Inf   17.88206   20.10870   20.40546   35.04030   35.61496
## [7]   55.20120 1004.92622
## 
## $method
## [1] "deviance"
## 
## attr(,"class")
## [1] "prune"         "tree.sequence"

Plot the CV error

plot(cv.car)

Manually plot the deviance

plot(cv.car$size,cv.car$dev,type='b')

Prune the Regression Tree

prune.car=prune.tree(tree.car,best=9)
plot(prune.car,type = "unif")
text(prune.car,pretty=0,cex=0.7)

MSE with full tree

yhat=predict(tree.car,newdata=Carseats[-train,])
car.test=Carseats[-train,"Sales"]
mean((yhat-car.test)^2)
## [1] 2.577326
plot(yhat,car.test); abline(0,1)

MSE with Pruned Tree

yhat=predict(prune.car,newdata=Carseats[-train,])
car.test=Carseats[-train,"Sales"]
mean((yhat-car.test)^2)
## [1] 2.517199
plot(yhat,car.test); abline(0,1)

Next: Bagging

Bagging and Random Forest

Here we apply bagging and random forests to the Boston data, using the randomForest package in R.

Random Forest

  • Ensemble Learning: Generae many predictors and aggregate their results.
  • Two well-known methods: Boosting (Shapire et al. , 1998) and Bagging (Breiman, 1996)
  • In boosting, successive trees give extra weight to points incorrectly predicted by earlier predictors. In the end, a weighted vote is taken for prediction.
  • In bagging, successive trees do not depend on earlier trees - each is independently constructed using a bootstrap sample of the data set.

We use the Boston housing data

library(MASS);data(Boston)
str(Boston)
## 'data.frame':    506 obs. of  14 variables:
##  $ crim   : num  0.00632 0.02731 0.02729 0.03237 0.06905 ...
##  $ zn     : num  18 0 0 0 0 0 12.5 12.5 12.5 12.5 ...
##  $ indus  : num  2.31 7.07 7.07 2.18 2.18 2.18 7.87 7.87 7.87 7.87 ...
##  $ chas   : int  0 0 0 0 0 0 0 0 0 0 ...
##  $ nox    : num  0.538 0.469 0.469 0.458 0.458 0.458 0.524 0.524 0.524 0.524 ...
##  $ rm     : num  6.58 6.42 7.18 7 7.15 ...
##  $ age    : num  65.2 78.9 61.1 45.8 54.2 58.7 66.6 96.1 100 85.9 ...
##  $ dis    : num  4.09 4.97 4.97 6.06 6.06 ...
##  $ rad    : int  1 2 2 3 3 3 5 5 5 5 ...
##  $ tax    : num  296 242 242 222 222 222 311 311 311 311 ...
##  $ ptratio: num  15.3 17.8 17.8 18.7 18.7 18.7 15.2 15.2 15.2 15.2 ...
##  $ black  : num  397 397 393 395 397 ...
##  $ lstat  : num  4.98 9.14 4.03 2.94 5.33 ...
##  $ medv   : num  24 21.6 34.7 33.4 36.2 28.7 22.9 27.1 16.5 18.9 ...

Bagging

  • Bagging is a special case of a random forest with \(m = p\).
  • randomForest() can perform both random forests and bagging.
library(randomForest)
train = sample(1:nrow(Boston), nrow(Boston)/2)
boston.test=Boston[-train,"medv"]
bag.boston=randomForest(medv~.,data=Boston,subset=train,mtry=13,
                        importance=TRUE)
  • The argument mtry=13 indicates that all 13 predictors should be considered for each split of the tree.
  • In other words, bagging should be done.
  • importance=TRUE: should importance of the predictors be assessed?

Bagging output

bag.boston
## 
## Call:
##  randomForest(formula = medv ~ ., data = Boston, mtry = 13, importance = TRUE,      subset = train) 
##                Type of random forest: regression
##                      Number of trees: 500
## No. of variables tried at each split: 13
## 
##           Mean of squared residuals: 12.35287
##                     % Var explained: 85.91

Test error

yhat.bag = predict(bag.boston,newdata=Boston[-train,])
mean((yhat.bag-boston.test)^2)
## [1] 13.67243
plot(yhat.bag, boston.test); abline(0,1)

A Single Tree

tree.boston=tree(medv~.,Boston,subset=train)
yhat=predict(tree.boston,newdata=Boston[-train,])
boston.test=Boston[-train,"medv"]
mean((yhat-boston.test)^2)
## [1] 27.07243
plot(yhat,boston.test);abline(0,1)

Number of Trees

  • By default ntree = 500.
  • We could change the number of trees grown by randomForest() using the ntree argument:
bag.boston=randomForest(medv~.,data=Boston,subset=train,mtry=13,
                        ntree=25)
yhat.bag = predict(bag.boston,newdata=Boston[-train,])
mean((yhat.bag-boston.test)^2)
## [1] 14.35049
  • This is still bagging as mtry = 13, but the test error has decreased.

Random Forest

  • Growing a random forest proceeds in exactly the same way, except that we use a smaller value of the mtry argument.
  • By default, randomForest() uses \(p/3\) variables when building a random forest of regression trees, and \([\sqrt{p}]\) variables when building a random forest of classification trees.
  • Here we use mtry = 6.

Random Forest

set.seed(1)
rf.boston=randomForest(medv~.,data=Boston,subset=train,mtry=6,importance=TRUE)
yhat.rf = predict(rf.boston,newdata=Boston[-train,])
(mse.rf <- mean((yhat.rf-boston.test)^2))
## [1] 12.24709

The test set MSE is 12.2470904: this indicates that random forests yielded an improvement over bagging in this case.

Importance

  • Using the importance() function, we can view the importance of each variable.
importance(rf.boston)
##           %IncMSE IncNodePurity
## crim     7.432964     664.71032
## zn       2.136489      97.17134
## indus    9.464938    1102.28543
## chas     2.640715      53.88275
## nox     14.970963    1472.53258
## rm      36.884024    8419.02599
## age      8.538212     463.79099
## dis     10.944432    1068.69866
## rad      5.372191     141.79123
## tax      7.540827     379.67119
## ptratio 12.197230     866.22970
## black    7.824388     318.10503
## lstat   25.795979    6773.80308
  1. Mean decrease of accuracy in predictions on the out of bag samples when a given variable is excluded from the model
  2. Measure of the total decrease in node impurity that results from splits over that variable, averaged over all trees.

Importance Plot

varImpPlot(rf.boston)

Compare Bagging and RF

ntreeset = seq(10,300,by = 10)
mse.bag = rep(0,length(ntreeset));mse.rf = rep(0,length(ntreeset))

for(i in 1:length(ntreeset)){
  nt = ntreeset[i]
  bag.boston=randomForest(medv~.,data=Boston,subset=train,mtry=13,ntree=nt)
  yhat.bag = predict(bag.boston,newdata=Boston[-train,])
  mse.bag[i] = mean((yhat.bag-boston.test)^2)
  
  rf.boston=randomForest(medv~.,data=Boston,subset=train,ntree=nt)
  yhat.bag = predict(rf.boston,newdata=Boston[-train,])
  mse.rf[i] = mean((yhat.bag-boston.test)^2)
}

Plot the test error

plot(ntreeset,mse.bag,type="l",col=2,ylim=c(5,20))
lines(ntreeset,mse.rf,col=3)
legend("bottomright",c("Bagging","Random Forest"),col=c(2,3),lty=c(1,1))

Try three different sizes

ntreeset = seq(10,300,by = 10)
mse.bag = rep(0,length(ntreeset));mse.rf1 = rep(0,length(ntreeset));
mse.rf2 = rep(0,length(ntreeset));

for(i in 1:length(ntreeset)){
  nt = ntreeset[i]
  bag.boston=randomForest(medv~.,data=Boston,subset=train,mtry=13,ntree=nt)
  yhat.bag = predict(bag.boston,newdata=Boston[-train,])
  mse.bag[i] = mean((yhat.bag-boston.test)^2)
  
  rf.boston=randomForest(medv~.,data=Boston,mtry = 6, subset=train,ntree=nt)
  yhat.bag = predict(rf.boston,newdata=Boston[-train,])
  mse.rf1[i] = mean((yhat.bag-boston.test)^2)
  
  rf.boston=randomForest(medv~.,data=Boston,mtry = 4, subset=train,ntree=nt)
  yhat.bag = predict(rf.boston,newdata=Boston[-train,])
  mse.rf2[i] = mean((yhat.bag-boston.test)^2)
}

and plot the test error

plot(ntreeset,mse.bag,type="l",col=2,ylim=c(5,20))
lines(ntreeset,mse.rf1,col=3); lines(ntreeset,mse.rf2,col=4)
legend("bottomright",c("Bagging","Random Forest (m = 6)",
                       "Random Forest (m = 4)"),col=c(2,3,4), lty=c(1,1,1))