Here we will use the HR churn data (https://www.kaggle.com/)
to present the breakDown package for randomForest
models.
The data is in the breakDown
package
set.seed(1313)
library(breakDown)
head(HR_data, 3)
#> satisfaction_level last_evaluation number_project average_montly_hours
#> 1 0.38 0.53 2 157
#> 2 0.80 0.86 5 262
#> 3 0.11 0.88 7 272
#> time_spend_company Work_accident left promotion_last_5years sales salary
#> 1 3 0 1 0 sales low
#> 2 6 0 1 0 sales medium
#> 3 4 0 1 0 sales medium
Now let’s create a random forest regression model for churn, the
left
variable.
library("randomForest")
#> randomForest 4.7-1.2
#> Type rfNews() to see new features/changes/bug fixes.
#>
#> Attaching package: 'randomForest'
#> The following object is masked from 'package:ggplot2':
#>
#> margin
model <- randomForest(factor(left)~., data = HR_data, family = "binomial", maxnodes = 5)
But how to understand which factors drive predictions for a single observation?
With the breakDown
package!
Explanations for the linear predictor.
library(ggplot2)
predict.function <- function(model, new_observation) predict(model, new_observation, type="prob")[,2]
predict.function(model, HR_data[11,-7])
#> [1] 0.888
explain_1 <- broken(model, HR_data[11,-7], data = HR_data[,-7],
predict.function = predict.function,
direction = "down")
explain_1
#> contribution
#> (Intercept) 0.148
#> - satisfaction_level = 0.45 0.133
#> - number_project = 2 0.201
#> - last_evaluation = 0.54 0.182
#> - average_montly_hours = 135 0.141
#> - time_spend_company = 3 0.068
#> - Work_accident = 0 0.010
#> - salary = low 0.005
#> - sales = sales 0.000
#> - promotion_last_5years = 0 0.000
#> final_prognosis 0.888
#> baseline: 0
plot(explain_1) + ggtitle("breakDown plot (direction=down) for randomForest model")
explain_2 <- broken(model, HR_data[11,-7], data = HR_data[,-7],
predict.function = predict.function,
direction = "up")
explain_2
#> contribution
#> (Intercept) 0.148
#> + satisfaction_level = 0.45 0.133
#> + number_project = 2 0.201
#> + last_evaluation = 0.54 0.182
#> + average_montly_hours = 135 0.141
#> + time_spend_company = 3 0.068
#> + Work_accident = 0 0.010
#> + salary = low 0.005
#> + promotion_last_5years = 0 0.000
#> + sales = sales 0.000
#> final_prognosis 0.888
#> baseline: 0
plot(explain_2) + ggtitle("breakDown plot (direction=up) for randomForest model")