14  SHAP Results

Objectives

This chapter presents the results from Chapter 13.

library(tidyverse)
── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
✔ dplyr     1.1.4     ✔ readr     2.1.5
✔ forcats   1.0.0     ✔ stringr   1.5.1
✔ ggplot2   3.5.1     ✔ tibble    3.2.1
✔ lubridate 1.9.3     ✔ tidyr     1.3.1
✔ purrr     1.0.2     
── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
✖ dplyr::filter() masks stats::filter()
✖ dplyr::lag()    masks stats::lag()
ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(caret)
Loading required package: lattice

Attaching package: 'caret'

The following object is masked from 'package:purrr':

    lift
library(ranger)
library(treeshap)

Let us define a theme for graphs.

Definition of theme_paper()
#' Theme for ggplot2
#'
#' @param ... arguments passed to the theme function
#' @export
#' @importFrom ggplot2 element_rect element_text element_blank element_line unit
#'   rel
theme_paper <- function (...) {
  ggthemes::theme_base() +
    theme(
      legend.background = element_rect(
        fill = "transparent", linetype="solid", colour ="black"),
      legend.position = "bottom",
      legend.direction = "horizontal",
      legend.box = "horizontal",
      legend.key = element_blank()
    )
}

14.1 Load data and estimated models

14.1.1 Data

Let us load the train and test data:

load("../data/out/df_train_mhi3.rda")
load("../data/out/df_test_mhi3.rda")

For random forests, we need to convert categorical variables to numerical variables:

df_train_num <-
  df_train |>
  mutate(across(where(is.factor), as.numeric))

df_test_num <- 
  df_test |>
  mutate(across(where(is.factor), as.numeric))

We keep track of the levels of factor data:

corresp_factors <- df_train |> 
  select_if(is.factor) |> 
  map(levels)

14.1.2 Classifiers

The estimated models (see ?sec-estimations-mhi):

# The estimated random forests (and the grid)
load("../data/out/estim/v3/grid_search_rf_mhi-3.rda")
# The estimated xgb (and the grid)
load("../data/out/estim/v3/grid_search_xgb_mhi-3.rda")

Let us get the best classifiers obtained both for the random forest and XGBoost. Let us begin with random forests:

ind_remove <- which(
  colnames(df_train) %in% c("status", "id", "PERSONNE_statut")
)
formula <- as.formula(
  paste("status ~", paste(colnames(df_train[-ind_remove]), collapse = "+"))
)

final_model_rf <- ranger(
  formula,
  data = df_train_num, 
  mtry = grid_search_rf$bestTune$mtry,
  splitrule = "gini",
  min.node.size = grid_search_rf$bestTune$min.node.size,
  classification = T
)

And for the extreme gradient boosting algorithm:

final_model_xgb <- grid_search_xgb$finalModel

14.1.3 SHAP Values

We load the results estimated SHAP values (see Chapter 13):

load("../data/out/treeSHAP/v3/treeshap_rf-mhi3.rda")
load("../data/out/treeSHAP/v3/treeshap_xgb-mhi3.rda")

The explainer used only the values from the train set to estimate the SHAP values:

reference_data_rf <- df_train_num
reference_data_xgb <- df_train

And the SHAP values were estimated on both the train set and the test set:

df_all_rf <- bind_rows(df_train_num, df_test_num)
df_all_xgb <- bind_rows(df_train, df_test)

14.2 Predicted Values (classifier)

Let us get the predicted values for both classifiers.

14.2.1 Random Forest

pred_val_ref_rf <- predict(final_model_rf, reference_data_rf)$predictions
head(pred_val_ref_rf)
[1] 1 1 1 1 1 1

The issue is that this gives the majority vote, and not the associated scores. Let us compute the scores by aggregating the votes on each tree.

# Get the scores instead:
predict_score_ranger <- function(object, data) {
  pred_all <- predict(object, data, predict.all = TRUE)
  pred_all_df <- as.data.frame(pred_all$predictions)
  ntrees <- ncol(pred_all_df)
  pred_all_df$votes <- rowSums( pred_all_df == 1 ) / ntrees
  pred_all_df$votes
}
pred_val_rf <- predict_score_ranger(final_model_rf, reference_data_rf)
head(pred_val_rf)
[1] 0.882 0.870 0.956 0.970 0.952 0.928

14.2.2 XGBoost

And now, for XGBoost:

pred_val_xgb <- predict(
  grid_search_xgb, newdata = reference_data_xgb,type = "prob"
)
pred_val_xgb <- pred_val_xgb[, "Not_D_and_inf_Q1"]
head(pred_val_xgb)
[1] 0.4375817 0.4256971 0.3940019 0.4918582 0.4555430 0.4870948

Let us get the predicted values on both the train set and the test set as well:

pred_val_all_xgb <- predict(
  grid_search_xgb, newdata = df_all_xgb, type = "prob"
)
pred_val_all_xgb <- pred_val_all_xgb[, "Not_D_and_inf_Q1"]

14.3 Understanding the Output

Let us have a look at the first individual to explain the results obtained in the table. Let us have a look at the first individual.

id_indiv <- 1
df_all_rf[id_indiv, ]
  PERSONNE_pb_asthm PERSONNE_pb_bronchit PERSONNE_pb_infarctus
1                 1                    1                     1
  PERSONNE_pb_coronair PERSONNE_pb_hypertens PERSONNE_pb_avc
1                    1                     1               1
  PERSONNE_pb_arthros PERSONNE_pb_lombalgi PERSONNE_pb_cervical
1                   1                    1                    1
  PERSONNE_pb_diabet PERSONNE_pb_allergi PERSONNE_pb_cirrhos
1                  1                   1                   1
  PERSONNE_pb_urinair PERSONNE_age PERSONNE_sexe PERSONNE_couple
1                   1           32             2               3
  PERSONNE_statut PERSONNE_ss PERSONNE_regime PERSONNE_rap_pcs8 PERSONNE_ald
1               2           1               1                 4            2
  SOINS_ald_am MENAGE_revucinsee MENAGE_tu MENAGE_nbpers ensol_2011
1            2            1561.9         2             4   1812.192
  MUTUELLE_assu MUTUELLE_typcc SOINS_remomn SOINS_remspe SOINS_rempha
1             2              5       178.28       114.08         9.38
  SOINS_remkin SOINS_reminf SOINS_remden SOINS_remmat SOINS_remtra SOINS_remopt
1       243.59            0            0            0            0            0
  SOINS_rempro SOINS_remurg SOINS_tmomn SOINS_tmspe SOINS_tmpha SOINS_tmkin
1         44.2            0       81.12        50.6        17.6      171.34
  SOINS_tminf SOINS_tmden SOINS_tmmat SOINS_tmtra SOINS_tmopt SOINS_tmpro
1           0           0           0           0           0       29.46
  SOINS_tmurg SOINS_dpaomn SOINS_dpaspe SOINS_dpapha SOINS_dpakin SOINS_dpainf
1           0         11.2            0            0            0            0
  SOINS_dpaden SOINS_dpamat SOINS_dpatra SOINS_dpaopt SOINS_dpapro SOINS_dpaurg
1            0            0            0            0        87.14            0
  SOINS_pf_fromn SOINS_pf_frspe SOINS_pf_frpha SOINS_pf_frkin SOINS_pf_frinf
1             11              4              4           13.5              0
  SOINS_pf_frden SOINS_pf_frtra SOINS_pf_frurg SOINS_seac_omn SOINS_seac_spe
1              0              0              0              9              4
  OPINION1_renonc_cons OPINION1_renonc_dent OPINION1_renonc_fin
1                    2                    1                   2
  OPINION1_renonc_loin OPINION1_renonc_long QST_ct_depech QST_ct_liberte
1                    2                    2             2              4
  QST_ct_apprend QST_ct_aidecol QST_ct_travnuit QST_ct_repet QST_ct_lourd
1              2              2               4            4            4
  QST_ct_posture QST_ct_produit QES_association QES_tpsami QES_tpsasso
1              3              3               1          3           2
  QES_tpscolleg QES_tpsfamil QES_mere_etude QES_pere_etude id status
1             4            3              3              3  3      2

The estimated SHAP values for this individual (with the random forest):

treeshap_rf$shaps[id_indiv, ]
  PERSONNE_pb_asthm PERSONNE_pb_bronchit PERSONNE_pb_infarctus
1       0.002412881          0.004311861          0.0002731262
  PERSONNE_pb_coronair PERSONNE_pb_hypertens PERSONNE_pb_avc
1         0.0003429562          0.0005622856     8.38537e-05
  PERSONNE_pb_arthros PERSONNE_pb_lombalgi PERSONNE_pb_cervical
1         0.001050751          0.009802696          0.002056049
  PERSONNE_pb_diabet PERSONNE_pb_allergi PERSONNE_pb_cirrhos
1        0.003028471         0.000660099        3.355765e-05
  PERSONNE_pb_urinair PERSONNE_age PERSONNE_sexe PERSONNE_couple  PERSONNE_ss
1         0.002677253  0.002259672   0.004020243     0.004624774 0.0009792908
  PERSONNE_regime PERSONNE_rap_pcs8 PERSONNE_ald SOINS_ald_am MENAGE_revucinsee
1    5.635322e-05       0.002088681  0.003886941    0.0010831         0.0208107
    MENAGE_tu MENAGE_nbpers  ensol_2011 MUTUELLE_assu MUTUELLE_typcc
1 0.006169268   0.002302185 0.004195434   0.004940774    0.003754732
  SOINS_remomn SOINS_remspe SOINS_rempha SOINS_remkin SOINS_reminf SOINS_remden
1  -0.02211637   0.00461902   0.01143639    -0.016758  0.006993455  0.003180838
  SOINS_remmat SOINS_remtra SOINS_remopt SOINS_rempro SOINS_remurg SOINS_tmomn
1  0.003657292  0.004075546 0.0008374523 -0.004895966  0.002066372 -0.01051562
  SOINS_tmspe SOINS_tmpha SOINS_tmkin  SOINS_tminf SOINS_tmden SOINS_tmmat
1 0.003176497 0.007055068 -0.01337438 0.0002958302 0.002348536 0.001884856
  SOINS_tmtra  SOINS_tmopt  SOINS_tmpro SOINS_tmurg SOINS_dpaomn SOINS_dpaspe
1 0.001135434 0.0008754183 -0.004519039 0.001186415  0.001649548 0.0005853113
  SOINS_dpapha SOINS_dpakin SOINS_dpainf SOINS_dpaden SOINS_dpamat SOINS_dpatra
1 0.0001917441 0.0006158043 4.022302e-05  0.001534009 0.0006455588 9.426563e-06
  SOINS_dpaopt SOINS_dpapro SOINS_dpaurg SOINS_pf_fromn SOINS_pf_frspe
1  0.001744615  -0.02345778 5.796717e-06    -0.02156193    0.003214504
  SOINS_pf_frpha SOINS_pf_frkin SOINS_pf_frinf SOINS_pf_frden SOINS_pf_frtra
1    0.008500961    -0.00991723    0.001394056   3.662451e-05   0.0008812861
  SOINS_pf_frurg SOINS_seac_omn SOINS_seac_spe OPINION1_renonc_cons
1   0.0001329767   1.514114e-05    0.002764504          0.001693441
  OPINION1_renonc_dent OPINION1_renonc_fin OPINION1_renonc_loin
1          -0.01075482         0.001033907          0.001490988
  OPINION1_renonc_long QST_ct_depech QST_ct_liberte QST_ct_apprend
1          0.005928268  -0.008285421    0.006188549    0.007860287
  QST_ct_aidecol QST_ct_travnuit QST_ct_repet QST_ct_lourd QST_ct_posture
1    0.003860199    0.0009169207 0.0005924579 0.0004268576    0.002936785
  QST_ct_produit QES_association    QES_tpsami QES_tpsasso QES_tpscolleg
1     0.00030088     0.001929519 -0.0002505383 0.003191918   0.001927209
  QES_tpsfamil QES_mere_etude QES_pere_etude
1  0.004670225    0.001868683    0.003063726

Each column gives the SHAP value for a variable, for this individual. Each row of treeshap_rf gives the SHAP values for a single individual.

In the reference data (we used the train set), the average predicted scores is:

(mean_pred_rf_ref <- mean(pred_val_rf)) # random forest
[1] 0.8112718
(mean_pred_xgb_ref <- mean(pred_val_xgb)) # XGBoost
[1] 0.4644458

For the first individual, the predicted score, for the random, is:

pred_indiv_rf <- 
  predict_score_ranger(final_model_rf, df_all_rf[id_indiv, ])
pred_indiv_rf
[1] 0.882

And with XGB:

pred_indiv_xgb <- predict(grid_search_xgb, df_all_xgb[id_indiv, ], type = "prob")
pred_indiv_xgb <- pred_indiv_xgb[, "Not_D_and_inf_Q1"]
pred_indiv_xgb
[1] 0.4375817

For the record, the actual class is:

df_all_rf[id_indiv, "status"]
[1] 2
df_all_xgb[id_indiv, "status"]
[1] Not_D_and_sup_Q1
Levels: Not_D_and_inf_Q1 Not_D_and_sup_Q1

The sum of the SHAP values correspond to the deviation from the average of the scores estimated on the reference data:

shap_indiv_rf <- treeshap_rf$shaps[id_indiv, ] # random forest
shap_indiv_xgb <- treeshap_xgb$shaps[id_indiv, ] # XGBoost

For the random forest:

c(
  "shap" = sum(shap_indiv_rf), 
  "pred" = pred_indiv_rf, 
  "mean_pred" = mean_pred_rf_ref,
  "diff" = pred_indiv_rf - mean_pred_rf_ref
)
      shap       pred  mean_pred       diff 
0.07072822 0.88200000 0.81127178 0.07072822 

and for XGBoost:

c(
  "shap" = sum(shap_indiv_xgb), 
  "pred" = pred_indiv_xgb, 
  "mean_pred" = mean_pred_xgb_ref,
  "diff" = pred_indiv_xgb - mean_pred_xgb_ref
)
       shap        pred   mean_pred        diff 
-0.26443689  0.43758169  0.46444577 -0.02686408 

Let us now visualize the SHAP values for a specific individual.

load("../data/out/variable_names.rda")

We build a table with variable names.

variable_names_categ <- 
  enframe(corresp_factors) |> 
  left_join(variable_names, by = c("name" = "variable")) |> 
  unnest(value) |> 
  mutate(
    variable = str_c(name, value),
    label_categ = str_c(label, " = ", value)) |> 
  select(variable, label_categ, variable_raw = name, variable_label_val = value)
variable_names_categ
# A tibble: 176 × 4
   variable                 label_categ          variable_raw variable_label_val
   <chr>                    <chr>                <chr>        <chr>             
 1 PERSONNE_pb_asthmNo      Asthma = No          PERSONNE_pb… No                
 2 PERSONNE_pb_asthmYes     Asthma = Yes         PERSONNE_pb… Yes               
 3 PERSONNE_pb_bronchitNo   Bronchitis = No      PERSONNE_pb… No                
 4 PERSONNE_pb_bronchitYes  Bronchitis = Yes     PERSONNE_pb… Yes               
 5 PERSONNE_pb_infarctusNo  Heart Attack = No    PERSONNE_pb… No                
 6 PERSONNE_pb_infarctusYes Heart Attack = Yes   PERSONNE_pb… Yes               
 7 PERSONNE_pb_coronairNo   Artery Disease = No  PERSONNE_pb… No                
 8 PERSONNE_pb_coronairYes  Artery Disease = Yes PERSONNE_pb… Yes               
 9 PERSONNE_pb_hypertensNo  Hypertension = No    PERSONNE_pb… No                
10 PERSONNE_pb_hypertensYes Hypertension = Yes   PERSONNE_pb… Yes               
# ℹ 166 more rows

We define a function, plot_individual_variable_effect() to understand why, according to the estimated SHAP values, the predicted value for a single individuals deviates from the average prediction in the reference sample.

#' Plot the decomposition of the deviation from the average prediction
#' for an individual, using SHAP values
#' 
#' @param i row number to designate individual
#' @param treeshap_res result obtained with treeshap
#' @param predicted_val_i predicted value for the individual
#' @param mean_pred_ref average predicted value in the reference data
#' @param n top n variables (default to 10)
#' @param min_max if not `NULL`, limits for the x axis (Shap values), in a 
#'  vector of length 2 containing c(min, max)
plot_individual_variable_effect <- function(i, 
                                            treeshap_res, 
                                            predicted_val_i,
                                            mean_pred_ref,
                                            n = 10,
                                            min_max = NULL) {
  
  df_plot <- 
    treeshap_res$shaps |> 
    dplyr::slice(i) |> 
    unlist() |> 
    enframe(name = "variable", value = "shap") |> 
    left_join(variable_names, by = "variable") |> 
    left_join(variable_names_categ, by = "variable") |> 
    mutate(label = ifelse(is.na(label), label_categ, label)) |> 
    arrange(desc(abs(shap))) |> 
    mutate(in_top_n = row_number() <= !!n) |> 
    mutate(
      label_2 = ifelse(in_top_n, label, "Other Variables")
    ) |> 
    group_by(label_2) |> 
    summarise(
      shap = sum(shap),
      .groups = "drop"
    ) |> 
    mutate(
      label_2 = fct_reorder(label_2, abs(shap)),
      label_2 = fct_relevel(label_2, "Other Variables")
    ) |> 
    mutate(
      mean_pred_ref = !!mean_pred_ref,
      shap_rounded = str_c("  ", round(shap, 3)),
      sign = sign(shap),
      sign = factor(sign, levels = c(-1, 1), labels = c("-", "+"))
    )
  
  p <- ggplot(
    data = df_plot, 
    mapping = aes(
      y =  mean_pred_ref + pmax(shap, 0),
      x = label_2,
      ymin = mean_pred_ref,
      ymax = mean_pred_ref + shap,
      colour = sign)
  ) +
    geom_linerange(linewidth = 8) +
    geom_hline(yintercept = mean_pred_ref) +
    geom_text(aes(label = shap_rounded, hjust = 0)) +
    coord_flip() +
    scale_colour_manual(
      "Variable impact", 
      values = c("-" = "#EE324E", "+" = "#009D57"), 
      labels = c("-" = "Negative", "+" = "Positive")
    ) +
    labs(
      x = NULL, 
      y = "Shap values", 
      title=str_c("Predicted value: ", round(predicted_val_i, 3)),
      subtitle = str_c("Base value: ", round(mean_pred_ref, 3))
    ) +
    theme(
      panel.grid.major.y = element_blank(), axis.ticks.y = element_blank(),
      legend.position = "bottom",
      plot.title.position = "plot",
      plot.title = element_text(hjust = 0, size = rel(1.3), face = "bold")
    )
  
  if (!is.null(min_max)) {
    p <- p + scale_y_continuous(limits = min_max)
  }
  p
}

Let us look at the first individual:

i <- 1
treeshap_rf_opposite <- treeshap_rf
treeshap_rf_opposite$shaps <- -treeshap_rf_opposite$shaps
plot_individual_variable_effect(
  i = i, 
  treeshap_res = treeshap_rf, 
  predicted_val_i = 1-pred_val_rf[i], 
  mean_pred_ref = 1-mean_pred_rf_ref, n = 20
  # min_max = c(0.75, 0.95)
)
Figure 14.1: SHAP values for the first individual, when the score is estimated with a Random Forest. Positive class: Imaginary healthy patient.
plot_individual_variable_effect(
  i = i, 
  treeshap_res = treeshap_xgb, 
  predicted_val_i = pred_val_xgb[i], 
  mean_pred_ref = mean_pred_xgb_ref, n = 20
  # min_max = c(0.1, 1)
)
Figure 14.2: SHAP values for the first individual, when the score is estimated with XGBoost. Positive class: Imaginary healthy patient.

Now, let us pick two individuals: one with a low predicted score, and another one with a high predicted score. We define two functions: get_top_n_indiv(), to get the top n variables explaining the deviation of the score of an individual from the average predicted score in the reference set, and get_table_charact() to extract the characteristics of specific individuals given a set of characteristics.

#' Extracts the top n variables according to absolute SHAP val
#' 
#' @param i row number of individual
#' @param treeshap_res result obtained with treeshap
#' @param n top n variables (default to 10)
get_top_n_indiv <- function(i, treeshap_res, n = 10) {
  treeshap_res$shaps |> 
    dplyr::slice(i) |> 
    unlist() |> 
    enframe(name = "variable", value = "shap") |> 
    arrange(desc(abs(shap))) |> 
    filter(row_number() <= !!n) |> 
    pull("variable")
}



#' Get the characteristics of the union of top n variables among individuals
#' (using abolute SHAP values)
#' for xgboost only
#' 
#' @param i row numbers of individuals
#' @param treeshap_res result obtained with treeshap
#' @param reference_data reference dataset
#' @param n top n variables (default to 10)
get_table_charact <- function(i, 
                              treeshap_res, 
                              n = 10) {
  # Top n variables for individuals
  top_n_indiv <- map(
    i, 
    ~get_top_n_indiv(i = .x, treeshap_res = treeshap_res, n = n)
  )
  top_n_indiv_variables <- list_c(top_n_indiv) |> unique()
  
  # Average values for those variables in the reference sample
  values_reference <- 
    model.matrix(formula, reference_data_xgb) |> 
    as_tibble() |> 
    select(!!top_n_indiv_variables) |> 
    summarise(across(everything(), ~mean(.x))) |> 
    pivot_longer(cols = everything(), names_to = "variable", values_to = "val_ref")
  
  # Characteristics for these individuals
  top_n_indiv_characteristics <- map2(
    .x = i, .y = top_n_indiv,
    ~tibble(variable = top_n_indiv_variables, id_indiv = .x) |> 
      left_join(
        model.matrix(formula, df_all_xgb) |>
          as_tibble() |> 
          dplyr::slice(.x) |> 
          select(!!top_n_indiv_variables) |> 
          pivot_longer(
            cols = everything(), 
            names_to = "variable", values_to = "value_indiv"
          )
      ) |> 
      mutate(in_top_n = ifelse(variable %in% .y, yes = "$\\checkmark$", no = ""))
  ) |> 
    list_rbind()
  top_n_indiv_characteristics |> 
    left_join(variable_names, by = "variable") |> 
    left_join(variable_names_categ, by = "variable") |> 
    left_join(values_reference) |> 
    mutate(label = ifelse(is.na(label), label_categ, label)) |> 
    select(label, id_indiv, value_indiv, in_top_n, val_ref) |> 
    pivot_wider(names_from = id_indiv, values_from = c(value_indiv, in_top_n))
}

Let us identify two individuals:

i_low <- 192
i_high <- 55
round(c(pred_val_xgb[i_low], pred_val_xgb[i_high]), 3)
[1] 0.166 0.818
Code
plot_indiv_low_pred <- 
  plot_individual_variable_effect(
  i = i_low, 
  treeshap_res = treeshap_xgb, 
  predicted_val_i = pred_val_xgb[i_low], 
  mean_pred_ref = mean_pred_xgb_ref, n = 10
  # min_max = c(0, .6)
)
plot_indiv_low_pred
Figure 14.3: SHAP values for an individual with a low predicted score, when the latter is estimated with XGBoost. Positive class: Imaginary healthy patient.
Code
plot_indiv_high_pred <- 
  plot_individual_variable_effect(
  i = i_high, 
  treeshap_res = treeshap_xgb, 
  predicted_val_i = pred_val_xgb[i_high], 
  mean_pred_ref = mean_pred_xgb_ref, n = 10
  # min_max = c(0.4, .95)
)
plot_indiv_high_pred
Figure 14.4: SHAP values for an individual with a high predicted score, when the latter is estimated with XGBoost. Positive class: Imaginary healthy patient.

We retrieve the characteristics of the identified most important variables for both individuals, as well as the average values in the reference sample (train set):

tb_two_indivs_example <- get_table_charact(
    i = c(i_low, i_high), 
    treeshap_res = treeshap_xgb, 
    n = 10
)
Joining with `by = join_by(variable)`
Joining with `by = join_by(variable)`
Joining with `by = join_by(variable)`

And then, we display them in a nice table:

Code
tb_two_indivs_example |> 
  select(label, value_indiv_192, in_top_n_192, 
         value_indiv_55, in_top_n_55, val_ref) |> 
  mutate(val_ref = round(val_ref, 2)) |> 
  kableExtra::kable(booktabs = TRUE) |>
  kableExtra::kable_classic(full_width = F) |> 
  kableExtra::kable_styling() |>
  unclass() |> cat()
Table 14.1: Characteristics for the Most Important Variables for Two Individuals According to their SHAP Values.
label value_indiv_192 in_top_n_192 value_indiv_55 in_top_n_55 val_ref
Very Little Freedom to Do Job = Never 1.00 \(\checkmark\) 0.00 0.19
Have to Hurry to Do Job = Sometimes 1.00 \(\checkmark\) 0.00 0.18
Reimbursement General Practitioner 30.20 \(\checkmark\) 332.40 \(\checkmark\) 88.84
Net Income per Cons. Unit 1700.00 \(\checkmark\) 520.00 \(\checkmark\) 1604.76
Reimbursement Pharmacy 108.38 \(\checkmark\) 1533.62 \(\checkmark\) 372.20
Gender = Male 1.00 \(\checkmark\) 1.00 \(\checkmark\) 0.47
Age 28.00 \(\checkmark\) 74.00 48.81
Low Back Pain = Yes 0.00 \(\checkmark\) 1.00 \(\checkmark\) 0.21
No. Medical Sessions General Pract. 2.00 \(\checkmark\) 11.00 \(\checkmark\) 4.68
Waiver Appointment Delay Too Long = No 1.00 \(\checkmark\) 1.00 0.68
Couple = Yes 0.00 0.00 \(\checkmark\) 0.65
Frequency Meeting with Family Living Outside Household = Less than once a month 0.00 1.00 \(\checkmark\) 0.15
Long-term condition (Self-declared) = No 1.00 0.00 \(\checkmark\) 0.80
Insurance = No answer 0.00 1.00 \(\checkmark\) 0.13

14.4 Variable importance

We compute the average absolute Shap value for each variable and order the results by descending values. We filter the results to focus on the predictions of being classified as a imaginary healthy patient.

# Mean(| SHAP |)
var_imp_rf_mean_abs_shap <- treeshap_rf_opposite$shaps |> 
  summarise(across(everything(), ~mean(abs(.x))))

# Mean(SHAP)
var_imp_rf_mean_shap <- treeshap_rf_opposite$shaps |> 
  summarise(across(everything(), ~mean(.x)))

We order the variables by descending values of the mean absolute SHAP values.

order_variables_shap_rf <- 
  sort(unlist(var_imp_rf_mean_abs_shap), decreasing = TRUE) |> 
  enframe() |> 
  left_join(
    sort(unlist(var_imp_rf_mean_shap), decreasing = TRUE) |> 
      enframe(value = "mean"),
    by = "name"
  )

The top 10 variables:

top_n_variables_rf <- order_variables_shap_rf$name[1:10]
top_n_variables_rf
 [1] "MENAGE_revucinsee"    "SOINS_remomn"         "SOINS_rempha"        
 [4] "SOINS_seac_omn"       "SOINS_pf_frpha"       "PERSONNE_pb_lombalgi"
 [7] "QES_tpsfamil"         "QST_ct_depech"        "SOINS_remkin"        
[10] "PERSONNE_pb_bronchit"
# Mean(| SHAP |)
var_imp_xgb_mean_abs_shap <- treeshap_xgb$shaps |> 
  summarise(across(everything(), ~mean(abs(.x))))

# Mean(SHAP)
var_imp_xgb_mean_shap <- treeshap_xgb$shaps |> 
  summarise(across(everything(), ~mean(.x)))

We order the variables by descending values of the mean absolute SHAP values.

order_variables_shap_xgb <- 
  sort(unlist(var_imp_xgb_mean_abs_shap), decreasing = TRUE) |> 
  enframe() |> 
  left_join(
    sort(unlist(var_imp_xgb_mean_shap), decreasing = TRUE) |> 
      enframe(value = "mean"),
    by = "name"
  )
save(order_variables_shap_xgb, file = "../data/out/order_variables_shap_xgb_mhi3.rda")

The top 10 variables:

top_n_variables_xgb <- order_variables_shap_xgb$name[1:10]
top_n_variables_xgb
 [1] "MENAGE_revucinsee"       "SOINS_remomn"           
 [3] "SOINS_rempha"            "QST_ct_liberteNever"    
 [5] "PERSONNE_pb_lombalgiYes" "PERSONNE_sexeMale"      
 [7] "PERSONNE_age"            "QST_ct_depechSometimes" 
 [9] "OPINION1_renonc_longNo"  "SOINS_seac_omn"         

Let us plot the top 10 variables by importance according to the SHAP values.

Code
p_variable_importance_rf <- 
  ggplot(
    data = order_variables_shap_rf |> 
      left_join(variable_names, by = c("name" = "variable")) |> 
      mutate(
        label = ifelse(name %in% !!top_n_variables_rf, 
                       label, "Other Variables")
      ) |> 
      group_by(label) |> 
      summarise(
        value = sum(value),
        .groups = "drop"
      ) |> 
      mutate(
        label = fct_reorder(label, value),
        label = fct_relevel(label, "Other Variables")
      ),
    mapping = aes(
      x = label, 
      y = value
    )
  ) +
  geom_bar(stat = "identity") +
  coord_flip() +
  labs(y = "Mean(|SHAP Value|)", x = NULL)

p_variable_importance_rf
Figure 14.5: Variable importance according to SHAP values (with the random forest classifier)
Code
p_variable_importance_xgb <- 
  ggplot(
    data = order_variables_shap_xgb |> 
      left_join(variable_names, by = c("name" = "variable")) |> 
      left_join(variable_names_categ, by = c("name" = "variable")) |> 
      mutate(
        label = ifelse(is.na(label), label_categ, label),
        label = ifelse(name %in% !!top_n_variables_xgb, 
                       label, "Other Variables")
      ) |> 
      group_by(label) |> 
      summarise(
        value = sum(value),
        .groups = "drop"
      ) |> 
      mutate(
        label = fct_reorder(label, value),
        label = fct_relevel(label, "Other Variables")
      ),
    mapping = aes(
      x = label, 
      y = value
    )
  ) +
  geom_bar(stat = "identity") +
  coord_flip() +
  labs(y = "Mean(|SHAP Value|)", x = NULL)

p_variable_importance_xgb
Figure 14.6: Variable importance according to SHAP values (with the XGBoost classifier)

The complete ranking, for both models:

order_variables_shap_xgb |> 
  mutate(rank_xgb = row_number()) |> 
  rename(
    mean_abs_shap_xgb = value,
    mean_shap_xgb = mean
  ) |> 
  full_join(
    order_variables_shap_rf |> 
      mutate(rank_rf = row_number()) |> 
      rename(
        mean_abs_shap_rf = value,
        mean_shap_rf = mean
      )
  ) |> 
  DT::datatable()
Joining with `by = join_by(name)`

In Python, the summary_plot() function is used to offer a quick view of the effect of each variable on the prediction. We can reproduce such a plot. For quantitative variables, this type of graph seems advantageous. For quantitative variables, on the other hand, we prefer a different type of graph (see below).

Let us define a function to rescale values (min-max rescaling).

#' std1
#' a function to standardize feature values into same range
#' Source: https://github.com/pablo14/shap-values/blob/master/shap.R
std1 <- function(x) {
  ((x - min(x, na.rm = T)) / (max(x, na.rm = T) - min(x, na.rm = T)))
}

Then, we need to create a table that contains for each individual and each variable, the Shap value, and the standardized value of the characteristics.

values_min_max_rf <- 
  df_all_rf |> 
  select(!!top_n_variables_rf) |> 
  pivot_longer(cols = everything()) |> 
  group_by(name) |> 
  summarise(
    min_value = min(value), max_value = max(value),
    med_value = median(value),
    q1_value = quantile(value, probs = .25),
    q3_value = quantile(value, probs = .75)
  )
values_min_max_xgb <- 
  as_tibble(model.matrix(formula, df_all_xgb)) |> 
  select(!!top_n_variables_xgb) |> 
  pivot_longer(cols = everything()) |> 
  group_by(name) |> 
  summarise(
    min_value = min(value), max_value = max(value),
    med_value = median(value),
    q1_value = quantile(value, probs = .25),
    q3_value = quantile(value, probs = .75)
  )

We format the data to be able to create the plot.

df_plot_rf <- 
  treeshap_rf_opposite$shaps |> 
  select(!!top_n_variables_rf) |> 
  mutate(id_row = row_number()) |> 
  pivot_longer(cols = -id_row, values_to = "shap") |> 
  left_join(
    df_all_rf |> 
      select(!!top_n_variables_rf) |> 
      mutate(id_row = row_number()) |> 
      pivot_longer(cols = -id_row, values_to = "value"),
    by = c("id_row", "name")
  ) |> 
  group_by(name) |> 
  mutate(
    std_shap_value = std1(shap),
    mean_shap_value = mean(shap),
    mean_abs_shap_value = mean(abs(shap))
  ) |> 
  ungroup() |> 
  left_join(
    values_min_max_rf,
    by = "name"
  ) |> 
  mutate(
    value_variable_std = (value - min_value) / (max_value - min_value)
  ) |> 
  left_join(variable_names, by = c("name" = "variable")) |> 
  mutate(
    label = fct_reorder(label, mean_abs_shap_value)
  )
df_plot_xgb <- 
  treeshap_xgb$shaps |> 
  select(!!top_n_variables_xgb) |> 
  mutate(id_row = row_number()) |> 
  pivot_longer(cols = -id_row, values_to = "shap") |> 
  left_join(
    as_tibble(model.matrix(formula, df_all_xgb)) |> 
      select(!!top_n_variables_xgb) |> 
      mutate(id_row = row_number()) |> 
      pivot_longer(cols = -id_row, values_to = "value"),
    by = c("id_row", "name")
  ) |> 
  group_by(name) |> 
  mutate(
    std_shap_value = std1(shap),
    mean_shap_value = mean(shap),
    mean_abs_shap_value = mean(abs(shap))
  ) |> 
  ungroup() |> 
  left_join(
    values_min_max_xgb,
    by = "name"
  ) |> 
  mutate(
    value_variable_std = (value - min_value) / (max_value - min_value)
  ) |> 
  left_join(variable_names, by = c("name" = "variable")) |> 
  left_join(variable_names_categ, by = c("name" = "variable")) |> 
  mutate(
    label = ifelse(is.na(label), label_categ, label),
    label = fct_reorder(label, mean_abs_shap_value)
  )

Then, using the geom_sina() function from {ggforce}, we can create the summary plot. Each dot, for each variable given in the y-axis represent an individual. The x-axis gives the estimated Shap value of the variable for each individual. The colours state whether the value of the variable for a specific individual is low (blue) or high (red), using the standardized value as the reference. The variables on the y-axis appear according to their relative importance in explaining the prediction. For example, the graph shows that pharmacy deductible amounts are relatively important in explaining the probability of being classified as an imaginary healthy patient. When the value of this franchise is low (blue dots), this variable influences the prediction downwards (since the blue dots are mostly associated with negative Shap values); conversely, when the pharmacy franchise is high (red dots), the probability of being classified as an imaginary healthy patient increases.

Code
library(ggforce)
p_rf <- ggplot(
  data = df_plot_rf,
  mapping = aes(
    y = shap, 
    x = label, 
    colour = value_variable_std
  )
) +
  geom_sina(alpha = .5) +
  scale_color_gradient(
    "Feature value",
    low="#0081BC", high = "#EE324E",
    breaks = c(0,1), labels = c("Low", "High"),
    guide = guide_colourbar(
      direction = "horizontal",
      barwidth = 40,
      title.position = "bottom",
      title.hjust = .5,
      midpoint = 0
    )
  ) +
  labs(y = "SHAP value (impact on model output)", x = NULL) +
  theme(
    legend.position = "bottom",
    legend.key.height   = unit(1, "line"),
    legend.key.width    = unit(1.5, "line")
  ) +
  coord_flip() +
  geom_hline(yintercept = 0, linetype = "dashed")
p_rf
Figure 14.7: Estimated effect of variables on the probability of being predicted an imaginary healthy patienteach individual (random forest)
Code
p_xgb <- ggplot(
  data = df_plot_xgb,
  mapping = aes(
    y = shap, 
    x = label, 
    colour = value_variable_std
  )
) +
  geom_sina(alpha = .5) +
  scale_color_gradient(
    "Feature value",
    low="#0081BC", high = "#EE324E",
    breaks = c(0,1), labels = c("Low", "High"),
    guide = guide_colourbar(
      direction = "horizontal",
      barwidth = 40,
      title.position = "bottom",
      title.hjust = .5,
      midpoint = 0
    )
  ) +
  labs(y = "SHAP value (impact on model output)", x = NULL) +
  theme(
    legend.position = "bottom",
    legend.key.height   = unit(1, "line"),
    legend.key.width    = unit(1.5, "line")
  ) +
  coord_flip() +
  geom_hline(yintercept = 0, linetype = "dashed")
p_xgb
Figure 14.8: Estimated effect of variables on the probability of being predicted an imaginary healthy patienteach individual (XGBoost)

Now, instead of a summary plot, let us create a plot for each of the variables in the top n. To that end, we create a function which plots the estimated SHAP values for a single variable for each individuals. When the variable is categorical, we want the graph to be split depending on the categories.

For continuous variable, we may want to make the colors of the dots depend on the quantiles. This is the case for net income, because of the shape of the distribution:

hist(df_all_rf$MENAGE_revucinsee, breaks = 100)

summary(df_all_rf$MENAGE_revucinsee)
   Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
     74     992    1400    1610    2000   25300 

The code of the plot function is a bit long.

Code
#' Plots the SHAP values for each individuals, for a single variable
#'
#' @param variable_name name of the variable
#' @param treeshap_res treeshap results
#' @param df_all data frame with observations from train and test sets
#' @param quantile if TRUE, uses the empirical quantiles as the colour code (default to FALSE)
#' @param size_dots size of the dots
#' @param bar_width legend bar width for qualitative variables
plot_shap_all_indiv_variable_rf <- function(variable_name, 
                                            treeshap_res,
                                            df_all,
                                            quantile = FALSE, 
                                            size_dots = .6, 
                                            bar_width = 30) {
  df_plot <- 
    treeshap_res$shaps |> 
    select(!!variable_name) |> 
    rename(SHAP_value = !!variable_name) |> 
    mutate(variable_name = variable_name) |> 
    as_tibble() |> 
    # Add value of the variable from the dataset
    mutate(
      variable_value = df_all |> pull(!!variable_name)
    )
  
  is_factor <- variable_name %in% names(corresp_factors)
  if (is_factor == TRUE) {
    # Get the levels for the variable
    niveaux <- corresp_factors[[variable_name]]
    df_plot <- 
      df_plot |> 
      left_join(
        tibble(
          variable_value = 1:length(niveaux),
          variable_value_level = niveaux
        ),
        by = "variable_value"
      )
    
    if (any(niveaux == "No answer")) {
      nouveaux_niveaux <- niveaux
      nouveaux_niveaux[nouveaux_niveaux == "No answer"] <- "Did not answer"
      
      if ("At least once a month" %in% niveaux) {
        ordre_niveaux <- c(
          "Every day or almost every day",
          "At least once a week",
          "At least once a month",
          "Less than once a month", "Never", "Did not answer"
        )
        ordre_niveaux_tbl <- tibble(niveaux, nouveaux_niveaux) |> 
          mutate(
            nouveaux_niveaux = factor(nouveaux_niveaux, levels = ordre_niveaux)
          ) |> 
          arrange(desc(nouveaux_niveaux))
        
        niveaux <- ordre_niveaux_tbl$niveaux
        nouveaux_niveaux <- as.character(ordre_niveaux_tbl$nouveaux_niveaux)
        
      } else if ("Sometimes" %in% niveaux) {
        ordre_niveaux <- c("Always", "Often", "Sometimes", "Never", "Did not answer")
        ordre_niveaux_tbl <- 
          tibble(niveaux, nouveaux_niveaux) |> 
          mutate(
            nouveaux_niveaux = factor(nouveaux_niveaux, levels = ordre_niveaux)
          ) |> 
          arrange(desc(nouveaux_niveaux))
        
        niveaux <- ordre_niveaux_tbl$niveaux
        nouveaux_niveaux <- as.character(ordre_niveaux_tbl$nouveaux_niveaux)
      }
      df_plot <- 
        df_plot |> left_join(
          tibble(
            variable_value_level = niveaux, 
            variable_value_level_new = nouveaux_niveaux
          ),
          by = "variable_value_level"
        ) |> 
        select(-variable_value_level) |> 
        rename(variable_value_level = variable_value_level_new)
      
      niveaux <- nouveaux_niveaux
    }
    
    df_plot <- df_plot |> 
      mutate(
        variable_value_level = factor(variable_value_level, levels = niveaux)
      )
  }
  
  if (is_factor == TRUE) {
    p <- 
      ggplot(
        data = df_plot, 
        mapping = aes(
          y = SHAP_value, 
          x = variable_value_level
        )
      ) +
      geom_sina(size = size_dots) +
      geom_violin(alpha=0) +
      coord_flip() +
      labs(y = "SHAP value (impact on model output)", x = NULL)
  } else {
    
    if (quantile == TRUE) {
      
      df_plot$quantile_variable <- 
        cut(
          pull(df_plot, "variable_value"),
          labels = c(
            "$\\leq$ D1", "D1 to D2", "D2 to D3", "D3 to D4",
            "D4 to D5", "D5 to D6", "D6 to D7", "D7 to D8", 
            "D8 to D9", "> D9"
          ),
          breaks = c(
            0,
            quantile(
              magrittr::extract2(df_all, variable_name),
              probs = seq(.1, .9, .1)
            ),
            Inf
          )
        )
      
      df_plot <- 
        df_plot |> 
        mutate(quantile_revenu_num = as.numeric(quantile_variable))
      
      p <- 
        ggplot(
          data = df_plot |> mutate(x = 1),,
          mapping = aes(
            y = SHAP_value,
            x = x,
            colour = quantile_revenu_num
          )
        ) +
        geom_sina(alpha = .5, size = size_dots) +
        geom_violin(alpha = 0) +
        scale_color_gradient(
          "Variable value",
          breaks = c(1, 4, 7, 10),
          labels = c("$\\leq$ D1", "D3 to D4", "D6 to D7", "$>$ D9"),
          low = "#0081BC", high = "#EE324E",
          guide = guide_colourbar(
            direction = "horizontal",
            barwidth = bar_width,
            title.position = "bottom",
            title.hjust = .5
          )
        ) +
        labs(y = "SHAP value (impact on model output)", x = NULL) +
        theme(
          legend.position = "bottom",
          legend.key.height = unit(1, "line"),
          legend.key.width  = unit(1.5, "line"),
          axis.text.y = element_blank(),
          axis.ticks.y = element_blank()
        ) +
        coord_flip()
      
    } else {
      # Not using quantile
      p <- 
        ggplot(
          data = df_plot |> mutate(x = 1),
          mapping = aes(
            x = x,
            y = SHAP_value,
            colour = variable_value
          )
        ) +
        geom_sina(alpha = .5, size = size_dots) +
        geom_violin(alpha = 0) +
        scale_color_gradient(
          "Variable value",
          low = "#0081BC", high = "#EE324E",
          guide = guide_colourbar(
            direction = "horizontal",
            barwidth = bar_width,
            title.position = "bottom",
            title.hjust = .5)
        ) +
        labs(y = "SHAP value (impact on model output)", x = NULL) +
        theme(
          legend.position = "bottom",
          legend.key.height = unit(1, "line"),
          legend.key.width  = unit(1.5, "line"),
          axis.text.y = element_blank(), 
          axis.ticks.y = element_blank()
        ) +
        coord_flip()
    }
  }
  p + theme(
    plot.title.position = "plot",
    plot.title = element_text(hjust = 0, size = rel(1.3), face = "bold"),
    legend.text = element_text(size = rel(.8)),
    legend.title = element_text(size = rel(.8))
  )
}
Net Income per Cons. Unit

Reimbursement General Practitioner

Reimbursement Pharmacy

No. Medical Sessions General Pract.

Deduct. Pharmacy

Low Back Pain

Frequency Meeting with Family Living Outside Household

Have to Hurry to Do Job

Reimbursement Physiotherapist

Bronchitis

Code
#' Plots the SHAP values for each individuals, for a single variable
#'
#' @param variable_name name of the variable
#' @param treeshap_res treeshap results
#' @param df_all data frame with observations from train and test sets
#' @param quantile if TRUE, uses the empirical quantiles as the colour code (default to FALSE)
#' @param size_dots size of the dots
#' @param bar_width legend bar width for qualitative variables
plot_shap_all_indiv_variable_xgb <- function(variable_name, 
                                             treeshap_res,
                                             df_all,
                                             quantile = FALSE, 
                                             size_dots = .6, 
                                             bar_width = 30) {
  df_plot <- 
    treeshap_res$shaps |> 
    select(!!variable_name) |> 
    rename(SHAP_value = !!variable_name) |> 
    mutate(variable_name = variable_name) |> 
    as_tibble() |> 
    # Add value of the variable from the dataset
    mutate(
      variable_value = df_all |> pull(!!variable_name)
    )
  
  is_factor <- variable_name %in% variable_names_categ$variable
  if (is_factor == TRUE) {
    # Not using quantile
    p <-
      ggplot(
        data = df_plot |> mutate(x = 1),
        mapping = aes(
          x = x,
          y = SHAP_value,
          colour = variable_value
        )
      ) +
      geom_sina(alpha = .5, size = size_dots) +
      geom_violin(alpha = 0) +
      scale_color_gradient(
        "Variable value",
        low = "#0081BC", high = "#EE324E", n.breaks = 2
      ) +
      labs(y = "SHAP value (impact on model output)", x = NULL) +
      theme(
        legend.position = "bottom",
        legend.key.height = unit(1, "line"),
        legend.key.width  = unit(1.5, "line"),
        axis.text.y = element_blank(), 
        axis.ticks.y = element_blank()
      ) +
      coord_flip() +
      guides(color = guide_legend(override.aes = list(
        color = c("#0081BC", "#EE324E"),
        label = c("0", "1")
      )))
  } else {
    
    if (quantile == TRUE) {
      
      breaks_q <- c(
        0,
        quantile(
          magrittr::extract2(df_all, variable_name),
          probs = seq(.1, .9, .1)
        ),
        Inf
      )
      labels_q <- c(
        "$\\leq$ D1", "D1 to D2", "D2 to D3", "D3 to D4",
        "D4 to D5", "D5 to D6", "D6 to D7", "D7 to D8", 
        "D8 to D9", "> D9"
      )
      ind_remove <- which(duplicated(breaks_q))
      if (length(ind_remove) > 0) {
        labels_q <- labels_q[-ind_remove]
        breaks_q <- breaks_q[-ind_remove]
      }
      df_plot$quantile_variable <- 
        cut(
          pull(df_plot, "variable_value"),
          labels = labels_q,
          breaks = breaks_q
        )
      
      df_plot <- 
        df_plot |> 
        mutate(quantile_revenu_num = as.numeric(quantile_variable))
      
      p <- 
        ggplot(
          data = df_plot |> mutate(x = 1),,
          mapping = aes(
            y = SHAP_value,
            x = x,
            colour = quantile_revenu_num
          )
        ) +
        geom_sina(alpha = .5, size = size_dots) +
        geom_violin(alpha = 0) +
        scale_color_gradient(
          "Variable value",
          breaks = c(1, 4, 7, 10),
          labels = c("$\\leq$ D1", "D3 to D4", "D6 to D7", "$>$ D9"),
          low = "#0081BC", high = "#EE324E",
          guide = guide_colourbar(
            direction = "horizontal",
            barwidth = bar_width,
            title.position = "bottom",
            title.hjust = .5
          )
        ) +
        labs(y = "SHAP value (impact on model output)", x = NULL) +
        theme(
          legend.position = "bottom",
          legend.key.height = unit(1, "line"),
          legend.key.width  = unit(1.5, "line"),
          axis.text.y = element_blank(),
          axis.ticks.y = element_blank()
        ) +
        coord_flip()
      
    } else {
      # Not using quantile
      p <- 
        ggplot(
          data = df_plot |> mutate(x = 1),
          mapping = aes(
            x = x,
            y = SHAP_value,
            colour = variable_value
          )
        ) +
        geom_sina(alpha = .5, size = size_dots) +
        geom_violin(alpha = 0) +
        scale_color_gradient(
          "Variable value",
          low = "#0081BC", high = "#EE324E",
          guide = guide_colourbar(
            direction = "horizontal",
            barwidth = bar_width,
            title.position = "bottom",
            title.hjust = .5)
        ) +
        labs(y = "SHAP value (impact on model output)", x = NULL) +
        theme(
          legend.position = "bottom",
          legend.key.height = unit(1, "line"),
          legend.key.width  = unit(1.5, "line"),
          axis.text.y = element_blank(), 
          axis.ticks.y = element_blank()
        ) +
        coord_flip()
    }
  }
  p + theme(
    plot.title.position = "plot",
    plot.title = element_text(hjust = 0, size = rel(1.3), face = "bold"),
    legend.text = element_text(size = rel(.8)),
    legend.title = element_text(size = rel(.8))
  )
}
Net Income per Cons. Unit

Reimbursement General Practitioner

Reimbursement Pharmacy

Very Little Freedom to Do Job = Never

Low Back Pain = Yes

Gender = Male

Age

Have to Hurry to Do Job = Sometimes

Waiver Appointment Delay Too Long = No

No. Medical Sessions General Pract.

14.5 Clustering

It may be possible to group individuals in the dataset depending the profile of the shap values. Let us perform a hierarchical clustering to group the individuals depending on their Shapley values. We will only keep a few variables to represent each individual, and will focus only on people classified as imaginary healthy patients. First, let us compute the aberage Shapley value for each variable, among the predicted imaginary healthy patients.

14.5.1 Note

We only focus on the SHAP values computed based on the predictions made with XGBoost in this part.

Recall we put in a tibble (order_variables_shap_xgb) the average absolute SHAP values (column value) and the average SHAP values (column mean):

order_variables_shap_xgb
# A tibble: 175 × 3
   name                     value     mean
   <chr>                    <dbl>    <dbl>
 1 MENAGE_revucinsee       0.202  -0.0269 
 2 SOINS_remomn            0.147  -0.0222 
 3 SOINS_rempha            0.140  -0.0232 
 4 QST_ct_liberteNever     0.0981 -0.0125 
 5 PERSONNE_pb_lombalgiYes 0.0900 -0.00804
 6 PERSONNE_sexeMale       0.0831 -0.00492
 7 PERSONNE_age            0.0810 -0.00335
 8 QST_ct_depechSometimes  0.0777 -0.0107 
 9 OPINION1_renonc_longNo  0.0578 -0.00706
10 SOINS_seac_omn          0.0506 -0.00855
# ℹ 165 more rows

We can display the average of the absolute SHAP values on a plot. A vertical red line is added at the average of the average of absolute Shapley values. Red dots indicate variables with negative average Shapley values, green dots indicate positive average values, and gray dots indicate zero average values.

ggplot(
  data = order_variables_shap_xgb |>
    left_join(variable_names, by = c("name" = "variable")) |>
    left_join(variable_names_categ, by = c("name" = "variable")) |>
    mutate(label = ifelse(is.na(label), label_categ, label)) |>
    rename(mean_abs_shap_value = value) |>
    mutate(
      sign = sign(mean),
      sign = factor(sign, levels = c(-1, 0, 1), labels = c("-", "0", "+"))
    ),
  mapping = aes(
    y = fct_reorder(label, desc(mean_abs_shap_value)),
    x = mean_abs_shap_value
  )
) +
  geom_point(mapping = aes(colour = sign)) +
  geom_vline(
    xintercept = mean(order_variables_shap_xgb$value), colour = "red"
  ) +
  scale_colour_manual(
    "Feature impact",
    values = c("-" = "#EE324E", "+" = "#009D57", "0" = "gray"),
    labels = c("-" = "Negative", "+" = "Positive")
  ) +
  labs(x = "Mean(|Shap value|)", y = "Variable")
Figure 14.9: Mean absolute SHAP values for each variable

Let us create a table in which we only keep the individuals predicted as imaginary healthy patients:

df_clustering_Not_D_and_inf_Q1 <-
  treeshap_xgb$shaps |>
  mutate(id_row = row_number()) |>
  mutate(predicted = pred_val_all_xgb) |>
  filter(predicted > .5) |>
  as_tibble() |>
  select(-predicted)

Let us only keep the 8 variables with the highest mean absolute Shapley values:

nb_top <- 8
variables_to_keep_cluster <-
  order_variables_shap_xgb |>
  arrange(desc(value)) |>
  dplyr::slice_head(n = nb_top) |>
  pull("name")
variables_to_keep_cluster
[1] "MENAGE_revucinsee"       "SOINS_remomn"           
[3] "SOINS_rempha"            "QST_ct_liberteNever"    
[5] "PERSONNE_pb_lombalgiYes" "PERSONNE_sexeMale"      
[7] "PERSONNE_age"            "QST_ct_depechSometimes" 

We can visualize which variables were kept:

ggplot(
  data = order_variables_shap_xgb |>
    left_join(variable_names, by = c("name" = "variable")) |>
    left_join(variable_names_categ, by = c("name" = "variable")) |>
    mutate(label = ifelse(is.na(label), label_categ, label)) |>
    rename(mean_abs_shap_value = value) |>
    mutate(
      in_top_kept = name %in% variables_to_keep_cluster
    ),
  mapping = aes(
    y = fct_reorder(label, desc(mean_abs_shap_value)),
    x = mean_abs_shap_value
  )
) +
  geom_point(mapping = aes(colour = in_top_kept)) +
  geom_vline(
    xintercept = mean(order_variables_shap_xgb$value), colour = "red"
  ) +
  scale_colour_manual(
    "Variable kept for clustering",
    values = c("FALSE" = "gray", "TRUE" = "#009D57"),
    labels = c("FALSE" = "Discarded", "+" = "Kept")
  ) +
  labs(x = "Mean(|Shap value|)", y = "Variable")
Figure 14.10: Variables kept for the clustering and their averable absolute Shapley value.

We need to select a number of clusters K. Instead of arbitrarily pick a value, we vary the number of clusters and look at the silhouette information for each value K.

sil_val <- NULL
dist_m <- dist(
  as.data.frame(
    df_clustering_Not_D_and_inf_Q1 |>
      select(!!!syms(variables_to_keep_cluster))
  )
)
hierarchical_clust <- hclust(dist_m, method = "ward.D")
for (K in 2:15) {
  clusters_k <- cutree(hierarchical_clust, K)
  sil <- cluster::silhouette(clusters_k, dist_m)
  sil_val <- c(sil_val, mean(sil[, 3]))
}
p_silhouette <-
  ggplot(data = tibble(K = 2:15, y = sil_val), aes(x = K, y = y)) +
  geom_line() +
  geom_point() +
  labs(x = "K", y = "Silhouette score")
p_silhouette
Figure 14.11: Silhouette scores

We hence pick 2 as the number of desired clusters, and then assign the observation their cluster.

K <- 2
clusters_km_Not_D_and_inf_Q <- hierarchical_clust
df_clustering_Not_D_and_inf_Q1$cluster <- cutree(clusters_km_Not_D_and_inf_Q, K)

We create a table to plot the Shap values of these individuals, and order the observations according to their distance from each other.

df_plot <- df_clustering_Not_D_and_inf_Q1[clusters_km_Not_D_and_inf_Q$order, ]

For variables that are not in the top, we create an “Other” category.

df_plot <-
  df_plot |> select(-id_row) |>
  mutate(x = row_number()) |>
  pivot_longer(
    cols = -c(cluster, x), names_to = "variable", values_to = "value"
  ) |>
  mutate(
    variable = ifelse(
      variable %in% variables_to_keep_cluster,
      yes = variable,
      no = "Other"
    )
  ) |>
  group_by(cluster, x, variable) |>
  summarise(value = mean(value), .groups = "drop") |>
  left_join(
    variable_names |>
      bind_rows(variable_names_categ |> select(variable, label = label_categ)) |>
      bind_rows(tibble(variable = "Other", label = "Other", type = NA)),
    by = c("variable" = "variable")
  ) |>
  arrange(x)

The order of the variables:

labels_order <-
  tibble(variable = variables_to_keep_cluster) |>
  left_join(variable_names, by = c("variable")) |>
  left_join(variable_names_categ, by = c("variable")) |>
  mutate(label = ifelse(is.na(label), label_categ, label)) |>
  pull(label)

labels_order <- c(labels_order, "Other")

We apply this order to the variables in the data plot.

df_plot <-
  df_plot |>
  mutate(label = factor(label, levels = labels_order))

Then, we can plot the graph to visualize the average the contribution of each variable of the top within each cluster.

nb_vars <- df_plot$variable |> unique() |> length()

x_lines <-
  df_plot |>
  group_by(cluster) |>
  summarise(x_max = max(x)) |>
  pull(x_max) |>
  sort()
x_lines <- c(1, x_lines)

plot_clusters <-
  ggplot(data = df_plot) +
  geom_vline(xintercept = x_lines) +
  geom_col(aes(x = x, y = value, fill = label), width =1, alpha = 0.9) +
  geom_hline(yintercept = 0, col = "black") +
  scale_fill_manual("", values = c(RColorBrewer::brewer.pal(nb_vars-1, 'Paired'), "gray")) +
  labs(x = "Individuals", y = "Output value") +
  # theme_paper() +
  theme(
    legend.position = "bottom",
    axis.text.x = element_blank(),
    axis.ticks.x = element_blank()
  )
plot_clusters + guides(fill = guide_legend(ncol = 3))
Figure 14.12: Decomposition of the Effect of the Most Important Variables on the Probability of Being Predicted as an Imaginary Healthy Patient for Individuals Predicted as Such. All individuals.

With aggregated mean SHAP values at the cluster level:

df_plot_cluster <- 
  df_plot |> 
  group_by(cluster, variable, label) |> 
  summarise(value = mean(value)) |> 
  ungroup() |> 
  mutate(cluster = factor(cluster, levels = unique(df_plot$cluster)))
`summarise()` has grouped output by 'cluster', 'variable'. You can override
using the `.groups` argument.
plot_clusters_2 <-
  ggplot(
    data = df_plot_cluster |>
      mutate(
        cluster = factor(
          cluster,
          levels = levels(df_plot_cluster$cluster),
          labels = seq_len(length(levels(df_plot_cluster$cluster)))
        )
      ),
    mapping = aes(x = cluster, fill = label, y = value)
  ) +
  geom_bar(position = "stack", stat = "identity") +
  labs(x = "Cluster",y = "Output value") +
  scale_fill_manual(
    "", 
    values = c(RColorBrewer::brewer.pal(nb_vars-1, 'Paired'), "gray")
  )+
  theme(
    legend.position = "bottom"
  )
plot_clusters_2 + guides(fill = guide_legend(ncol = 3)) +
  geom_hline(yintercept = 0, linetype = "dashed")
Figure 14.13: Decomposition of the Effect of the Most Important Variables on the Probability of Being Predicted as an Imaginary Healthy Patient for Individuals Predicted as Such. Average per Cluster.

14.5.2 Descriptive Statistics of the Clusters

Let us now get some descriptive statistics at the cluster level, for each of these variables. We need to get the values of the variables for the individuals predicted as imaginary healthy patients.

For categorical variables, we would like to get the most frequent value, in each cluster. For numerical variables, we would like to get the average and standard deviation. Let us isolate these two types of variables.

The categorical variables:

top_name_categ <-
  variable_names_categ |>
  filter(variable %in% top_n_variables_xgb) |>
  pull("variable_raw")

The numerical variables:

top_name_num <- top_n_variables_xgb[!top_n_variables_xgb %in% variable_names_categ$variable]

Let us create a dataset with the characteristics and the cluster in which the individuals belong:

df_stat_des <-
  df_all_xgb |>
  tibble() |>
  mutate(id_row = row_number()) |>
  # Add cluster
  left_join(
    df_clustering_Not_D_and_inf_Q1 |> select(id_row, cluster),
    by = "id_row"
  )

We define a function, get_stats() which computes the desired statistics:

get_stats <- function(df) {
  # Desc. Stat on numerical variables
  df_num <-
    df |>
    select(!!top_name_num) |>
    pivot_longer(cols = everything()) |>
    group_by(name) |>
    summarise(mean = mean(value), sd = sd(value)) |>
    mutate(value = str_c(round(mean, 2), " (", round(sd, 2),")")) |>
    select(name, value)

  df_categ <-
    df |>
    select(!!top_name_categ) |>
    pivot_longer(cols = everything()) |>
    group_by(name) |>
    count(value) |>
    mutate(pct = round(100 *n / sum(n), 1)) |>
    slice_max(order_by = pct, n = 1) |>
    mutate(value = str_c(value, ' (', pct, "%)")) |>
    select(name, value)

  df_num |> bind_rows(df_categ)
}

We will compute the statistics on the following datasets:

  • whole dataset
  • sample of individuals predicted as imaginary healthy patients
  • each of the clusters (i.e., sub sample of the sample of individuals predicted as imaginary healthy patients).

On the whole dataset:

stat_des_whole <- get_stats(df_stat_des) |> mutate(type = "Whole sample")

On the dataset of individuals predicted as imaginary healthy patients by the classifier:

stat_des_imaginary <- df_stat_des |>
  filter(id_row %in% df_clustering_Not_D_and_inf_Q1$id_row) |>
  get_stats() |> mutate(type = "Imaginary healthy")

And on each of the clusters:

df_stat_des_clusters <- NULL
for (i_clust in 1:K) {
  df_stat_des_clusters <-
    df_stat_des_clusters |> bind_rows(
      df_stat_des |>
        filter(cluster == !!i_clust) |>
        get_stats() |> mutate(type = as.character(i_clust))
    )
}

We also compare the clusters by computing the p-values of an ANOVA for quantitative variables, and of a chi-squared test for qualitative variables.

For quantitative variables:

p_values_categ <- NULL
for (name in top_name_categ) {
  values <- df_stat_des |> pull(!!name)
  # confusion matrix
  tab <- table(df_stat_des$cluster, values)
  # chi-squared test
  chi2 <- chisq.test(tab, correct = T)
  # p-value
  pvalue <- chi2$p.value
  p_values_categ <- bind_rows(
    p_values_categ,
    tibble(name = name, pvalue = pvalue)
  )
}

For numerical variables:

p_values_num <- NULL
for (name in top_name_num) {
  # anova test
  values <- df_stat_des |> pull(!!name)
  anova <- anova(aov(as.numeric(values) ~ df_stat_des$cluster))
  # p-value
  pvalue <- anova$`Pr(>F)`[1]
  p_values_num <- bind_rows(
    p_values_num,
    tibble(name = name, pvalue = pvalue)
  )
}

We merge the p-values in a single object.

tb_p_values <- p_values_categ |> bind_rows(p_values_num)

To build a prettier table, let us build a tibble with the labels of the variables.

table_ref <-
  tibble(variable = top_n_variables_xgb) |>
  left_join(variable_names_categ, by = c("variable")) |>
  mutate(
    variable = ifelse(is.na(variable_raw), variable, variable_raw)
  ) |>
  left_join(
    variable_names, by = "variable"
  )

Let us merge the three tables with the descriptive statistics and add the p-values of the clusters comparisons.

tb_desc_stats_clusters <-
  df_stat_des_clusters |>
  bind_rows(stat_des_whole) |>
  bind_rows(stat_des_imaginary) |>
  mutate(
    type = factor(type, levels = c(1:K, "Imaginary healthy", "Whole sample"))
  ) |>
  pivot_wider(names_from = type, values_from = value) |>
  left_join(
    table_ref |> select(variable, label),
    by = c("name" = "variable")
  ) |>
  left_join(unique(tb_p_values), by = "name") |>
  relocate(label, .before = name) |>
  select(-name) |>
  unique() |>
  mutate(label = factor(label, levels = unique(table_ref$label))) |>
  arrange(label)
tb_desc_stats_clusters |>
  select(-pvalue) |>
  add_column(p_value = format.pval(tb_desc_stats_clusters$pvalue, 2)) |>
  kableExtra::kable() |>
  kableExtra::kable_classic(full_width = F, html_font = "Cambria") |>
  kableExtra::kable_styling(
    bootstrap_options = c("striped", "hover", "condensed", "responsive")
  )
Table 14.2: Average individual in each cluster and in the sample
label 1 2 Whole sample Imaginary healthy p_value
Net Income per Cons. Unit 737.3 (294.62) 1666.93 (746.55) 1609.51 (1008.13) 1205.24 (734.6) < 2e-16
Reimbursement General Practitioner 131.68 (149.24) 181.55 (141.08) 89.44 (112.53) 156.78 (147.28) 6.9e-15
Reimbursement Pharmacy 488.28 (1354.34) 982.78 (2498.11) 361.5 (1431.61) 737.19 (2027.66) 2.3e-08
Very Little Freedom to Do Job No answer (78.5%) No answer (61.5%) No answer (52.1%) No answer (69.9%) < 2e-16
Low Back Pain No (76%) No (57%) No (79.6%) No (66.4%) < 2e-16
Gender Female (62.6%) Female (63.3%) Female (52.2%) Female (63%) 0.78
Age 48.29 (19.32) 58.26 (17.51) 48.7 (18.58) 53.31 (19.09) < 2e-16
Have to Hurry to Do Job No answer (78.1%) No answer (61.4%) No answer (51.8%) No answer (69.7%) < 2e-16
Waiver Appointment Delay Too Long No (56.8%) No (71.2%) No (67.6%) No (64%) < 2e-16
No. Medical Sessions General Pract. 6.67 (6.57) 8.74 (5.96) 4.73 (5.02) 7.71 (6.35) 6.6e-14

Number of observation per cluster:

table(df_clustering_Not_D_and_inf_Q1$cluster)

   1    2 
1036 1050 

Number of observation in the whole sample:

nrow(df_stat_des)
[1] 5293

Number of individuals predicted imaginary healthy:

nrow(df_clustering_Not_D_and_inf_Q1)
[1] 2086

Remaining obs (sanity check):

df_stat_des |>
  filter(!id_row %in% df_clustering_Not_D_and_inf_Q1$id_row) |> 
  nrow()
[1] 3207

Predicted class by XGB:

# Healthy
pred_class_xgb <- predict(
  grid_search_xgb, newdata = df_all_xgb
)

Overview of predictions in each clusters:

df_stat_des |> 
  mutate(predicted_val = pred_val_all_xgb) |> 
  mutate(
    pred_class = pred_class_xgb,
    correctly_pred = pred_class == status
  ) |> 
  # Keep only predicted as imaginary healthy
  filter(!is.na(cluster)) |> 
  group_by(cluster) |> 
  summarise(
    n = n(),
    accuracy = round(100*mean(correctly_pred), 1),
    no_im_healthy = sum(status == "Not_D_and_inf_Q1"),
    pct_im_healthy = round(100*no_im_healthy / n, 1)
  )
# A tibble: 2 × 5
  cluster     n accuracy no_im_healthy pct_im_healthy
    <int> <int>    <dbl>         <int>          <dbl>
1       1  1036     40.7           422           40.7
2       2  1050     38.3           402           38.3

Overall accuracy:

df_stat_des |> 
  mutate(predicted_val = pred_val_all_xgb) |> 
  mutate(
    pred_class = ifelse(
      predicted_val > .5, 
      yes = "Not_D_and_inf_Q1", "Not_D_and_sup_Q1"
    ),
    correctly_pred = pred_class == status
  ) |> 
  summarise(
    n = n(),
    accuracy = round(100*mean(correctly_pred),1),
    no_im_healthy = sum(status == "Not_D_and_inf_Q1"),
    pct_im_healthy = round(100*no_im_healthy / n, 1)
  )
# A tibble: 1 × 4
      n accuracy no_im_healthy pct_im_healthy
  <int>    <dbl>         <int>          <dbl>
1  5293     67.9          1262           23.8

Percentage of individuals correctly predicted Not_D_and_inf_Q1, i.e., imaginary healthy:

df_stat_des |> 
  mutate(predicted_val = pred_val_all_xgb) |> 
  mutate(
    pred_class = ifelse(
      predicted_val > .5, 
      yes = "Not_D_and_inf_Q1", "Not_D_and_sup_Q1"
    ),
    correctly_pred = pred_class == status
  ) |> 
  filter(pred_class == "Not_D_and_inf_Q1") |> 
  summarise(
    n = n(),
    accuracy = round(100*mean(correctly_pred),1),
    no_im_healthy = sum(status == "Not_D_and_inf_Q1"),
    pct_im_healthy = round(100*no_im_healthy / n, 1)
  )
# A tibble: 1 × 4
      n accuracy no_im_healthy pct_im_healthy
  <int>    <dbl>         <int>          <dbl>
1  2086     39.5           824           39.5

Percentage of individuals correctly predicted Not_D_and_sup_Q1, i.e., healthy:

# healthy
df_stat_des |> 
  mutate(predicted_val = pred_val_all_xgb) |> 
  mutate(
    pred_class = ifelse(
      predicted_val > .5, 
      yes = "Not_D_and_inf_Q1", "Not_D_and_sup_Q1"
    ),
    correctly_pred = pred_class == status
  ) |> 
  filter(pred_class == "Not_D_and_sup_Q1") |> 
  summarise(
    n = n(),
    accuracy = round(100*mean(correctly_pred),1),
    no_im_healthy = sum(status == "Not_D_and_inf_Q1"),
    pct_im_healthy = round(100*no_im_healthy / n, 1)
  )
# A tibble: 1 × 4
      n accuracy no_im_healthy pct_im_healthy
  <int>    <dbl>         <int>          <dbl>
1  3207     86.3           438           13.7