We will show how to grow a tree for classification. Regression trees are developed in a very similar way.
As before, these codes are taken from the Introduction to Statistical Learning book.
April 12, 2019
We will show how to grow a tree for classification. Regression trees are developed in a very similar way.
As before, these codes are taken from the Introduction to Statistical Learning book.
In these data, Sales is a continuous variable, and so we begin by recoding it as a binary variable.
library(tree) library(ISLR) attach(Carseats) High=ifelse(Sales<=8,"No","Yes") Carseats=data.frame(Carseats,High)
tree.carseats=tree(High~.-Sales,Carseats) summary(tree.carseats)
## ## Classification tree: ## tree(formula = High ~ . - Sales, data = Carseats) ## Variables actually used in tree construction: ## [1] "ShelveLoc" "Price" "Income" "CompPrice" "Population" ## [6] "Advertising" "Age" "US" ## Number of terminal nodes: 27 ## Residual mean deviance: 0.4575 = 170.7 / 373 ## Misclassification error rate: 0.09 = 36 / 400
We see that the training error rate is 9%.
For classification trees, the deviance reported in the output of summary() is given by \[ -2 \sum_{m} \sum_k n_{mk} \log(\hat{p}_{mk}) \] where \(n_{mk}\) is the number of observations in the \(m\)th terminal node that belong to the \(k\)th class.
A small deviance indicates a tree that provides a good fit to the (training) data.
The residual mean deviance reported is simply the deviance divided by \(n - |T_0|\), which in this case is 400 - 27 = 373.
plot()
function to display the tree structure, and the text()
function to display the node labels.plot(tree.carseats)
plot(tree.carseats) text(tree.carseats,pretty=0)
Shelving location is the most important predictor here.
plot(tree.carseats) text(tree.carseats,pretty=1, cex = 0.65)
R
prints output corresponding to each branch of the tree.R
displays the split criterion (e.g. Price<92.5
), the number of observations in that branch, the deviance, the overall prediction for the branch (Yes
or No
), and the fraction of observations in that branch that take on values of Yes
and No
.tree.carseats
## node), split, n, deviance, yval, (yprob) ## * denotes terminal node ## ## 1) root 400 541.500 No ( 0.59000 0.41000 ) ## 2) ShelveLoc: Bad,Medium 315 390.600 No ( 0.68889 0.31111 ) ## 4) Price < 92.5 46 56.530 Yes ( 0.30435 0.69565 ) ## 8) Income < 57 10 12.220 No ( 0.70000 0.30000 ) ## 16) CompPrice < 110.5 5 0.000 No ( 1.00000 0.00000 ) * ## 17) CompPrice > 110.5 5 6.730 Yes ( 0.40000 0.60000 ) * ## 9) Income > 57 36 35.470 Yes ( 0.19444 0.80556 ) ## 18) Population < 207.5 16 21.170 Yes ( 0.37500 0.62500 ) * ## 19) Population > 207.5 20 7.941 Yes ( 0.05000 0.95000 ) * ## 5) Price > 92.5 269 299.800 No ( 0.75465 0.24535 ) ## 10) Advertising < 13.5 224 213.200 No ( 0.81696 0.18304 ) ## 20) CompPrice < 124.5 96 44.890 No ( 0.93750 0.06250 ) ## 40) Price < 106.5 38 33.150 No ( 0.84211 0.15789 ) ## 80) Population < 177 12 16.300 No ( 0.58333 0.41667 ) ## 160) Income < 60.5 6 0.000 No ( 1.00000 0.00000 ) * ## 161) Income > 60.5 6 5.407 Yes ( 0.16667 0.83333 ) * ## 81) Population > 177 26 8.477 No ( 0.96154 0.03846 ) * ## 41) Price > 106.5 58 0.000 No ( 1.00000 0.00000 ) * ## 21) CompPrice > 124.5 128 150.200 No ( 0.72656 0.27344 ) ## 42) Price < 122.5 51 70.680 Yes ( 0.49020 0.50980 ) ## 84) ShelveLoc: Bad 11 6.702 No ( 0.90909 0.09091 ) * ## 85) ShelveLoc: Medium 40 52.930 Yes ( 0.37500 0.62500 ) ## 170) Price < 109.5 16 7.481 Yes ( 0.06250 0.93750 ) * ## 171) Price > 109.5 24 32.600 No ( 0.58333 0.41667 ) ## 342) Age < 49.5 13 16.050 Yes ( 0.30769 0.69231 ) * ## 343) Age > 49.5 11 6.702 No ( 0.90909 0.09091 ) * ## 43) Price > 122.5 77 55.540 No ( 0.88312 0.11688 ) ## 86) CompPrice < 147.5 58 17.400 No ( 0.96552 0.03448 ) * ## 87) CompPrice > 147.5 19 25.010 No ( 0.63158 0.36842 ) ## 174) Price < 147 12 16.300 Yes ( 0.41667 0.58333 ) ## 348) CompPrice < 152.5 7 5.742 Yes ( 0.14286 0.85714 ) * ## 349) CompPrice > 152.5 5 5.004 No ( 0.80000 0.20000 ) * ## 175) Price > 147 7 0.000 No ( 1.00000 0.00000 ) * ## 11) Advertising > 13.5 45 61.830 Yes ( 0.44444 0.55556 ) ## 22) Age < 54.5 25 25.020 Yes ( 0.20000 0.80000 ) ## 44) CompPrice < 130.5 14 18.250 Yes ( 0.35714 0.64286 ) ## 88) Income < 100 9 12.370 No ( 0.55556 0.44444 ) * ## 89) Income > 100 5 0.000 Yes ( 0.00000 1.00000 ) * ## 45) CompPrice > 130.5 11 0.000 Yes ( 0.00000 1.00000 ) * ## 23) Age > 54.5 20 22.490 No ( 0.75000 0.25000 ) ## 46) CompPrice < 122.5 10 0.000 No ( 1.00000 0.00000 ) * ## 47) CompPrice > 122.5 10 13.860 No ( 0.50000 0.50000 ) ## 94) Price < 125 5 0.000 Yes ( 0.00000 1.00000 ) * ## 95) Price > 125 5 0.000 No ( 1.00000 0.00000 ) * ## 3) ShelveLoc: Good 85 90.330 Yes ( 0.22353 0.77647 ) ## 6) Price < 135 68 49.260 Yes ( 0.11765 0.88235 ) ## 12) US: No 17 22.070 Yes ( 0.35294 0.64706 ) ## 24) Price < 109 8 0.000 Yes ( 0.00000 1.00000 ) * ## 25) Price > 109 9 11.460 No ( 0.66667 0.33333 ) * ## 13) US: Yes 51 16.880 Yes ( 0.03922 0.96078 ) * ## 7) Price > 135 17 22.070 No ( 0.64706 0.35294 ) ## 14) Income < 46 6 0.000 No ( 1.00000 0.00000 ) * ## 15) Income > 46 11 15.160 Yes ( 0.45455 0.54545 ) *
predict()
function can be used for this purpose.type="class"
instructs R
to return the actual class prediction.set.seed(2) train=sample(1:nrow(Carseats), 200) Carseats.test=Carseats[-train,] High.test=High[-train] tree.carseats=tree(High~.-Sales,Carseats,subset=train) tree.pred=predict(tree.carseats,Carseats.test,type="class") table(tree.pred,High.test)
## High.test ## tree.pred No Yes ## No 86 27 ## Yes 30 57
This approach leads to correct predictions for 0.715 cases.
cv.tree()
performs cross-validation in order to determine the optimal level of tree complexity; cost complexity pruning is used in order to select a sequence of trees for consideration.The 0-1 loss or misclassification error rate: \[ \sum_{m = 1}^{|T|} \sum_{x_i \in R_m} 1(y_i \neq y_{R_m}) \]
Cross-entropy (measures the 'purity' of a leaf)
\[ - \sum_{m = 1}^{|T|} q_m \sum_{k=1}^{K} \hat{p}_{mk} \log(\hat{p}_{mk}) \]
where \(\hat{p}_{mk}\) is the proportion of class \(k\) within \(R_m\), and \(q_m\) is the proportion of samples in \(R_m\).
We use the argument FUN=prune.misclass
in order to indicate that we want the classification error rate to guide the cross-validation and pruning process, rather than the default for the cv.tree()
function, which is deviance.
cv.tree()
function reportsk
corresponds to our \(\alpha\))\[ \text{minimize} \sum_{m = 1}^{|T|} \sum_{x_i \in R_m} 1(y_i \neq y_{R_m}) + \alpha |T| \]
cv.tree()
set.seed(3) cv.carseats=cv.tree(tree.carseats,FUN=prune.misclass) names(cv.carseats)
## [1] "size" "dev" "k" "method"
cv.carseats
## $size ## [1] 19 17 14 13 9 7 3 2 1 ## ## $dev ## [1] 55 55 53 52 50 56 69 65 80 ## ## $k ## [1] -Inf 0.0000000 0.6666667 1.0000000 1.7500000 2.0000000 ## [7] 4.2500000 5.0000000 23.0000000 ## ## $method ## [1] "misclass" ## ## attr(,"class") ## [1] "prune" "tree.sequence"
dev
corresponds to the cross-validation error rate in this instance.par(mfrow=c(1,2)) plot(cv.carseats$size,cv.carseats$dev,type="b") plot(cv.carseats$k,cv.carseats$dev,type="b")
We now apply the prune.misclass()
function in order to prune the tree to obtain the nine-node tree.
prune.carseats=prune.misclass(tree.carseats,best=9) plot(prune.carseats) text(prune.carseats,pretty=0, cex = 0.7)
tree.pred=predict(prune.carseats,Carseats.test,type="class") table(tree.pred,High.test)
## High.test ## tree.pred No Yes ## No 94 24 ## Yes 22 60
Now precentage of correct prediction is 0.77.
k
, e.g. \(k = 15\).prune.carseats=prune.misclass(tree.carseats,best=15) tree.pred=predict(prune.carseats,Carseats.test,type="class") table(tree.pred,High.test)
## High.test ## tree.pred No Yes ## No 86 22 ## Yes 30 62
Percentage of correct predictions will be lower: 0.74
Compare to the best model by CV: 0.77
glm.fits=glm(High~.-Sales,Carseats,family=binomial,subset=train) summary(glm.fits)
## ## Call: ## glm(formula = High ~ . - Sales, family = binomial, data = Carseats, ## subset = train) ## ## Deviance Residuals: ## Min 1Q Median 3Q Max ## -2.51502 -0.27248 -0.05029 0.21633 2.69809 ## ## Coefficients: ## Estimate Std. Error z value Pr(>|z|) ## (Intercept) -0.361952 3.505771 -0.103 0.917769 ## CompPrice 0.148640 0.033357 4.456 8.35e-06 *** ## Income 0.033647 0.010832 3.106 0.001895 ** ## Advertising 0.321493 0.073583 4.369 1.25e-05 *** ## Population -0.001044 0.001773 -0.589 0.555845 ## Price -0.164643 0.028451 -5.787 7.17e-09 *** ## ShelveLocGood 7.802580 1.397498 5.583 2.36e-08 *** ## ShelveLocMedium 3.617614 0.975650 3.708 0.000209 *** ## Age -0.082738 0.019931 -4.151 3.31e-05 *** ## Education -0.177913 0.113880 -1.562 0.118223 ## UrbanYes -0.848113 0.629494 -1.347 0.177886 ## USYes -0.477470 0.725935 -0.658 0.510712 ## --- ## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1 ## ## (Dispersion parameter for binomial family taken to be 1) ## ## Null deviance: 269.205 on 199 degrees of freedom ## Residual deviance: 98.008 on 188 degrees of freedom ## AIC: 122.01 ## ## Number of Fisher Scoring iterations: 7
glm.probs=predict(glm.fits,Carseats.test,type="response") contrasts(as.factor(High))
## Yes ## No 0 ## Yes 1
glm.pred=rep("No",200) glm.pred[glm.probs>.5]="Yes" table(glm.pred,High.test)
## High.test ## glm.pred No Yes ## No 107 9 ## Yes 9 75
mean(glm.pred==High.test)
## [1] 0.91
par(mfrow=c(1,2)) plot(Carseats.test$ShelveLoc,glm.probs); abline(h=0.5) plot(Carseats.test$Price,glm.probs);abline(h=0.5)
Linear Regression: \[ f(X) = \beta_0 + \sum_{j=1}^{p}X_j \beta_j \]
Regression Tree Method: \[ f(X) = \sum_{m=1}^{M} c_m 1_{(X \in R_m)} \] where \(R_1, \ldots, R_m\) is a partition of the \(X\)-space.
Which model is better? It depends on the problem at hand.
If true decision boundary is linear, classical methods work well, if it's non-linear tree-based methods might work better.
tree.car=tree(Sales~.,Carseats,subset=train) summary(tree.car)
## ## Regression tree: ## tree(formula = Sales ~ ., data = Carseats, subset = train) ## Variables actually used in tree construction: ## [1] "High" "ShelveLoc" "CompPrice" "Price" ## Number of terminal nodes: 10 ## Residual mean deviance: 1.798 = 341.6 / 190 ## Distribution of residuals: ## Min. 1st Qu. Median Mean 3rd Qu. Max. ## -3.64200 -0.93070 -0.08071 0.00000 1.10800 3.09100
plot(tree.car, type = "unif") text(tree.car,pretty=0,cex=0.7)
cv.car=cv.tree(tree.car) cv.car
## $size ## [1] 10 9 7 5 4 3 2 1 ## ## $dev ## [1] 569.4074 562.4486 589.3855 589.3579 585.4891 585.6050 603.0508 ## [8] 1585.5328 ## ## $k ## [1] -Inf 17.88206 20.10870 20.40546 35.04030 35.61496 ## [7] 55.20120 1004.92622 ## ## $method ## [1] "deviance" ## ## attr(,"class") ## [1] "prune" "tree.sequence"
plot(cv.car)
plot(cv.car$size,cv.car$dev,type='b')
prune.car=prune.tree(tree.car,best=9) plot(prune.car,type = "unif") text(prune.car,pretty=0,cex=0.7)
yhat=predict(tree.car,newdata=Carseats[-train,]) car.test=Carseats[-train,"Sales"] mean((yhat-car.test)^2)
## [1] 2.577326
plot(yhat,car.test); abline(0,1)
yhat=predict(prune.car,newdata=Carseats[-train,]) car.test=Carseats[-train,"Sales"] mean((yhat-car.test)^2)
## [1] 2.517199
plot(yhat,car.test); abline(0,1)
Here we apply bagging and random forests to the Boston data, using the randomForest
package in R.
library(MASS);data(Boston) str(Boston)
## 'data.frame': 506 obs. of 14 variables: ## $ crim : num 0.00632 0.02731 0.02729 0.03237 0.06905 ... ## $ zn : num 18 0 0 0 0 0 12.5 12.5 12.5 12.5 ... ## $ indus : num 2.31 7.07 7.07 2.18 2.18 2.18 7.87 7.87 7.87 7.87 ... ## $ chas : int 0 0 0 0 0 0 0 0 0 0 ... ## $ nox : num 0.538 0.469 0.469 0.458 0.458 0.458 0.524 0.524 0.524 0.524 ... ## $ rm : num 6.58 6.42 7.18 7 7.15 ... ## $ age : num 65.2 78.9 61.1 45.8 54.2 58.7 66.6 96.1 100 85.9 ... ## $ dis : num 4.09 4.97 4.97 6.06 6.06 ... ## $ rad : int 1 2 2 3 3 3 5 5 5 5 ... ## $ tax : num 296 242 242 222 222 222 311 311 311 311 ... ## $ ptratio: num 15.3 17.8 17.8 18.7 18.7 18.7 15.2 15.2 15.2 15.2 ... ## $ black : num 397 397 393 395 397 ... ## $ lstat : num 4.98 9.14 4.03 2.94 5.33 ... ## $ medv : num 24 21.6 34.7 33.4 36.2 28.7 22.9 27.1 16.5 18.9 ...
randomForest()
can perform both random forests and bagging.library(randomForest) train = sample(1:nrow(Boston), nrow(Boston)/2) boston.test=Boston[-train,"medv"] bag.boston=randomForest(medv~.,data=Boston,subset=train,mtry=13, importance=TRUE)
mtry=13
indicates that all 13 predictors should be considered for each split of the tree.importance=TRUE
: should importance of the predictors be assessed?bag.boston
## ## Call: ## randomForest(formula = medv ~ ., data = Boston, mtry = 13, importance = TRUE, subset = train) ## Type of random forest: regression ## Number of trees: 500 ## No. of variables tried at each split: 13 ## ## Mean of squared residuals: 12.35287 ## % Var explained: 85.91
yhat.bag = predict(bag.boston,newdata=Boston[-train,]) mean((yhat.bag-boston.test)^2)
## [1] 13.67243
plot(yhat.bag, boston.test); abline(0,1)
tree.boston=tree(medv~.,Boston,subset=train) yhat=predict(tree.boston,newdata=Boston[-train,]) boston.test=Boston[-train,"medv"] mean((yhat-boston.test)^2)
## [1] 27.07243
plot(yhat,boston.test);abline(0,1)
ntree = 500
.randomForest()
using the ntree
argument:bag.boston=randomForest(medv~.,data=Boston,subset=train,mtry=13, ntree=25) yhat.bag = predict(bag.boston,newdata=Boston[-train,]) mean((yhat.bag-boston.test)^2)
## [1] 14.35049
mtry = 13
, but the test error has decreased.mtry
argument.randomForest()
uses \(p/3\) variables when building a random forest of regression trees, and \([\sqrt{p}]\) variables when building a random forest of classification trees.mtry = 6
.set.seed(1) rf.boston=randomForest(medv~.,data=Boston,subset=train,mtry=6,importance=TRUE) yhat.rf = predict(rf.boston,newdata=Boston[-train,]) (mse.rf <- mean((yhat.rf-boston.test)^2))
## [1] 12.24709
The test set MSE is 12.2470904: this indicates that random forests yielded an improvement over bagging in this case.
importance(rf.boston)
## %IncMSE IncNodePurity ## crim 7.432964 664.71032 ## zn 2.136489 97.17134 ## indus 9.464938 1102.28543 ## chas 2.640715 53.88275 ## nox 14.970963 1472.53258 ## rm 36.884024 8419.02599 ## age 8.538212 463.79099 ## dis 10.944432 1068.69866 ## rad 5.372191 141.79123 ## tax 7.540827 379.67119 ## ptratio 12.197230 866.22970 ## black 7.824388 318.10503 ## lstat 25.795979 6773.80308
varImpPlot(rf.boston)
ntreeset = seq(10,300,by = 10) mse.bag = rep(0,length(ntreeset));mse.rf = rep(0,length(ntreeset)) for(i in 1:length(ntreeset)){ nt = ntreeset[i] bag.boston=randomForest(medv~.,data=Boston,subset=train,mtry=13,ntree=nt) yhat.bag = predict(bag.boston,newdata=Boston[-train,]) mse.bag[i] = mean((yhat.bag-boston.test)^2) rf.boston=randomForest(medv~.,data=Boston,subset=train,ntree=nt) yhat.bag = predict(rf.boston,newdata=Boston[-train,]) mse.rf[i] = mean((yhat.bag-boston.test)^2) }
plot(ntreeset,mse.bag,type="l",col=2,ylim=c(5,20)) lines(ntreeset,mse.rf,col=3) legend("bottomright",c("Bagging","Random Forest"),col=c(2,3),lty=c(1,1))
ntreeset = seq(10,300,by = 10) mse.bag = rep(0,length(ntreeset));mse.rf1 = rep(0,length(ntreeset)); mse.rf2 = rep(0,length(ntreeset)); for(i in 1:length(ntreeset)){ nt = ntreeset[i] bag.boston=randomForest(medv~.,data=Boston,subset=train,mtry=13,ntree=nt) yhat.bag = predict(bag.boston,newdata=Boston[-train,]) mse.bag[i] = mean((yhat.bag-boston.test)^2) rf.boston=randomForest(medv~.,data=Boston,mtry = 6, subset=train,ntree=nt) yhat.bag = predict(rf.boston,newdata=Boston[-train,]) mse.rf1[i] = mean((yhat.bag-boston.test)^2) rf.boston=randomForest(medv~.,data=Boston,mtry = 4, subset=train,ntree=nt) yhat.bag = predict(rf.boston,newdata=Boston[-train,]) mse.rf2[i] = mean((yhat.bag-boston.test)^2) }
plot(ntreeset,mse.bag,type="l",col=2,ylim=c(5,20)) lines(ntreeset,mse.rf1,col=3); lines(ntreeset,mse.rf2,col=4) legend("bottomright",c("Bagging","Random Forest (m = 6)", "Random Forest (m = 4)"),col=c(2,3,4), lty=c(1,1,1))