20  SHAP Results

Objectives

This chapter presents the results from Chapter 19.

library(tidyverse)
── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
✔ dplyr     1.1.4     ✔ readr     2.1.5
✔ forcats   1.0.0     ✔ stringr   1.5.1
✔ ggplot2   3.5.1     ✔ tibble    3.2.1
✔ lubridate 1.9.3     ✔ tidyr     1.3.1
✔ purrr     1.0.2     
── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
✖ dplyr::filter() masks stats::filter()
✖ dplyr::lag()    masks stats::lag()
ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(caret)
Loading required package: lattice

Attaching package: 'caret'

The following object is masked from 'package:purrr':

    lift
library(ranger)
library(treeshap)

Let us define a theme for graphs.

Definition of theme_paper()
#' Theme for ggplot2
#'
#' @param ... arguments passed to the theme function
#' @export
#' @importFrom ggplot2 element_rect element_text element_blank element_line unit
#'   rel
theme_paper <- function (...) {
  ggthemes::theme_base() +
    theme(
      legend.background = element_rect(
        fill = "transparent", linetype="solid", colour ="black"),
      legend.position = "bottom",
      legend.direction = "horizontal",
      legend.box = "horizontal",
      legend.key = element_blank()
    )
}

20.1 Load data and estimated models

20.1.1 Data

Let us load the train and test data:

load("../data/out/df_train_sah.rda")
load("../data/out/df_test_sah.rda")

For random forests, we need to convert categorical variables to numerical variables:

df_train_num <-
  df_train |>
  mutate(across(where(is.factor), as.numeric))

df_test_num <- 
  df_test |>
  mutate(across(where(is.factor), as.numeric))

We keep track of the levels of factor data:

corresp_factors <- df_train |> 
  select_if(is.factor) |> 
  map(levels)

20.1.2 Classifiers

The estimated models (see ?sec-estimations-mhi):

# The estimated random forests (and the grid)
load("../data/out/estim/v3/grid_search_rf_sah.rda")
# The estimated xgb (and the grid)
load("../data/out/estim/v3/grid_search_xgb_sah.rda")

Let us get the best classifiers obtained both for the random forest and XGBoost. Let us begin with random forests:

ind_remove <- which(
  colnames(df_train) %in% c("status", "id", "PERSONNE_statut")
)
formula <- as.formula(
  paste("status ~", paste(colnames(df_train[-ind_remove]), collapse = "+"))
)

final_model_rf <- ranger(
  formula,
  data = df_train_num, 
  mtry = grid_search_rf$bestTune$mtry,
  splitrule = "gini",
  min.node.size = grid_search_rf$bestTune$min.node.size,
  classification = T
)

And for the extreme gradient boosting algorithm:

final_model_xgb <- grid_search_xgb$finalModel

20.1.3 SHAP Values

We load the results estimated SHAP values (see Chapter 19):

load("../data/out/treeSHAP/v3/treeshap_rf_sah.rda")
load("../data/out/treeSHAP/v3/treeshap_xgb_sah.rda")

The explainer used only the values from the train set to estimate the SHAP values:

reference_data_rf <- df_train_num
reference_data_xgb <- df_train

And the SHAP values were estimated on both the train set and the test set:

df_all_rf <- bind_rows(df_train_num, df_test_num)
df_all_xgb <- bind_rows(df_train, df_test)

20.2 Predicted Values (classifier)

Let us get the predicted values for both classifiers.

20.2.1 Random Forest

pred_val_ref_rf <- predict(final_model_rf, reference_data_rf)$predictions
head(pred_val_ref_rf)
[1] 0 1 1 1 1 1

The issue is that this gives the majority vote, and not the associated scores. Let us compute the scores by aggregating the votes on each tree.

# Get the scores instead:
predict_score_ranger <- function(object, data) {
  pred_all <- predict(object, data, predict.all = TRUE)
  pred_all_df <- as.data.frame(pred_all$predictions)
  ntrees <- ncol(pred_all_df)
  pred_all_df$votes <- rowSums( pred_all_df == 1 ) / ntrees
  pred_all_df$votes
}
pred_val_rf <- predict_score_ranger(final_model_rf, reference_data_rf)
head(pred_val_rf)
[1] 0.406 0.788 0.786 0.854 0.826 0.866

20.2.2 XGBoost

And now, for XGBoost:

pred_val_xgb <- predict(
  grid_search_xgb, newdata = reference_data_xgb,type = "prob"
)
pred_val_xgb <- pred_val_xgb[, "Not_D_and_inf_Q1"]
head(pred_val_xgb)
[1] 0.4140283 0.4672670 0.3959899 0.4698538 0.6405708 0.4444454

Let us get the predicted values on both the train set and the test set as well:

pred_val_all_xgb <- predict(
  grid_search_xgb, newdata = df_all_xgb, type = "prob"
)
pred_val_all_xgb <- pred_val_all_xgb[, "Not_D_and_inf_Q1"]

20.3 Understanding the Output

Let us have a look at the first individual to explain the results obtained in the table. Let us have a look at the first individual.

id_indiv <- 1
df_all_rf[id_indiv, ]
  PERSONNE_pb_asthm PERSONNE_pb_bronchit PERSONNE_pb_infarctus
1                 3                    3                     3
  PERSONNE_pb_coronair PERSONNE_pb_hypertens PERSONNE_pb_avc
1                    3                     3               3
  PERSONNE_pb_arthros PERSONNE_pb_lombalgi PERSONNE_pb_cervical
1                   3                    3                    3
  PERSONNE_pb_diabet PERSONNE_pb_allergi PERSONNE_pb_cirrhos
1                  3                   3                   3
  PERSONNE_pb_urinair PERSONNE_age PERSONNE_sexe PERSONNE_couple
1                   3           65             1               3
  PERSONNE_statut PERSONNE_ss PERSONNE_regime PERSONNE_rap_pcs8 PERSONNE_ald
1               2           1               1                 3            2
  SOINS_ald_am MENAGE_revucinsee MENAGE_tu MENAGE_nbpers ensol_2011
1            2              1400         1             2   2280.443
  MUTUELLE_assu MUTUELLE_typcc SOINS_remomn SOINS_remspe SOINS_rempha
1             2              2         75.5       265.14        84.05
  SOINS_remkin SOINS_reminf SOINS_remden SOINS_remmat SOINS_remtra SOINS_remopt
1       346.02         5.18       127.89            0            0            0
  SOINS_rempro SOINS_remurg SOINS_tmomn SOINS_tmspe SOINS_tmpha SOINS_tmkin
1            0            0        34.5      117.92       90.77      243.63
  SOINS_tminf SOINS_tmden SOINS_tmmat SOINS_tmtra SOINS_tmopt SOINS_tmpro
1        3.78       54.82           0           0           0           0
  SOINS_tmurg SOINS_dpaomn SOINS_dpaspe SOINS_dpapha SOINS_dpakin SOINS_dpainf
1           0            0           34            0            0            0
  SOINS_dpaden SOINS_dpamat SOINS_dpatra SOINS_dpaopt SOINS_dpapro SOINS_dpaurg
1        282.5            0            0            0            0            0
  SOINS_pf_fromn SOINS_pf_frspe SOINS_pf_frpha SOINS_pf_frkin SOINS_pf_frinf
1              5             10             17           19.5            0.5
  SOINS_pf_frden SOINS_pf_frtra SOINS_pf_frurg SOINS_seac_omn SOINS_seac_spe
1              0              0              0              5             10
  OPINION1_renonc_cons OPINION1_renonc_dent OPINION1_renonc_fin
1                    2                    2                   2
  OPINION1_renonc_loin OPINION1_renonc_long QST_ct_depech QST_ct_liberte
1                    2                    2             5              5
  QST_ct_apprend QST_ct_aidecol QST_ct_travnuit QST_ct_repet QST_ct_lourd
1              5              6               5            5            5
  QST_ct_posture QST_ct_produit QES_association QES_tpsami QES_tpsasso
1              5              5               1          1           3
  QES_tpscolleg QES_tpsfamil QES_mere_etude QES_pere_etude id status
1             6            3              5              5  1      1

The estimated SHAP values for this individual (with the random forest):

treeshap_rf$shaps[id_indiv, ]
  PERSONNE_pb_asthm PERSONNE_pb_bronchit PERSONNE_pb_infarctus
1      -0.007246341          -0.01336903          -0.004644532
  PERSONNE_pb_coronair PERSONNE_pb_hypertens PERSONNE_pb_avc
1         -0.006694261          -0.001778938    -0.003898941
  PERSONNE_pb_arthros PERSONNE_pb_lombalgi PERSONNE_pb_cervical
1        -0.001554188          -0.02055935          -0.05266636
  PERSONNE_pb_diabet PERSONNE_pb_allergi PERSONNE_pb_cirrhos
1       -0.008534858          -0.0259766        -0.005046734
  PERSONNE_pb_urinair PERSONNE_age PERSONNE_sexe PERSONNE_couple  PERSONNE_ss
1         -0.01245911  0.009956432   -0.02699848     0.003119755 0.0003633826
  PERSONNE_regime PERSONNE_rap_pcs8 PERSONNE_ald  SOINS_ald_am
1    -0.003175054       0.006178235 0.0003112731 -0.0002916731
  MENAGE_revucinsee   MENAGE_tu MENAGE_nbpers   ensol_2011 MUTUELLE_assu
1       0.002390928 0.001722672  0.0003568199 -0.002993251   0.000580098
  MUTUELLE_typcc SOINS_remomn SOINS_remspe SOINS_rempha SOINS_remkin
1   0.0005672605 -0.001176905  -0.01859792  0.001755285  -0.02594093
  SOINS_reminf SOINS_remden SOINS_remmat SOINS_remtra SOINS_remopt SOINS_rempro
1 -5.64542e-05 -0.003214886 2.276637e-05   0.00114703 0.0001741754 -0.001419339
  SOINS_remurg SOINS_tmomn SOINS_tmspe SOINS_tmpha SOINS_tmkin  SOINS_tminf
1  0.000990654 -0.01598007 -0.02313433 -0.01579783 -0.03061514 0.0006350854
   SOINS_tmden SOINS_tmmat  SOINS_tmtra   SOINS_tmopt   SOINS_tmpro
1 -0.004696977 0.000677272 0.0005885813 -0.0008278269 -0.0002139524
   SOINS_tmurg SOINS_dpaomn SOINS_dpaspe SOINS_dpapha SOINS_dpakin SOINS_dpainf
1 0.0008722014 0.0008331505  0.001751509 4.664878e-05 7.918558e-05  3.15962e-05
   SOINS_dpaden SOINS_dpamat SOINS_dpatra  SOINS_dpaopt SOINS_dpapro
1 -0.0008823053 0.0004741555 2.530444e-05 -0.0002633437  0.001130082
  SOINS_dpaurg SOINS_pf_fromn SOINS_pf_frspe SOINS_pf_frpha SOINS_pf_frkin
1 9.816743e-06    0.001152235   -0.007856489    -0.01523178    -0.02256475
  SOINS_pf_frinf SOINS_pf_frden SOINS_pf_frtra SOINS_pf_frurg SOINS_seac_omn
1   0.0004703884   4.160579e-05    0.000287363   0.0004107725  -0.0005260417
  SOINS_seac_spe OPINION1_renonc_cons OPINION1_renonc_dent OPINION1_renonc_fin
1    -0.02025151          0.002553635          0.005560955         0.001479719
  OPINION1_renonc_loin OPINION1_renonc_long QST_ct_depech QST_ct_liberte
1          0.001318324          0.002934469   0.005777593    0.002055691
  QST_ct_apprend QST_ct_aidecol QST_ct_travnuit QST_ct_repet QST_ct_lourd
1   -0.001787547   -0.001737267    -0.000171069  0.001085351   0.00145683
  QST_ct_posture QST_ct_produit QES_association QES_tpsami QES_tpsasso
1  -0.0001175104   0.0005289548      0.00180956 0.01116717 0.002723404
  QES_tpscolleg QES_tpsfamil QES_mere_etude QES_pere_etude
1   0.001195995 0.0005406768   -0.004747627   -0.006676365

Each column gives the SHAP value for a variable, for this individual. Each row of treeshap_rf gives the SHAP values for a single individual.

In the reference data (we used the train set), the average predicted scores is:

(mean_pred_rf_ref <- mean(pred_val_rf)) # random forest
[1] 0.7470318
(mean_pred_xgb_ref <- mean(pred_val_xgb)) # XGBoost
[1] 0.4750502

For the first individual, the predicted score, for the random, is:

pred_indiv_rf <- 
  predict_score_ranger(final_model_rf, df_all_rf[id_indiv, ])
pred_indiv_rf
[1] 0.406

And with XGB:

pred_indiv_xgb <- predict(grid_search_xgb, df_all_xgb[id_indiv, ], type = "prob")
pred_indiv_xgb <- pred_indiv_xgb[, "Not_D_and_inf_Q1"]
pred_indiv_xgb
[1] 0.4140283

For the record, the actual class is:

df_all_rf[id_indiv, "status"]
[1] 1
df_all_xgb[id_indiv, "status"]
[1] Not_D_and_inf_Q1
Levels: Not_D_and_inf_Q1 Not_D_and_sup_Q1

The sum of the SHAP values correspond to the deviation from the average of the scores estimated on the reference data:

shap_indiv_rf <- treeshap_rf$shaps[id_indiv, ] # random forest
shap_indiv_xgb <- treeshap_xgb$shaps[id_indiv, ] # XGBoost

For the random forest:

c(
  "shap" = sum(shap_indiv_rf), 
  "pred" = pred_indiv_rf, 
  "mean_pred" = mean_pred_rf_ref,
  "diff" = pred_indiv_rf - mean_pred_rf_ref
)
      shap       pred  mean_pred       diff 
-0.3410318  0.4060000  0.7470318 -0.3410318 

and for XGBoost:

c(
  "shap" = sum(shap_indiv_xgb), 
  "pred" = pred_indiv_xgb, 
  "mean_pred" = mean_pred_xgb_ref,
  "diff" = pred_indiv_xgb - mean_pred_xgb_ref
)
       shap        pred   mean_pred        diff 
-0.34732477  0.41402832  0.47505017 -0.06102186 

Let us now visualize the SHAP values for a specific individual.

load("../data/out/variable_names.rda")

We build a table with variable names.

variable_names_categ <- 
  enframe(corresp_factors) |> 
  left_join(variable_names, by = c("name" = "variable")) |> 
  unnest(value) |> 
  mutate(
    variable = str_c(name, value),
    label_categ = str_c(label, " = ", value)) |> 
  select(variable, label_categ, variable_raw = name, variable_label_val = value)
variable_names_categ
# A tibble: 189 × 4
   variable                       label_categ    variable_raw variable_label_val
   <chr>                          <chr>          <chr>        <chr>             
 1 PERSONNE_pb_asthmNo            Asthma = No    PERSONNE_pb… No                
 2 PERSONNE_pb_asthmYes           Asthma = Yes   PERSONNE_pb… Yes               
 3 PERSONNE_pb_asthmNo answer     Asthma = No a… PERSONNE_pb… No answer         
 4 PERSONNE_pb_bronchitNo         Bronchitis = … PERSONNE_pb… No                
 5 PERSONNE_pb_bronchitYes        Bronchitis = … PERSONNE_pb… Yes               
 6 PERSONNE_pb_bronchitNo answer  Bronchitis = … PERSONNE_pb… No answer         
 7 PERSONNE_pb_infarctusNo        Heart Attack … PERSONNE_pb… No                
 8 PERSONNE_pb_infarctusYes       Heart Attack … PERSONNE_pb… Yes               
 9 PERSONNE_pb_infarctusNo answer Heart Attack … PERSONNE_pb… No answer         
10 PERSONNE_pb_coronairNo         Artery Diseas… PERSONNE_pb… No                
# ℹ 179 more rows

We define a function, plot_individual_variable_effect() to understand why, according to the estimated SHAP values, the predicted value for a single individuals deviates from the average prediction in the reference sample.

#' Plot the decomposition of the deviation from the average prediction
#' for an individual, using SHAP values
#' 
#' @param i row number to designate individual
#' @param treeshap_res result obtained with treeshap
#' @param predicted_val_i predicted value for the individual
#' @param mean_pred_ref average predicted value in the reference data
#' @param n top n variables (default to 10)
#' @param min_max if not `NULL`, limits for the x axis (Shap values), in a 
#'  vector of length 2 containing c(min, max)
plot_individual_variable_effect <- function(i, 
                                            treeshap_res, 
                                            predicted_val_i,
                                            mean_pred_ref,
                                            n = 10,
                                            min_max = NULL) {
  
  df_plot <- 
    treeshap_res$shaps |> 
    dplyr::slice(i) |> 
    unlist() |> 
    enframe(name = "variable", value = "shap") |> 
    left_join(variable_names, by = "variable") |> 
    left_join(variable_names_categ, by = "variable") |> 
    mutate(label = ifelse(is.na(label), label_categ, label)) |> 
    arrange(desc(abs(shap))) |> 
    mutate(in_top_n = row_number() <= !!n) |> 
    mutate(
      label_2 = ifelse(in_top_n, label, "Other Variables")
    ) |> 
    group_by(label_2) |> 
    summarise(
      shap = sum(shap),
      .groups = "drop"
    ) |> 
    mutate(
      label_2 = fct_reorder(label_2, abs(shap)),
      label_2 = fct_relevel(label_2, "Other Variables")
    ) |> 
    mutate(
      mean_pred_ref = !!mean_pred_ref,
      shap_rounded = str_c("  ", round(shap, 3)),
      sign = sign(shap),
      sign = factor(sign, levels = c(-1, 1), labels = c("-", "+"))
    )
  
  p <- ggplot(
    data = df_plot, 
    mapping = aes(
      y =  mean_pred_ref + pmax(shap, 0),
      x = label_2,
      ymin = mean_pred_ref,
      ymax = mean_pred_ref + shap,
      colour = sign)
  ) +
    geom_linerange(linewidth = 8) +
    geom_hline(yintercept = mean_pred_ref) +
    geom_text(aes(label = shap_rounded, hjust = 0)) +
    coord_flip() +
    scale_colour_manual(
      "Variable impact", 
      values = c("-" = "#EE324E", "+" = "#009D57"), 
      labels = c("-" = "Negative", "+" = "Positive")
    ) +
    labs(
      x = NULL, 
      y = "Shap values", 
      title=str_c("Predicted value: ", round(predicted_val_i, 3)),
      subtitle = str_c("Base value: ", round(mean_pred_ref, 3))
    ) +
    theme(
      panel.grid.major.y = element_blank(), axis.ticks.y = element_blank(),
      legend.position = "bottom",
      plot.title.position = "plot",
      plot.title = element_text(hjust = 0, size = rel(1.3), face = "bold")
    )
  
  if (!is.null(min_max)) {
    p <- p + scale_y_continuous(limits = min_max)
  }
  p
}

Let us look at the first individual:

i <- 1
treeshap_rf_opposite <- treeshap_rf
treeshap_rf_opposite$shaps <- -treeshap_rf_opposite$shaps
plot_individual_variable_effect(
  i = i, 
  treeshap_res = treeshap_rf, 
  predicted_val_i = 1-pred_val_rf[i], 
  mean_pred_ref = 1-mean_pred_rf_ref, n = 20
)
Figure 20.1: SHAP values for the first individual, when the score is estimated with a Random Forest. Positive class: Imaginary healthy patient.
plot_individual_variable_effect(
  i = i, 
  treeshap_res = treeshap_xgb, 
  predicted_val_i = pred_val_xgb[i], 
  mean_pred_ref = mean_pred_xgb_ref, n = 20
)
Figure 20.2: SHAP values for the first individual, when the score is estimated with XGBoost. Positive class: Imaginary healthy patient.

Now, let us pick two individuals: one with a low predicted score, and another one with a high predicted score. We define two functions: get_top_n_indiv(), to get the top n variables explaining the deviation of the score of an individual from the average predicted score in the reference set, and get_table_charact() to extract the characteristics of specific individuals given a set of characteristics.

#' Extracts the top n variables according to absolute SHAP val
#' 
#' @param i row number of individual
#' @param treeshap_res result obtained with treeshap
#' @param n top n variables (default to 10)
get_top_n_indiv <- function(i, treeshap_res, n = 10) {
  treeshap_res$shaps |> 
    dplyr::slice(i) |> 
    unlist() |> 
    enframe(name = "variable", value = "shap") |> 
    arrange(desc(abs(shap))) |> 
    filter(row_number() <= !!n) |> 
    pull("variable")
}



#' Get the characteristics of the union of top n variables among individuals
#' (using abolute SHAP values)
#' for xgboost only
#' 
#' @param i row numbers of individuals
#' @param treeshap_res result obtained with treeshap
#' @param reference_data reference dataset
#' @param n top n variables (default to 10)
get_table_charact <- function(i, 
                              treeshap_res, 
                              n = 10) {
  # Top n variables for individuals
  top_n_indiv <- map(
    i, 
    ~get_top_n_indiv(i = .x, treeshap_res = treeshap_res, n = n)
  )
  top_n_indiv_variables <- list_c(top_n_indiv) |> unique()
  
  # Average values for those variables in the reference sample
  values_reference <- 
    model.matrix(formula, reference_data_xgb) |> 
    as_tibble() |> 
    select(!!top_n_indiv_variables) |> 
    summarise(across(everything(), ~mean(.x))) |> 
    pivot_longer(cols = everything(), names_to = "variable", values_to = "val_ref")
  
  # Characteristics for these individuals
  top_n_indiv_characteristics <- map2(
    .x = i, .y = top_n_indiv,
    ~tibble(variable = top_n_indiv_variables, id_indiv = .x) |> 
      left_join(
        model.matrix(formula, df_all_xgb) |>
          as_tibble() |> 
          dplyr::slice(.x) |> 
          select(!!top_n_indiv_variables) |> 
          pivot_longer(
            cols = everything(), 
            names_to = "variable", values_to = "value_indiv"
          ),
        by = "variable"
      ) |> 
      mutate(in_top_n = ifelse(variable %in% .y, yes = "$\\checkmark$", no = ""))
  ) |> 
    list_rbind()
  top_n_indiv_characteristics |> 
    left_join(variable_names, by = "variable") |> 
    left_join(variable_names_categ, by = "variable") |> 
    left_join(values_reference, by = "variable") |> 
    mutate(label = ifelse(is.na(label), label_categ, label)) |> 
    select(label, id_indiv, value_indiv, in_top_n, val_ref) |> 
    pivot_wider(names_from = id_indiv, values_from = c(value_indiv, in_top_n))
}

Let us identify two individuals:

i_low <- 111
i_high <- 125
round(c(pred_val_xgb[i_low], pred_val_xgb[i_high]), 3)
[1] 0.176 0.793
Code
plot_indiv_low_pred <- 
  plot_individual_variable_effect(
  i = i_low, 
  treeshap_res = treeshap_xgb, 
  predicted_val_i = pred_val_xgb[i_low], 
  mean_pred_ref = mean_pred_xgb_ref, n = 10
  # min_max = c(0.2, .55)
)
plot_indiv_low_pred
Figure 20.3: SHAP values for an individual with a low predicted score, when the latter is estimated with XGBoost. Positive class: Imaginary healthy patient.
Code
plot_indiv_high_pred <- 
  plot_individual_variable_effect(
  i = i_high, 
  treeshap_res = treeshap_xgb, 
  predicted_val_i = pred_val_xgb[i_high], 
  mean_pred_ref = mean_pred_xgb_ref, n = 10
  # min_max = c(0.4, .95)
)
plot_indiv_high_pred
Figure 20.4: SHAP values for an individual with a high predicted score, when the latter is estimated with XGBoost. Positive class: Imaginary healthy patient.

We retrieve the characteristics of the identified most important variables for both individuals, as well as the average values in the reference sample (train set):

tb_two_indivs_example <- get_table_charact(
    i = c(i_low, i_high), 
    treeshap_res = treeshap_xgb, 
    n = 10
)

And then, we display them in a nice table:

Code
tb_two_indivs_example |> 
  select(label, value_indiv_111, in_top_n_111, 
         value_indiv_125, in_top_n_125, val_ref) |> 
  mutate(val_ref = round(val_ref, 2)) |> 
  kableExtra::kable(booktabs = TRUE) |>
  kableExtra::kable_classic(full_width = F) |> 
  kableExtra::kable_styling() |>
  unclass() |> cat()
Table 20.1: Characteristics for the Most Important Variables for Two Individuals According to their SHAP Values.
label value_indiv_111 in_top_n_111 value_indiv_125 in_top_n_125 val_ref
Net Income per Cons. Unit 3071.11 \(\checkmark\) 606.92 \(\checkmark\) 1622.82
Have to Hurry to Do Job = Sometimes 1.00 \(\checkmark\) 0.00 0.20
Gender = Male 1.00 \(\checkmark\) 0.00 \(\checkmark\) 0.47
Waiver Dental Care = No 1.00 \(\checkmark\) 0.00 \(\checkmark\) 0.66
Frequency Meeting with People in Organizations = Never 0.00 \(\checkmark\) 1.00 0.50
Very Little Freedom to Do Job = Never 1.00 \(\checkmark\) 0.00 0.21
Deduct. Pharmacy 0.50 \(\checkmark\) 9.50 12.64
Neck Pain = Yes 0.00 \(\checkmark\) 0.00 0.14
Low Back Pain = Yes 0.00 \(\checkmark\) 0.00 0.19
Waiver Appointment Delay Too Long = No 1.00 \(\checkmark\) 0.00 0.67
Age 54.00 47.00 \(\checkmark\) 47.59
No. Medical Sessions General Pract. 2.00 18.00 \(\checkmark\) 4.55
Waiver General Practitioner = No answer 0.00 0.00 \(\checkmark\) 0.21
Reimbursement Pharmacy 0.30 652.07 \(\checkmark\) 273.31
Reimbursement General Practitioner 30.20 337.10 \(\checkmark\) 83.16
Co-payment Pharmacy 1.85 251.22 \(\checkmark\) 91.03
Occupation = Administrative employee 0.00 1.00 \(\checkmark\) 0.14

20.4 Variable importance

We compute the average absolute Shap value for each variable and order the results by descending values. We filter the results to focus on the predictions of being classified as a imaginary healthy patient.

# Mean(| SHAP |)
var_imp_rf_mean_abs_shap <- treeshap_rf_opposite$shaps |> 
  summarise(across(everything(), ~mean(abs(.x))))

# Mean(SHAP)
var_imp_rf_mean_shap <- treeshap_rf_opposite$shaps |> 
  summarise(across(everything(), ~mean(.x)))

We order the variables by descending values of the mean absolute SHAP values.

order_variables_shap_rf <- 
  sort(unlist(var_imp_rf_mean_abs_shap), decreasing = TRUE) |> 
  enframe() |> 
  left_join(
    sort(unlist(var_imp_rf_mean_shap), decreasing = TRUE) |> 
      enframe(value = "mean"),
    by = "name"
  )

The top 10 variables:

top_n_variables_rf <- order_variables_shap_rf$name[1:10]
top_n_variables_rf
 [1] "MENAGE_revucinsee"    "PERSONNE_sexe"        "QST_ct_depech"       
 [4] "OPINION1_renonc_dent" "SOINS_remomn"         "SOINS_pf_frpha"      
 [7] "PERSONNE_age"         "SOINS_tmpha"          "OPINION1_renonc_long"
[10] "PERSONNE_pb_cervical"
# Mean(| SHAP |)
var_imp_xgb_mean_abs_shap <- treeshap_xgb$shaps |> 
  summarise(across(everything(), ~mean(abs(.x))))

# Mean(SHAP)
var_imp_xgb_mean_shap <- treeshap_xgb$shaps |> 
  summarise(across(everything(), ~mean(.x)))

We order the variables by descending values of the mean absolute SHAP values.

order_variables_shap_xgb <- 
  sort(unlist(var_imp_xgb_mean_abs_shap), decreasing = TRUE) |> 
  enframe() |> 
  left_join(
    sort(unlist(var_imp_xgb_mean_shap), decreasing = TRUE) |> 
      enframe(value = "mean"),
    by = "name"
  )
save(order_variables_shap_xgb, file = "../data/out/order_variables_shap_xgb_sah.rda")

The top 10 variables:

top_n_variables_xgb <- order_variables_shap_xgb$name[1:10]
top_n_variables_xgb
 [1] "PERSONNE_age"            "PERSONNE_sexeMale"      
 [3] "MENAGE_revucinsee"       "OPINION1_renonc_dentNo" 
 [5] "PERSONNE_pb_cervicalYes" "QES_tpsassoNever"       
 [7] "SOINS_pf_frpha"          "PERSONNE_pb_lombalgiYes"
 [9] "QST_ct_depechSometimes"  "PERSONNE_pb_allergiYes" 

Let us plot the top 10 variables by importance according to the SHAP values.

Code
p_variable_importance_rf <- 
  ggplot(
    data = order_variables_shap_rf |> 
      left_join(variable_names, by = c("name" = "variable")) |> 
      mutate(
        label = ifelse(name %in% !!top_n_variables_rf, 
                       label, "Other Variables")
      ) |> 
      group_by(label) |> 
      summarise(
        value = sum(value),
        .groups = "drop"
      ) |> 
      mutate(
        label = fct_reorder(label, value),
        label = fct_relevel(label, "Other Variables")
      ),
    mapping = aes(
      x = label, 
      y = value
    )
  ) +
  geom_bar(stat = "identity") +
  coord_flip() +
  labs(y = "Mean(|SHAP Value|)", x = NULL)

p_variable_importance_rf
Figure 20.5: Variable importance according to SHAP values (with the random forest classifier)
Code
p_variable_importance_xgb <- 
  ggplot(
    data = order_variables_shap_xgb |> 
      left_join(variable_names, by = c("name" = "variable")) |> 
      left_join(variable_names_categ, by = c("name" = "variable")) |> 
      mutate(
        label = ifelse(is.na(label), label_categ, label),
        label = ifelse(name %in% !!top_n_variables_xgb, 
                       label, "Other Variables")
      ) |> 
      group_by(label) |> 
      summarise(
        value = sum(value),
        .groups = "drop"
      ) |> 
      mutate(
        label = fct_reorder(label, value),
        label = fct_relevel(label, "Other Variables")
      ),
    mapping = aes(
      x = label, 
      y = value
    )
  ) +
  geom_bar(stat = "identity") +
  coord_flip() +
  labs(y = "Mean(|SHAP Value|)", x = NULL)

p_variable_importance_xgb
Figure 20.6: Variable importance according to SHAP values (with the XGBoost classifier)

The complete ranking, for both models:

order_variables_shap_xgb |> 
  mutate(
    across(c(value, mean), ~round(.x, 4)),
    rank_xgb = row_number()
  ) |> 
  rename(
    mean_abs_shap_xgb = value,
    mean_shap_xgb = mean
  ) |> 
  full_join(
    order_variables_shap_rf |> 
      mutate(
        across(c(value, mean), ~round(.x, 4)),
        rank_rf = row_number()
      ) |> 
      rename(
        mean_abs_shap_rf = value,
        mean_shap_rf = mean
      )
  ) |> 
  DT::datatable()
Joining with `by = join_by(name)`

In Python, the summary_plot() function is used to offer a quick view of the effect of each variable on the prediction. We can reproduce such a plot. For quantitative variables, this type of graph seems advantageous. For quantitative variables, on the other hand, we prefer a different type of graph (see below).

Let us define a function to rescale values (min-max rescaling).

#' std1
#' a function to standardize feature values into same range
#' Source: https://github.com/pablo14/shap-values/blob/master/shap.R
std1 <- function(x) {
  ((x - min(x, na.rm = T)) / (max(x, na.rm = T) - min(x, na.rm = T)))
}

Then, we need to create a table that contains for each individual and each variable, the Shap value, and the standardized value of the characteristics.

values_min_max_rf <- 
  df_all_rf |> 
  select(!!top_n_variables_rf) |> 
  pivot_longer(cols = everything()) |> 
  group_by(name) |> 
  summarise(
    min_value = min(value), max_value = max(value),
    med_value = median(value),
    q1_value = quantile(value, probs = .25),
    q3_value = quantile(value, probs = .75)
  )
values_min_max_xgb <- 
  as_tibble(model.matrix(formula, df_all_xgb)) |> 
  select(!!top_n_variables_xgb) |> 
  pivot_longer(cols = everything()) |> 
  group_by(name) |> 
  summarise(
    min_value = min(value), max_value = max(value),
    med_value = median(value),
    q1_value = quantile(value, probs = .25),
    q3_value = quantile(value, probs = .75)
  )

We format the data to be able to create the plot.

df_plot_rf <- 
  treeshap_rf_opposite$shaps |> 
  select(!!top_n_variables_rf) |> 
  mutate(id_row = row_number()) |> 
  pivot_longer(cols = -id_row, values_to = "shap") |> 
  left_join(
    df_all_rf |> 
      select(!!top_n_variables_rf) |> 
      mutate(id_row = row_number()) |> 
      pivot_longer(cols = -id_row, values_to = "value"),
    by = c("id_row", "name")
  ) |> 
  group_by(name) |> 
  mutate(
    std_shap_value = std1(shap),
    mean_shap_value = mean(shap),
    mean_abs_shap_value = mean(abs(shap))
  ) |> 
  ungroup() |> 
  left_join(
    values_min_max_rf,
    by = "name"
  ) |> 
  mutate(
    value_variable_std = (value - min_value) / (max_value - min_value)
  ) |> 
  left_join(variable_names, by = c("name" = "variable")) |> 
  mutate(
    label = fct_reorder(label, mean_abs_shap_value)
  )
df_plot_xgb <- 
  treeshap_xgb$shaps |> 
  select(!!top_n_variables_xgb) |> 
  mutate(id_row = row_number()) |> 
  pivot_longer(cols = -id_row, values_to = "shap") |> 
  left_join(
    as_tibble(model.matrix(formula, df_all_xgb)) |> 
      select(!!top_n_variables_xgb) |> 
      mutate(id_row = row_number()) |> 
      pivot_longer(cols = -id_row, values_to = "value"),
    by = c("id_row", "name")
  ) |> 
  group_by(name) |> 
  mutate(
    std_shap_value = std1(shap),
    mean_shap_value = mean(shap),
    mean_abs_shap_value = mean(abs(shap))
  ) |> 
  ungroup() |> 
  left_join(
    values_min_max_xgb,
    by = "name"
  ) |> 
  mutate(
    value_variable_std = (value - min_value) / (max_value - min_value)
  ) |> 
  left_join(variable_names, by = c("name" = "variable")) |> 
  left_join(variable_names_categ, by = c("name" = "variable")) |> 
  mutate(
    label = ifelse(is.na(label), label_categ, label),
    label = fct_reorder(label, mean_abs_shap_value)
  )

Then, using the geom_sina() function from {ggforce}, we can create the summary plot. Each dot, for each variable given in the y-axis represent an individual. The x-axis gives the estimated Shap value of the variable for each individual. The colours state whether the value of the variable for a specific individual is low (blue) or high (red), using the standardized value as the reference. The variables on the y-axis appear according to their relative importance in explaining the prediction. For example, the graph shows that pharmacy deductible amounts are relatively important in explaining the probability of being classified as an imaginary healthy patient. When the value of this franchise is low (blue dots), this variable influences the prediction downwards (since the blue dots are mostly associated with negative Shap values); conversely, when the pharmacy franchise is high (red dots), the probability of being classified as an imaginary healthy patient increases.

Code
library(ggforce)
p_rf <- ggplot(
  data = df_plot_rf,
  mapping = aes(
    y = shap, 
    x = label, 
    colour = value_variable_std
  )
) +
  geom_sina(alpha = .5) +
  scale_color_gradient(
    "Feature value",
    low="#0081BC", high = "#EE324E",
    breaks = c(0,1), labels = c("Low", "High"),
    guide = guide_colourbar(
      direction = "horizontal",
      barwidth = 40,
      title.position = "bottom",
      title.hjust = .5,
      midpoint = 0
    )
  ) +
  labs(y = "SHAP value (impact on model output)", x = NULL) +
  theme(
    legend.position = "bottom",
    legend.key.height   = unit(1, "line"),
    legend.key.width    = unit(1.5, "line")
  ) +
  coord_flip() +
  geom_hline(yintercept = 0, linetype = "dashed")
p_rf
Figure 20.7: Estimated effect of variables on the probability of being predicted an imaginary healthy patienteach individual (random forest)
Code
p_xgb <- ggplot(
  data = df_plot_xgb,
  mapping = aes(
    y = shap, 
    x = label, 
    colour = value_variable_std
  )
) +
  geom_sina(alpha = .5) +
  scale_color_gradient(
    "Feature value",
    low="#0081BC", high = "#EE324E",
    breaks = c(0,1), labels = c("Low", "High"),
    guide = guide_colourbar(
      direction = "horizontal",
      barwidth = 40,
      title.position = "bottom",
      title.hjust = .5,
      midpoint = 0
    )
  ) +
  labs(y = "SHAP value (impact on model output)", x = NULL) +
  theme(
    legend.position = "bottom",
    legend.key.height   = unit(1, "line"),
    legend.key.width    = unit(1.5, "line")
  ) +
  coord_flip() +
  geom_hline(yintercept = 0, linetype = "dashed")
p_xgb
Figure 20.8: Estimated effect of variables on the probability of being predicted an imaginary healthy patienteach individual (XGBoost)

Now, instead of a summary plot, let us create a plot for each of the variables in the top n. To that end, we create a function which plots the estimated SHAP values for a single variable for each individuals. When the variable is categorical, we want the graph to be split depending on the categories.

For continuous variable, we may want to make the colors of the dots depend on the quantiles. This is the case for net income, because of the shape of the distribution:

hist(df_all_rf$MENAGE_revucinsee, breaks = 100)

summary(df_all_rf$MENAGE_revucinsee)
   Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
     74    1000    1400    1620    2000   25300 

The code of the plot function is a bit long.

Code
#' Plots the SHAP values for each individuals, for a single variable
#'
#' @param variable_name name of the variable
#' @param treeshap_res treeshap results
#' @param df_all data frame with observations from train and test sets
#' @param quantile if TRUE, uses the empirical quantiles as the colour code (default to FALSE)
#' @param size_dots size of the dots
#' @param bar_width legend bar width for qualitative variables
plot_shap_all_indiv_variable_rf <- function(variable_name, 
                                            treeshap_res,
                                            df_all,
                                            quantile = FALSE, 
                                            size_dots = .6, 
                                            bar_width = 30) {
  df_plot <- 
    treeshap_res$shaps |> 
    select(!!variable_name) |> 
    rename(SHAP_value = !!variable_name) |> 
    mutate(variable_name = variable_name) |> 
    as_tibble() |> 
    # Add value of the variable from the dataset
    mutate(
      variable_value = df_all |> pull(!!variable_name)
    )
  
  is_factor <- variable_name %in% names(corresp_factors)
  if (is_factor == TRUE) {
    # Get the levels for the variable
    niveaux <- corresp_factors[[variable_name]]
    df_plot <- 
      df_plot |> 
      left_join(
        tibble(
          variable_value = 1:length(niveaux),
          variable_value_level = niveaux
        ),
        by = "variable_value"
      )
    
    if (any(niveaux == "No answer")) {
      nouveaux_niveaux <- niveaux
      nouveaux_niveaux[nouveaux_niveaux == "No answer"] <- "Did not answer"
      
      if ("At least once a month" %in% niveaux) {
        ordre_niveaux <- c(
          "Every day or almost every day",
          "At least once a week",
          "At least once a month",
          "Less than once a month", "Never", "Did not answer"
        )
        ordre_niveaux_tbl <- tibble(niveaux, nouveaux_niveaux) |> 
          mutate(
            nouveaux_niveaux = factor(nouveaux_niveaux, levels = ordre_niveaux)
          ) |> 
          arrange(desc(nouveaux_niveaux))
        
        niveaux <- ordre_niveaux_tbl$niveaux
        nouveaux_niveaux <- as.character(ordre_niveaux_tbl$nouveaux_niveaux)
        
      } else if ("Sometimes" %in% niveaux) {
        ordre_niveaux <- c("Always", "Often", "Sometimes", "Never", "Did not answer")
        ordre_niveaux_tbl <- 
          tibble(niveaux, nouveaux_niveaux) |> 
          mutate(
            nouveaux_niveaux = factor(nouveaux_niveaux, levels = ordre_niveaux)
          ) |> 
          arrange(desc(nouveaux_niveaux))
        
        niveaux <- ordre_niveaux_tbl$niveaux
        nouveaux_niveaux <- as.character(ordre_niveaux_tbl$nouveaux_niveaux)
      }
      df_plot <- 
        df_plot |> left_join(
          tibble(
            variable_value_level = niveaux, 
            variable_value_level_new = nouveaux_niveaux
          ),
          by = "variable_value_level"
        ) |> 
        select(-variable_value_level) |> 
        rename(variable_value_level = variable_value_level_new)
      
      niveaux <- nouveaux_niveaux
    }
    
    df_plot <- df_plot |> 
      mutate(
        variable_value_level = factor(variable_value_level, levels = niveaux)
      )
  }
  
  if (is_factor == TRUE) {
    p <- 
      ggplot(
        data = df_plot, 
        mapping = aes(
          y = SHAP_value, 
          x = variable_value_level
        )
      ) +
      geom_sina(size = size_dots) +
      geom_violin(alpha=0) +
      coord_flip() +
      labs(y = "SHAP value (impact on model output)", x = NULL)
  } else {
    
    if (quantile == TRUE) {
      
      df_plot$quantile_variable <- 
        cut(
          pull(df_plot, "variable_value"),
          labels = c(
            "$\\leq$ D1", "D1 to D2", "D2 to D3", "D3 to D4",
            "D4 to D5", "D5 to D6", "D6 to D7", "D7 to D8", 
            "D8 to D9", "> D9"
          ),
          breaks = c(
            0,
            quantile(
              magrittr::extract2(df_all, variable_name),
              probs = seq(.1, .9, .1)
            ),
            Inf
          )
        )
      
      df_plot <- 
        df_plot |> 
        mutate(quantile_revenu_num = as.numeric(quantile_variable))
      
      p <- 
        ggplot(
          data = df_plot |> mutate(x = 1),,
          mapping = aes(
            y = SHAP_value,
            x = x,
            colour = quantile_revenu_num
          )
        ) +
        geom_sina(alpha = .5, size = size_dots) +
        geom_violin(alpha = 0) +
        scale_color_gradient(
          "Variable value",
          breaks = c(1, 4, 7, 10),
          labels = c("$\\leq$ D1", "D3 to D4", "D6 to D7", "$>$ D9"),
          low = "#0081BC", high = "#EE324E",
          guide = guide_colourbar(
            direction = "horizontal",
            barwidth = bar_width,
            title.position = "bottom",
            title.hjust = .5
          )
        ) +
        labs(y = "SHAP value (impact on model output)", x = NULL) +
        theme(
          legend.position = "bottom",
          legend.key.height = unit(1, "line"),
          legend.key.width  = unit(1.5, "line"),
          axis.text.y = element_blank(),
          axis.ticks.y = element_blank()
        ) +
        coord_flip()
      
    } else {
      # Not using quantile
      p <- 
        ggplot(
          data = df_plot |> mutate(x = 1),
          mapping = aes(
            x = x,
            y = SHAP_value,
            colour = variable_value
          )
        ) +
        geom_sina(alpha = .5, size = size_dots) +
        geom_violin(alpha = 0) +
        scale_color_gradient(
          "Variable value",
          low = "#0081BC", high = "#EE324E",
          guide = guide_colourbar(
            direction = "horizontal",
            barwidth = bar_width,
            title.position = "bottom",
            title.hjust = .5)
        ) +
        labs(y = "SHAP value (impact on model output)", x = NULL) +
        theme(
          legend.position = "bottom",
          legend.key.height = unit(1, "line"),
          legend.key.width  = unit(1.5, "line"),
          axis.text.y = element_blank(), 
          axis.ticks.y = element_blank()
        ) +
        coord_flip()
    }
  }
  p + theme(
    plot.title.position = "plot",
    plot.title = element_text(hjust = 0, size = rel(1.3), face = "bold"),
    legend.text = element_text(size = rel(.8)),
    legend.title = element_text(size = rel(.8))
  )
}
Net Income per Cons. Unit

Gender

Have to Hurry to Do Job

Waiver Dental Care

Reimbursement General Practitioner

Deduct. Pharmacy

Age

Co-payment Pharmacy

Waiver Appointment Delay Too Long

Neck Pain

Code
#' Plots the SHAP values for each individuals, for a single variable
#'
#' @param variable_name name of the variable
#' @param treeshap_res treeshap results
#' @param df_all data frame with observations from train and test sets
#' @param quantile if TRUE, uses the empirical quantiles as the colour code (default to FALSE)
#' @param size_dots size of the dots
#' @param bar_width legend bar width for qualitative variables
plot_shap_all_indiv_variable_xgb <- function(variable_name, 
                                             treeshap_res,
                                             df_all,
                                             quantile = FALSE, 
                                             size_dots = .6, 
                                             bar_width = 30) {
  df_plot <- 
    treeshap_res$shaps |> 
    select(!!variable_name) |> 
    rename(SHAP_value = !!variable_name) |> 
    mutate(variable_name = variable_name) |> 
    as_tibble() |> 
    # Add value of the variable from the dataset
    mutate(
      variable_value = df_all |> pull(!!variable_name)
    )
  
  is_factor <- variable_name %in% variable_names_categ$variable
  if (is_factor == TRUE) {
    # Not using quantile
    p <-
      ggplot(
        data = df_plot |> mutate(x = 1),
        mapping = aes(
          x = x,
          y = SHAP_value,
          colour = variable_value
        )
      ) +
      geom_sina(alpha = .5, size = size_dots) +
      geom_violin(alpha = 0) +
      scale_color_gradient(
        "Variable value",
        low = "#0081BC", high = "#EE324E", n.breaks = 2
      ) +
      labs(y = "SHAP value (impact on model output)", x = NULL) +
      theme(
        legend.position = "bottom",
        legend.key.height = unit(1, "line"),
        legend.key.width  = unit(1.5, "line"),
        axis.text.y = element_blank(), 
        axis.ticks.y = element_blank()
      ) +
      coord_flip() +
      guides(color = guide_legend(override.aes = list(
        color = c("#0081BC", "#EE324E"),
        label = c("0", "1")
      )))
  } else {
    
    if (quantile == TRUE) {
      
      breaks_q <- c(
        0,
        quantile(
          magrittr::extract2(df_all, variable_name),
          probs = seq(.1, .9, .1)
        ),
        Inf
      )
      labels_q <- c(
        "$\\leq$ D1", "D1 to D2", "D2 to D3", "D3 to D4",
        "D4 to D5", "D5 to D6", "D6 to D7", "D7 to D8", 
        "D8 to D9", "> D9"
      )
      ind_remove <- which(duplicated(breaks_q))
      if (length(ind_remove) > 0) {
        labels_q <- labels_q[-ind_remove]
        breaks_q <- breaks_q[-ind_remove]
      }
      df_plot$quantile_variable <- 
        cut(
          pull(df_plot, "variable_value"),
          labels = labels_q,
          breaks = breaks_q
        )
      
      df_plot <- 
        df_plot |> 
        mutate(quantile_revenu_num = as.numeric(quantile_variable))
      
      p <- 
        ggplot(
          data = df_plot |> mutate(x = 1),,
          mapping = aes(
            y = SHAP_value,
            x = x,
            colour = quantile_revenu_num
          )
        ) +
        geom_sina(alpha = .5, size = size_dots) +
        geom_violin(alpha = 0) +
        scale_color_gradient(
          "Variable value",
          breaks = c(1, 4, 7, 10),
          labels = c("$\\leq$ D1", "D3 to D4", "D6 to D7", "$>$ D9"),
          low = "#0081BC", high = "#EE324E",
          guide = guide_colourbar(
            direction = "horizontal",
            barwidth = bar_width,
            title.position = "bottom",
            title.hjust = .5
          )
        ) +
        labs(y = "SHAP value (impact on model output)", x = NULL) +
        theme(
          legend.position = "bottom",
          legend.key.height = unit(1, "line"),
          legend.key.width  = unit(1.5, "line"),
          axis.text.y = element_blank(),
          axis.ticks.y = element_blank()
        ) +
        coord_flip()
      
    } else {
      # Not using quantile
      p <- 
        ggplot(
          data = df_plot |> mutate(x = 1),
          mapping = aes(
            x = x,
            y = SHAP_value,
            colour = variable_value
          )
        ) +
        geom_sina(alpha = .5, size = size_dots) +
        geom_violin(alpha = 0) +
        scale_color_gradient(
          "Variable value",
          low = "#0081BC", high = "#EE324E",
          guide = guide_colourbar(
            direction = "horizontal",
            barwidth = bar_width,
            title.position = "bottom",
            title.hjust = .5)
        ) +
        labs(y = "SHAP value (impact on model output)", x = NULL) +
        theme(
          legend.position = "bottom",
          legend.key.height = unit(1, "line"),
          legend.key.width  = unit(1.5, "line"),
          axis.text.y = element_blank(), 
          axis.ticks.y = element_blank()
        ) +
        coord_flip()
    }
  }
  p + theme(
    plot.title.position = "plot",
    plot.title = element_text(hjust = 0, size = rel(1.3), face = "bold"),
    legend.text = element_text(size = rel(.8)),
    legend.title = element_text(size = rel(.8))
  )
}
Age

Gender = Male

Net Income per Cons. Unit

Waiver Dental Care = No

Neck Pain = Yes

Frequency Meeting with People in Organizations = Never

Deduct. Pharmacy

Low Back Pain = Yes

Have to Hurry to Do Job = Sometimes

Allergy = Yes

20.5 Clustering

It may be possible to group individuals in the dataset depending the profile of the shap values. Let us perform a hierarchical clustering to group the individuals depending on their Shapley values. We will only keep a few variables to represent each individual, and will focus only on people classified as imaginary healthy patients. First, let us compute the aberage Shapley value for each variable, among the predicted imaginary healthy patients.

20.5.1 Note

We only focus on the SHAP values computed based on the predictions made with XGBoost in this part.

Recall we put in a tibble (order_variables_shap_xgb) the average absolute SHAP values (column value) and the average SHAP values (column mean):

order_variables_shap_xgb
# A tibble: 188 × 3
   name                     value     mean
   <chr>                    <dbl>    <dbl>
 1 PERSONNE_age            0.152  -0.0146 
 2 PERSONNE_sexeMale       0.129  -0.00171
 3 MENAGE_revucinsee       0.119  -0.00860
 4 OPINION1_renonc_dentNo  0.102  -0.0146 
 5 PERSONNE_pb_cervicalYes 0.0863 -0.00842
 6 QES_tpsassoNever        0.0810 -0.00288
 7 SOINS_pf_frpha          0.0792 -0.00322
 8 PERSONNE_pb_lombalgiYes 0.0751 -0.00566
 9 QST_ct_depechSometimes  0.0744 -0.00724
10 PERSONNE_pb_allergiYes  0.0697 -0.00710
# ℹ 178 more rows

We can display the average of the absolute SHAP values on a plot. A vertical red line is added at the average of the average of absolute Shapley values. Red dots indicate variables with negative average Shapley values, green dots indicate positive average values, and gray dots indicate zero average values.

ggplot(
  data = order_variables_shap_xgb |>
    left_join(variable_names, by = c("name" = "variable")) |>
    left_join(variable_names_categ, by = c("name" = "variable")) |>
    mutate(label = ifelse(is.na(label), label_categ, label)) |>
    rename(mean_abs_shap_value = value) |>
    mutate(
      sign = sign(mean),
      sign = factor(sign, levels = c(-1, 0, 1), labels = c("-", "0", "+"))
    ),
  mapping = aes(
    y = fct_reorder(label, desc(mean_abs_shap_value)),
    x = mean_abs_shap_value
  )
) +
  geom_point(mapping = aes(colour = sign)) +
  geom_vline(
    xintercept = mean(order_variables_shap_xgb$value), colour = "red"
  ) +
  scale_colour_manual(
    "Feature impact",
    values = c("-" = "#EE324E", "+" = "#009D57", "0" = "gray"),
    labels = c("-" = "Negative", "+" = "Positive")
  ) +
  labs(x = "Mean(|Shap value|)", y = "Variable")
Figure 20.9: Mean absolute SHAP values for each variable

Let us create a table in which we only keep the individuals predicted as imaginary healthy patients:

df_clustering_Not_D_and_inf_Q1 <-
  treeshap_xgb$shaps |>
  mutate(id_row = row_number()) |>
  mutate(predicted = pred_val_all_xgb) |>
  filter(predicted > .5) |>
  as_tibble() |>
  select(-predicted)

Let us only keep the 14 variables with the highest mean absolute Shapley values:

nb_top <- 14
variables_to_keep_cluster <-
  order_variables_shap_xgb |>
  arrange(desc(value)) |>
  dplyr::slice_head(n = nb_top) |>
  pull("name")
variables_to_keep_cluster
 [1] "PERSONNE_age"                  "PERSONNE_sexeMale"            
 [3] "MENAGE_revucinsee"             "OPINION1_renonc_dentNo"       
 [5] "PERSONNE_pb_cervicalYes"       "QES_tpsassoNever"             
 [7] "SOINS_pf_frpha"                "PERSONNE_pb_lombalgiYes"      
 [9] "QST_ct_depechSometimes"        "PERSONNE_pb_allergiYes"       
[11] "SOINS_remomn"                  "OPINION1_renonc_consNo answer"
[13] "SOINS_seac_omn"                "PERSONNE_coupleYes"           

We can visualize which variables were kept:

ggplot(
  data = order_variables_shap_xgb |>
    left_join(variable_names, by = c("name" = "variable")) |>
    left_join(variable_names_categ, by = c("name" = "variable")) |>
    mutate(label = ifelse(is.na(label), label_categ, label)) |>
    rename(mean_abs_shap_value = value) |>
    mutate(
      in_top_kept = name %in% variables_to_keep_cluster
    ),
  mapping = aes(
    y = fct_reorder(label, desc(mean_abs_shap_value)),
    x = mean_abs_shap_value
  )
) +
  geom_point(mapping = aes(colour = in_top_kept)) +
  geom_vline(
    xintercept = mean(order_variables_shap_xgb$value), colour = "red"
  ) +
  scale_colour_manual(
    "Variable kept for clustering",
    values = c("FALSE" = "gray", "TRUE" = "#009D57"),
    labels = c("FALSE" = "Discarded", "+" = "Kept")
  ) +
  labs(x = "Mean(|Shap value|)", y = "Variable")
Figure 20.10: Variables kept for the clustering and their averable absolute Shapley value.

We need to select a number of clusters K. Instead of arbitrarily pick a value, we vary the number of clusters and look at the silhouette information for each value K.

sil_val <- NULL
dist_m <- dist(
  as.data.frame(
    df_clustering_Not_D_and_inf_Q1 |>
      select(!!!syms(variables_to_keep_cluster))
  )
)
hierarchical_clust <- hclust(dist_m, method = "ward.D")
for (K in 2:15) {
  clusters_k <- cutree(hierarchical_clust, K)
  sil <- cluster::silhouette(clusters_k, dist_m)
  sil_val <- c(sil_val, mean(sil[, 3]))
}
p_silhouette <-
  ggplot(data = tibble(K = 2:15, y = sil_val), aes(x = K, y = y)) +
  geom_line() +
  geom_point() +
  labs(x = "K", y = "Silhouette score")
p_silhouette
Figure 20.11: Silhouette scores

We hence pick 4 as the number of desired clusters, and then assign the observation their cluster.

K <- 4
clusters_km_Not_D_and_inf_Q <- hierarchical_clust
df_clustering_Not_D_and_inf_Q1$cluster <- cutree(clusters_km_Not_D_and_inf_Q, K)

We create a table to plot the Shap values of these individuals, and order the observations according to their distance from each other.

df_plot <- df_clustering_Not_D_and_inf_Q1[clusters_km_Not_D_and_inf_Q$order, ]

For variables that are not in the top, we create an “Other” category.

df_plot <-
  df_plot |> select(-id_row) |>
  mutate(x = row_number()) |>
  pivot_longer(
    cols = -c(cluster, x), names_to = "variable", values_to = "value"
  ) |>
  mutate(
    variable = ifelse(
      variable %in% variables_to_keep_cluster,
      yes = variable,
      no = "Other"
    )
  ) |>
  group_by(cluster, x, variable) |>
  summarise(value = mean(value), .groups = "drop") |>
  left_join(
    variable_names |>
      bind_rows(variable_names_categ |> select(variable, label = label_categ)) |>
      bind_rows(tibble(variable = "Other", label = "Other", type = NA)),
    by = c("variable" = "variable")
  ) |>
  arrange(x)

The order of the variables:

labels_order <-
  tibble(variable = variables_to_keep_cluster) |>
  left_join(variable_names, by = c("variable")) |>
  left_join(variable_names_categ, by = c("variable")) |>
  mutate(label = ifelse(is.na(label), label_categ, label)) |>
  pull(label)

labels_order <- c(labels_order, "Other")

We apply this order to the variables in the data plot.

df_plot <-
  df_plot |>
  mutate(label = factor(label, levels = labels_order))

Then, we can plot the graph to visualize the average the contribution of each variable of the top within each cluster.

nb_vars <- df_plot$variable |> unique() |> length()

x_lines <-
  df_plot |>
  group_by(cluster) |>
  summarise(x_max = max(x)) |>
  pull(x_max) |>
  sort()
x_lines <- c(1, x_lines)

# colours <- Polychrome::createPalette(nb_top,  c("#ff0000", "#00ff00", "#0000ff"))
# names(colours) <- NULL

plot_clusters <-
  ggplot(data = df_plot) +
  geom_vline(xintercept = x_lines) +
  geom_col(aes(x = x, y = value, fill = label), width =1, alpha = 0.9) +
  geom_hline(yintercept = 0, col = "black") +
  scale_fill_manual(
    "", values = c(RColorBrewer::brewer.pal(nb_vars-1, 'Paired'), "green", "black", "gray")
  ) +
  # scale_fill_manual("", values = c(colours, "gray")) +
  labs(x = "Individuals", y = "Output value") +
  # theme_paper() +
  theme(
    legend.position = "bottom",
    axis.text.x = element_blank(),
    axis.ticks.x = element_blank()
  )
plot_clusters + guides(fill = guide_legend(ncol = 3))
Figure 20.12: Decomposition of the Effect of the Most Important Variables on the Probability of Being Predicted as an Imaginary Healthy Patient for Individuals Predicted as Such. All individuals.

With aggregated mean SHAP values at the cluster level:

df_plot_cluster <- 
  df_plot |> 
  group_by(cluster, variable, label) |> 
  summarise(value = mean(value)) |> 
  ungroup() |> 
  mutate(cluster = factor(cluster, levels = unique(df_plot$cluster)))
`summarise()` has grouped output by 'cluster', 'variable'. You can override
using the `.groups` argument.
plot_clusters_2 <-
  ggplot(
    data = df_plot_cluster |>
      mutate(
        cluster = factor(
          cluster,
          levels = levels(df_plot_cluster$cluster),
          labels = seq_len(length(levels(df_plot_cluster$cluster)))
        )
      ),
    mapping = aes(x = cluster, fill = label, y = value)
  ) +
  geom_bar(position = "stack", stat = "identity") +
  labs(x = "Cluster",y = "Output value") +
  # scale_fill_manual("", values = c(colours, "gray")) +
  scale_fill_manual(
    "", 
    values = c(RColorBrewer::brewer.pal(nb_vars-1, 'Paired'), "green", "black", "gray")
  ) +
  theme(
    legend.position = "bottom"
  )
Warning in RColorBrewer::brewer.pal(nb_vars - 1, "Paired"): n too large, allowed maximum for palette Paired is 12
Returning the palette you asked for with that many colors
plot_clusters_2 + guides(fill = guide_legend(ncol = 3)) +
  geom_hline(yintercept = 0, linetype = "dashed")
Figure 20.13: Decomposition of the Effect of the Most Important Variables on the Probability of Being Predicted as an Imaginary Healthy Patient for Individuals Predicted as Such. Average per Cluster.

20.5.2 Descriptive Statistics of the Clusters

Let us now get some descriptive statistics at the cluster level, for each of these variables. We need to get the values of the variables for the individuals predicted as imaginary healthy patients.

For categorical variables, we would like to get the most frequent value, in each cluster. For numerical variables, we would like to get the average and standard deviation. Let us isolate these two types of variables.

The categorical variables:

top_name_categ <-
  variable_names_categ |>
  filter(variable %in% top_n_variables_xgb) |>
  pull("variable_raw")

The numerical variables:

top_name_num <- top_n_variables_xgb[!top_n_variables_xgb %in% variable_names_categ$variable]

Let us create a dataset with the characteristics and the cluster in which the individuals belong:

df_stat_des <-
  df_all_xgb |>
  tibble() |>
  mutate(id_row = row_number()) |>
  # Add cluster
  left_join(
    df_clustering_Not_D_and_inf_Q1 |> select(id_row, cluster),
    by = "id_row"
  )

We define a function, get_stats() which computes the desired statistics:

get_stats <- function(df) {
  # Desc. Stat on numerical variables
  df_num <-
    df |>
    select(!!top_name_num) |>
    pivot_longer(cols = everything()) |>
    group_by(name) |>
    summarise(mean = mean(value), sd = sd(value)) |>
    mutate(value = str_c(round(mean, 2), " (", round(sd, 2),")")) |>
    select(name, value)

  df_categ <-
    df |>
    select(!!top_name_categ) |>
    pivot_longer(cols = everything()) |>
    group_by(name) |>
    count(value) |>
    mutate(pct = round(100 *n / sum(n), 1)) |>
    slice_max(order_by = pct, n = 1) |>
    mutate(value = str_c(value, ' (', pct, "%)")) |>
    select(name, value)

  df_num |> bind_rows(df_categ)
}

We will compute the statistics on the following datasets:

  • whole dataset
  • sample of individuals predicted as imaginary healthy patients
  • each of the clusters (i.e., sub sample of the sample of individuals predicted as imaginary healthy patients).

On the whole dataset:

stat_des_whole <- get_stats(df_stat_des) |> mutate(type = "Whole sample")

On the dataset of individuals predicted as imaginary healthy patients by the classifier:

stat_des_imaginary <- df_stat_des |>
  filter(id_row %in% df_clustering_Not_D_and_inf_Q1$id_row) |>
  get_stats() |> mutate(type = "Imaginary healthy")

And on each of the clusters:

df_stat_des_clusters <- NULL
for (i_clust in 1:K) {
  df_stat_des_clusters <-
    df_stat_des_clusters |> bind_rows(
      df_stat_des |>
        filter(cluster == !!i_clust) |>
        get_stats() |> mutate(type = as.character(i_clust))
    )
}

We also compare the clusters by computing the p-values of an ANOVA for quantitative variables, and of a chi-squared test for qualitative variables.

For quantitative variables:

p_values_categ <- NULL
for (name in top_name_categ) {
  values <- df_stat_des |> pull(!!name)
  # confusion matrix
  tab <- table(df_stat_des$cluster, values)
  # chi-squared test
  chi2 <- chisq.test(tab, correct = T)
  # p-value
  pvalue <- chi2$p.value
  p_values_categ <- bind_rows(
    p_values_categ,
    tibble(name = name, pvalue = pvalue)
  )
}

For numerical variables:

p_values_num <- NULL
for (name in top_name_num) {
  # anova test
  values <- df_stat_des |> pull(!!name)
  anova <- anova(aov(as.numeric(values) ~ df_stat_des$cluster))
  # p-value
  pvalue <- anova$`Pr(>F)`[1]
  p_values_num <- bind_rows(
    p_values_num,
    tibble(name = name, pvalue = pvalue)
  )
}

We merge the p-values in a single object.

tb_p_values <- p_values_categ |> bind_rows(p_values_num)

To build a prettier table, let us build a tibble with the labels of the variables.

table_ref <-
  tibble(variable = top_n_variables_xgb) |>
  left_join(variable_names_categ, by = c("variable")) |>
  mutate(
    variable = ifelse(is.na(variable_raw), variable, variable_raw)
  ) |>
  left_join(
    variable_names, by = "variable"
  )

Let us merge the three tables with the descriptive statistics and add the p-values of the clusters comparisons.

tb_desc_stats_clusters <-
  df_stat_des_clusters |>
  bind_rows(stat_des_whole) |>
  bind_rows(stat_des_imaginary) |>
  mutate(
    type = factor(type, levels = c(1:K, "Imaginary healthy", "Whole sample"))
  ) |>
  pivot_wider(names_from = type, values_from = value) |>
  left_join(
    table_ref |> select(variable, label),
    by = c("name" = "variable")
  ) |>
  left_join(unique(tb_p_values), by = "name") |>
  relocate(label, .before = name) |>
  select(-name) |>
  unique() |>
  mutate(label = factor(label, levels = unique(table_ref$label))) |>
  arrange(label)
tb_desc_stats_clusters |>
  select(-pvalue) |>
  add_column(p_value = format.pval(tb_desc_stats_clusters$pvalue, 2)) |>
  kableExtra::kable() |>
  kableExtra::kable_classic(full_width = F, html_font = "Cambria") |>
  kableExtra::kable_styling(
    bootstrap_options = c("striped", "hover", "condensed", "responsive")
  )
Table 20.2: Average individual in each cluster and in the sample
label 1 2 3 4 Whole sample Imaginary healthy p_value
Age 51.98 (16.51) 40.84 (10.32) 37.05 (12.29) 71.38 (9.88) 47.57 (18.17) 47.48 (17) < 2e-16
Gender Female (98.9%) Female (66.2%) Female (63.6%) Female (61.3%) Female (52.9%) Female (71.6%) < 2e-16
Net Income per Cons. Unit 1501.28 (707.97) 1577.88 (745.3) 590.34 (160.11) 1357.57 (735.06) 1620.05 (1004.37) 1322.78 (759.48) < 2e-16
Waiver Dental Care No (71.2%) No (62.3%) No (47.4%) No (59.9%) No (66.3%) No (60.7%) < 2e-16
Neck Pain Yes (100%) No (88.1%) No (94.8%) No (85.7%) No (82.2%) No (70.9%) < 2e-16
Frequency Meeting with People in Organizations Never (48.5%) Never (63.2%) Never (66.7%) Never (54.9%) Never (50.4%) Never (59.5%) 5.1e-13
Deduct. Pharmacy 19.73 (16.58) 15.44 (13) 4.55 (9.49) 34.19 (14.28) 12.55 (14.7) 17.23 (16.25) 2.3e-14
Low Back Pain Yes (50.9%) No (71.4%) No (82.5%) No (56.5%) No (77.5%) No (66.6%) < 2e-16
Have to Hurry to Do Job No answer (50.9%) Often (37%) No answer (79.9%) No answer (93.4%) No answer (49.7%) No answer (51.8%) < 2e-16
Allergy No (67.8%) No (70.9%) No (81.4%) No (80.6%) No (81.9%) No (74%) 8.6e-12

Number of observation per cluster:

table(df_clustering_Not_D_and_inf_Q1$cluster)

  1   2   3   4 
466 949 462 377 

Number of observation in the whole sample:

nrow(df_stat_des)
[1] 5380

Number of individuals predicted imaginary healthy:

nrow(df_clustering_Not_D_and_inf_Q1)
[1] 2254

Remaining obs (sanity check):

df_stat_des |>
  filter(!id_row %in% df_clustering_Not_D_and_inf_Q1$id_row) |> 
  nrow()
[1] 3126

Predicted class by XGB:

# Healthy
pred_class_xgb <- predict(
  grid_search_xgb, newdata = df_all_xgb
)

Overview of predictions in each clusters:

df_stat_des |> 
  mutate(predicted_val = pred_val_all_xgb) |> 
  mutate(
    pred_class = pred_class_xgb,
    correctly_pred = pred_class == status
  ) |> 
  # Keep only predicted as imaginary healthy
  filter(!is.na(cluster)) |> 
  group_by(cluster) |> 
  summarise(
    n = n(),
    accuracy = round(100*mean(correctly_pred), 1),
    no_im_healthy = sum(status == "Not_D_and_inf_Q1"),
    pct_im_healthy = round(100*no_im_healthy / n, 1)
  )
# A tibble: 4 × 5
  cluster     n accuracy no_im_healthy pct_im_healthy
    <int> <int>    <dbl>         <int>          <dbl>
1       1   466     51.9           242           51.9
2       2   949     46.6           442           46.6
3       3   462     47             217           47  
4       4   377     40.8           154           40.8

Overall accuracy:

df_stat_des |> 
  mutate(predicted_val = pred_val_all_xgb) |> 
  mutate(
    pred_class = ifelse(
      predicted_val > .5, 
      yes = "Not_D_and_inf_Q1", "Not_D_and_sup_Q1"
    ),
    correctly_pred = pred_class == status
  ) |> 
  summarise(
    n = n(),
    accuracy = round(100*mean(correctly_pred),1),
    no_im_healthy = sum(status == "Not_D_and_inf_Q1"),
    pct_im_healthy = round(100*no_im_healthy / n, 1)
  )
# A tibble: 1 × 4
      n accuracy no_im_healthy pct_im_healthy
  <int>    <dbl>         <int>          <dbl>
1  5380     67.7          1592           29.6

Percentage of individuals correctly predicted Not_D_and_inf_Q1, i.e., imaginary healthy:

df_stat_des |> 
  mutate(predicted_val = pred_val_all_xgb) |> 
  mutate(
    pred_class = ifelse(
      predicted_val > .5, 
      yes = "Not_D_and_inf_Q1", "Not_D_and_sup_Q1"
    ),
    correctly_pred = pred_class == status
  ) |> 
  filter(pred_class == "Not_D_and_inf_Q1") |> 
  summarise(
    n = n(),
    accuracy = round(100*mean(correctly_pred),1),
    no_im_healthy = sum(status == "Not_D_and_inf_Q1"),
    pct_im_healthy = round(100*no_im_healthy / n, 1)
  )
# A tibble: 1 × 4
      n accuracy no_im_healthy pct_im_healthy
  <int>    <dbl>         <int>          <dbl>
1  2254     46.8          1055           46.8

Percentage of individuals correctly predicted Not_D_and_sup_Q1, i.e., healthy:

# healthy
df_stat_des |> 
  mutate(predicted_val = pred_val_all_xgb) |> 
  mutate(
    pred_class = ifelse(
      predicted_val > .5, 
      yes = "Not_D_and_inf_Q1", "Not_D_and_sup_Q1"
    ),
    correctly_pred = pred_class == status
  ) |> 
  filter(pred_class == "Not_D_and_sup_Q1") |> 
  summarise(
    n = n(),
    accuracy = round(100*mean(correctly_pred),1),
    no_im_healthy = sum(status == "Not_D_and_inf_Q1"),
    pct_im_healthy = round(100*no_im_healthy / n, 1)
  )
# A tibble: 1 × 4
      n accuracy no_im_healthy pct_im_healthy
  <int>    <dbl>         <int>          <dbl>
1  3126     82.8           537           17.2