8  Results

Objectives

This chapter presents the results from Chapter 7.

library(tidyverse)
── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
✔ dplyr     1.1.4     ✔ readr     2.1.5
✔ forcats   1.0.0     ✔ stringr   1.5.1
✔ ggplot2   3.5.1     ✔ tibble    3.2.1
✔ lubridate 1.9.3     ✔ tidyr     1.3.1
✔ purrr     1.0.2     
── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
✖ dplyr::filter() masks stats::filter()
✖ dplyr::lag()    masks stats::lag()
ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(caret)
Loading required package: lattice

Attaching package: 'caret'

The following object is masked from 'package:purrr':

    lift
library(ranger)
library(treeshap)

Let us define a theme for graphs.

Definition of theme_paper()
#' Theme for ggplot2
#'
#' @param ... arguments passed to the theme function
#' @export
#' @importFrom ggplot2 element_rect element_text element_blank element_line unit
#'   rel
theme_paper <- function (...) {
  ggthemes::theme_base() +
    theme(
      legend.background = element_rect(
        fill = "transparent", linetype="solid", colour ="black"),
      legend.position = "bottom",
      legend.direction = "horizontal",
      legend.box = "horizontal",
      legend.key = element_blank()
    )
}

8.1 Load data and estimated models

8.1.1 Data

Let us load the train and test data:

load("../data/out/df_train.rda")
load("../data/out/df_test.rda")

For random forests, we need to convert categorical variables to numerical variables:

df_train_num <-
  df_train |>
  mutate(across(where(is.factor), as.numeric))

df_test_num <- 
  df_test |>
  mutate(across(where(is.factor), as.numeric))

We keep track of the levels of factor data:

corresp_factors <- df_train |> 
  select_if(is.factor) |> 
  map(levels)

8.1.2 Classifiers

The estimated models (see Chapter 5):

# The estimated random forests (and the grid)
load("../data/out/estim/v3/grid_search_rf.rda")
# The estimated xgb (and the grid)
load("../data/out/estim/v3/grid_search_xgb.rda")

Let us get the best classifiers obtained both for the random forest and XGBoost. Let us begin with random forests:

ind_remove <- which(
  colnames(df_train) %in% c("status", "id", "PERSONNE_statut")
)
formula <- as.formula(
  paste("status ~", paste(colnames(df_train[-ind_remove]), collapse = "+"))
)

final_model_rf <- ranger(
  formula,
  data = df_train_num, 
  mtry = grid_search_rf$bestTune$mtry,
  splitrule = "gini",
  min.node.size = grid_search_rf$bestTune$min.node.size,
  classification = T
)

And for the extreme gradient boosting algorithm:

final_model_xgb <- grid_search_xgb$finalModel

8.1.3 SHAP Values

We load the results estimated SHAP values (see Chapter 7):

load("../data/out/treeSHAP/v3/treeshap_rf.rda")
load("../data/out/treeSHAP/v3/treeshap_xgb.rda")

The explainer used only the values from the train set to estimate the SHAP values:

reference_data_rf <- df_train_num
reference_data_xgb <- df_train

And the SHAP values were estimated on both the train set and the test set:

df_all_rf <- bind_rows(df_train_num, df_test_num)
df_all_xgb <- bind_rows(df_train, df_test)

8.2 Predicted Values (classifier)

Let us get the predicted values for both classifiers.

8.2.1 Random Forest

pred_val_ref_rf <- predict(final_model_rf, reference_data_rf)$predictions
head(pred_val_ref_rf)
[1] 1 1 1 0 1 1

The issue is that this gives the majority vote, and not the associated scores. Let us compute the scores by aggregating the votes on each tree.

# Get the scores instead:
predict_score_ranger <- function(object, data) {
  pred_all <- predict(object, data, predict.all = TRUE)
  pred_all_df <- as.data.frame(pred_all$predictions)
  ntrees <- ncol(pred_all_df)
  pred_all_df$votes <- rowSums( pred_all_df == 1 ) / ntrees
  pred_all_df$votes
}
pred_val_rf <- predict_score_ranger(final_model_rf, reference_data_rf)
head(pred_val_rf)
[1] 0.794 0.816 0.786 0.316 0.778 0.876

8.2.2 XGBoost

And now, for XGBoost:

pred_val_xgb <- predict(
  grid_search_xgb, newdata = reference_data_xgb,type = "prob"
)
pred_val_xgb <- pred_val_xgb[, "Not_D_and_inf_Q1"]
head(pred_val_xgb)
[1] 0.4802239 0.4298899 0.4921511 0.7234144 0.5549585 0.4536980

Let us get the predicted values on both the train set and the test set as well:

pred_val_all_xgb <- predict(
  grid_search_xgb, newdata = df_all_xgb, type = "prob"
)
pred_val_all_xgb <- pred_val_all_xgb[, "Not_D_and_inf_Q1"]

8.3 Understanding the Output

Let us have a look at the first individual to explain the results obtained in the table. Let us have a look at the first individual.

id_indiv <- 1
df_all_rf[id_indiv, ]
  PERSONNE_pb_asthm PERSONNE_pb_bronchit PERSONNE_pb_infarctus
1                 1                    1                     1
  PERSONNE_pb_coronair PERSONNE_pb_hypertens PERSONNE_pb_avc
1                    1                     1               1
  PERSONNE_pb_arthros PERSONNE_pb_lombalgi PERSONNE_pb_cervical
1                   1                    1                    1
  PERSONNE_pb_diabet PERSONNE_pb_allergi PERSONNE_pb_cirrhos
1                  1                   1                   1
  PERSONNE_pb_urinair PERSONNE_age PERSONNE_sexe PERSONNE_couple
1                   1           32             2               3
  PERSONNE_statut PERSONNE_ss PERSONNE_regime PERSONNE_rap_pcs8 PERSONNE_ald
1               2           1               1                 4            2
  SOINS_ald_am MENAGE_revucinsee MENAGE_tu MENAGE_nbpers ensol_2011
1            2            1561.9         2             4   1812.192
  MUTUELLE_assu MUTUELLE_typcc SOINS_remomn SOINS_remspe SOINS_rempha
1             2              5       178.28       114.08         9.38
  SOINS_remkin SOINS_reminf SOINS_remden SOINS_remmat SOINS_remtra SOINS_remopt
1       243.59            0            0            0            0            0
  SOINS_rempro SOINS_remurg SOINS_tmomn SOINS_tmspe SOINS_tmpha SOINS_tmkin
1         44.2            0       81.12        50.6        17.6      171.34
  SOINS_tminf SOINS_tmden SOINS_tmmat SOINS_tmtra SOINS_tmopt SOINS_tmpro
1           0           0           0           0           0       29.46
  SOINS_tmurg SOINS_dpaomn SOINS_dpaspe SOINS_dpapha SOINS_dpakin SOINS_dpainf
1           0         11.2            0            0            0            0
  SOINS_dpaden SOINS_dpamat SOINS_dpatra SOINS_dpaopt SOINS_dpapro SOINS_dpaurg
1            0            0            0            0        87.14            0
  SOINS_pf_fromn SOINS_pf_frspe SOINS_pf_frpha SOINS_pf_frkin SOINS_pf_frinf
1             11              4              4           13.5              0
  SOINS_pf_frden SOINS_pf_frtra SOINS_pf_frurg SOINS_seac_omn SOINS_seac_spe
1              0              0              0              9              4
  OPINION1_renonc_cons OPINION1_renonc_dent OPINION1_renonc_fin
1                    2                    1                   2
  OPINION1_renonc_loin OPINION1_renonc_long QST_ct_depech QST_ct_liberte
1                    2                    2             2              4
  QST_ct_apprend QST_ct_aidecol QST_ct_travnuit QST_ct_repet QST_ct_lourd
1              2              2               4            4            4
  QST_ct_posture QST_ct_produit QES_association QES_tpsami QES_tpsasso
1              3              3               1          3           2
  QES_tpscolleg QES_tpsfamil QES_mere_etude QES_pere_etude id status
1             4            3              3              3  3      2

The estimated SHAP values for this individual (with the random forest):

treeshap_rf$shaps[id_indiv, ]
  PERSONNE_pb_asthm PERSONNE_pb_bronchit PERSONNE_pb_infarctus
1       0.001902093          0.003492646          0.0002529194
  PERSONNE_pb_coronair PERSONNE_pb_hypertens PERSONNE_pb_avc
1          0.000118986           0.001184851    3.089125e-05
  PERSONNE_pb_arthros PERSONNE_pb_lombalgi PERSONNE_pb_cervical
1         0.004319534           0.00585481           0.00455173
  PERSONNE_pb_diabet PERSONNE_pb_allergi PERSONNE_pb_cirrhos
1        0.002328196          0.00256106        2.441598e-05
  PERSONNE_pb_urinair PERSONNE_age PERSONNE_sexe PERSONNE_couple  PERSONNE_ss
1         0.002037423 -0.006131559   0.007852905      0.00570197 0.0001545872
  PERSONNE_regime PERSONNE_rap_pcs8 PERSONNE_ald SOINS_ald_am MENAGE_revucinsee
1   -0.0008582498       0.003577031   0.00282703  0.000800785        0.02174846
    MENAGE_tu MENAGE_nbpers  ensol_2011 MUTUELLE_assu MUTUELLE_typcc
1 0.004915818 -0.0007619383 0.004986332    0.00264825   0.0005319078
  SOINS_remomn SOINS_remspe SOINS_rempha SOINS_remkin SOINS_reminf SOINS_remden
1  -0.02207328  0.004413408   0.02070351  -0.01199823  0.002504439 0.0005772186
  SOINS_remmat SOINS_remtra SOINS_remopt SOINS_rempro SOINS_remurg SOINS_tmomn
1  0.003017843  0.003690953  0.001046285 -0.003589319  0.001363076 -0.01104072
   SOINS_tmspe SOINS_tmpha SOINS_tmkin SOINS_tminf SOINS_tmden SOINS_tmmat
1 0.0004226593  0.01115561 -0.01379172 0.001213832 0.004498388 0.002988124
  SOINS_tmtra  SOINS_tmopt  SOINS_tmpro  SOINS_tmurg  SOINS_dpaomn SOINS_dpaspe
1 0.001284419 0.0005683362 -0.006030765 0.0005412069 -0.0002424793 0.0008339485
  SOINS_dpapha SOINS_dpakin SOINS_dpainf SOINS_dpaden SOINS_dpamat
1  7.28222e-05 0.0002880034 4.779914e-05  0.001299304  0.000450043
   SOINS_dpatra SOINS_dpaopt SOINS_dpapro SOINS_dpaurg SOINS_pf_fromn
1 -1.839553e-06  0.001222361  -0.01541312 1.367787e-05   -0.009509514
  SOINS_pf_frspe SOINS_pf_frpha SOINS_pf_frkin SOINS_pf_frinf SOINS_pf_frden
1   0.0006675115     0.01191832    -0.01954245   0.0009281199   3.111839e-05
  SOINS_pf_frtra SOINS_pf_frurg SOINS_seac_omn SOINS_seac_spe
1   0.0008339523   5.383237e-05  -0.0005813897    0.002417151
  OPINION1_renonc_cons OPINION1_renonc_dent OPINION1_renonc_fin
1          0.006141836          -0.02335377         0.001513655
  OPINION1_renonc_loin OPINION1_renonc_long QST_ct_depech QST_ct_liberte
1           0.00110833          0.006720838   -0.01829684    0.004776366
  QST_ct_apprend QST_ct_aidecol QST_ct_travnuit QST_ct_repet QST_ct_lourd
1    0.004302032    0.002739294    0.0006418666  0.002642911  0.001627902
  QST_ct_posture QST_ct_produit QES_association   QES_tpsami QES_tpsasso
1     0.00184207  -0.0009190289      0.00732533 -0.004240796  0.01258614
  QES_tpscolleg QES_tpsfamil QES_mere_etude QES_pere_etude
1   0.000587658  0.001452621    0.001350946    0.002939407

Each column gives the SHAP value for a variable, for this individual. Each row of treeshap_rf gives the SHAP values for a single individual.

In the reference data (we used the train set), the average predicted scores is:

(mean_pred_rf_ref <- mean(pred_val_rf)) # random forest
[1] 0.7366059
(mean_pred_xgb_ref <- mean(pred_val_xgb)) # XGBoost
[1] 0.4760346

For the first individual, the predicted score, for the random, is:

pred_indiv_rf <- 
  predict_score_ranger(final_model_rf, df_all_rf[id_indiv, ])
pred_indiv_rf
[1] 0.794

Note that for the random forest, this corresponds to the score associated to the class “Not_D_and_sup_Q1”. The probability of being classified as imaginary healthy is:

1-pred_indiv_rf
[1] 0.206

And with XGB:

pred_indiv_xgb <- predict(grid_search_xgb, df_all_xgb[id_indiv, ], type = "prob")
pred_indiv_xgb <- pred_indiv_xgb[, "Not_D_and_inf_Q1"]
pred_indiv_xgb
[1] 0.4802239

For the record, the actual class is:

df_all_rf[id_indiv, "status"]
[1] 2
df_all_xgb[id_indiv, "status"]
[1] Not_D_and_sup_Q1
Levels: Not_D_and_inf_Q1 Not_D_and_sup_Q1

The sum of the SHAP values correspond to the deviation from the average of the scores estimated on the reference data:

shap_indiv_rf <- treeshap_rf$shaps[id_indiv, ] # random forest
shap_indiv_xgb <- treeshap_xgb$shaps[id_indiv, ] # XGBoost

For the random forest:

c(
  "shap" = sum(shap_indiv_rf), 
  "pred" = pred_indiv_rf, 
  "mean_pred" = mean_pred_rf_ref,
  "diff" = pred_indiv_rf - mean_pred_rf_ref
)
     shap      pred mean_pred      diff 
0.0573941 0.7940000 0.7366059 0.0573941 

and for XGBoost:

c(
  "shap" = sum(shap_indiv_xgb), 
  "pred" = pred_indiv_xgb, 
  "mean_pred" = mean_pred_xgb_ref,
  "diff" = pred_indiv_xgb - mean_pred_xgb_ref
)
        shap         pred    mean_pred         diff 
-0.083964645  0.480223864  0.476034586  0.004189278 

Let us now visualize the SHAP values for a specific individual.

load("../data/out/variable_names.rda")

We build a table with variable names.

variable_names_categ <- 
  enframe(corresp_factors) |> 
  left_join(variable_names, by = c("name" = "variable")) |> 
  unnest(value) |> 
  mutate(
    variable = str_c(name, value),
    label_categ = str_c(label, " = ", value)) |> 
  select(variable, label_categ, variable_raw = name, variable_label_val = value)
variable_names_categ
# A tibble: 176 × 4
   variable                 label_categ          variable_raw variable_label_val
   <chr>                    <chr>                <chr>        <chr>             
 1 PERSONNE_pb_asthmNo      Asthma = No          PERSONNE_pb… No                
 2 PERSONNE_pb_asthmYes     Asthma = Yes         PERSONNE_pb… Yes               
 3 PERSONNE_pb_bronchitNo   Bronchitis = No      PERSONNE_pb… No                
 4 PERSONNE_pb_bronchitYes  Bronchitis = Yes     PERSONNE_pb… Yes               
 5 PERSONNE_pb_infarctusNo  Heart Attack = No    PERSONNE_pb… No                
 6 PERSONNE_pb_infarctusYes Heart Attack = Yes   PERSONNE_pb… Yes               
 7 PERSONNE_pb_coronairNo   Artery Disease = No  PERSONNE_pb… No                
 8 PERSONNE_pb_coronairYes  Artery Disease = Yes PERSONNE_pb… Yes               
 9 PERSONNE_pb_hypertensNo  Hypertension = No    PERSONNE_pb… No                
10 PERSONNE_pb_hypertensYes Hypertension = Yes   PERSONNE_pb… Yes               
# ℹ 166 more rows

We define a function, plot_individual_variable_effect() to understand why, according to the estimated SHAP values, the predicted value for a single individuals deviates from the average prediction in the reference sample.

#' Plot the decomposition of the deviation from the average prediction
#' for an individual, using SHAP values
#' 
#' @param i row number to designate individual
#' @param treeshap_res result obtained with treeshap
#' @param predicted_val_i predicted value for the individual
#' @param mean_pred_ref average predicted value in the reference data
#' @param n top n variables (default to 10)
#' @param min_max if not `NULL`, limits for the x axis (Shap values), in a 
#'  vector of length 2 containing c(min, max)
plot_individual_variable_effect <- function(i, 
                                            treeshap_res, 
                                            predicted_val_i,
                                            mean_pred_ref,
                                            n = 10,
                                            min_max = NULL) {
  
  df_plot <- 
    treeshap_res$shaps |> 
    dplyr::slice(i) |> 
    unlist() |> 
    enframe(name = "variable", value = "shap") |> 
    left_join(variable_names, by = "variable") |> 
    left_join(variable_names_categ, by = "variable") |> 
    mutate(label = ifelse(is.na(label), label_categ, label)) |> 
    arrange(desc(abs(shap))) |> 
    mutate(in_top_n = row_number() <= !!n) |> 
    mutate(
      label_2 = ifelse(in_top_n, label, "Other Variables")
    ) |> 
    group_by(label_2) |> 
    summarise(
      shap = sum(shap),
      .groups = "drop"
    ) |> 
    mutate(
      label_2 = fct_reorder(label_2, abs(shap)),
      label_2 = fct_relevel(label_2, "Other Variables")
    ) |> 
    mutate(
      mean_pred_ref = !!mean_pred_ref,
      shap_rounded = str_c("  ", round(shap, 3)),
      sign = sign(shap),
      sign = factor(sign, levels = c(-1, 1), labels = c("-", "+"))
    )
  
  p <- ggplot(
    data = df_plot, 
    mapping = aes(
      y =  mean_pred_ref + pmax(shap, 0),
      x = label_2,
      ymin = mean_pred_ref,
      ymax = mean_pred_ref + shap,
      colour = sign)
  ) +
    geom_linerange(linewidth = 8) +
    geom_hline(yintercept = mean_pred_ref) +
    geom_text(aes(label = shap_rounded, hjust = 0)) +
    coord_flip() +
    scale_colour_manual(
      "Variable impact", 
      values = c("-" = "#EE324E", "+" = "#009D57"), 
      labels = c("-" = "Negative", "+" = "Positive")
    ) +
    labs(
      x = NULL, 
      y = "Shap values", 
      title=str_c("Predicted value: ", round(predicted_val_i, 3)),
      subtitle = str_c("Base value: ", round(mean_pred_ref, 3))
    ) +
    theme(
      panel.grid.major.y = element_blank(), axis.ticks.y = element_blank(),
      legend.position = "bottom",
      plot.title.position = "plot",
      plot.title = element_text(hjust = 0, size = rel(1.3), face = "bold")
    )
  
  if (!is.null(min_max)) {
    p <- p + scale_y_continuous(limits = min_max)
  }
  p
}

Let us look at the first individual:

i <- 1
treeshap_rf_opposite <- treeshap_rf
treeshap_rf_opposite$shaps <- -treeshap_rf_opposite$shaps
plot_individual_variable_effect(
  i = i, 
  treeshap_res = treeshap_rf, 
  predicted_val_i = 1-pred_val_rf[i], 
  mean_pred_ref = 1-mean_pred_rf_ref, n = 20,
  min_max = c(0.25, .4)
)
Warning: Removed 6 rows containing missing values or values outside the scale range
(`geom_segment()`).
Figure 8.1: SHAP values for the first individual, when the score is estimated with a Random Forest. Positive class: Imaginary healthy patient.
plot_individual_variable_effect(
  i = i, 
  treeshap_res = treeshap_xgb, 
  predicted_val_i = pred_val_xgb[i], 
  mean_pred_ref = mean_pred_xgb_ref, n = 20,
  min_max = c(0.2, 1)
)
Figure 8.2: SHAP values for the first individual, when the score is estimated with XGBoost. Positive class: Imaginary healthy patient.

Now, let us pick two individuals: one with a low predicted score, and another one with a high predicted score. We define two functions: get_top_n_indiv(), to get the top n variables explaining the deviation of the score of an individual from the average predicted score in the reference set, and get_table_charact() to extract the characteristics of specific individuals given a set of characteristics.

#' Extracts the top n variables according to absolute SHAP val
#' 
#' @param i row number of individual
#' @param treeshap_res result obtained with treeshap
#' @param n top n variables (default to 10)
get_top_n_indiv <- function(i, treeshap_res, n = 10) {
  treeshap_res$shaps |> 
    dplyr::slice(i) |> 
    unlist() |> 
    enframe(name = "variable", value = "shap") |> 
    arrange(desc(abs(shap))) |> 
    filter(row_number() <= !!n) |> 
    pull("variable")
}



#' Get the characteristics of the union of top n variables among individuals
#' (using abolute SHAP values)
#' for xgboost only
#' 
#' @param i row numbers of individuals
#' @param treeshap_res result obtained with treeshap
#' @param reference_data reference dataset
#' @param n top n variables (default to 10)
get_table_charact <- function(i, 
                              treeshap_res, 
                              n = 10) {
  # Top n variables for individuals
  top_n_indiv <- map(
    i, 
    ~get_top_n_indiv(i = .x, treeshap_res = treeshap_res, n = n)
  )
  top_n_indiv_variables <- list_c(top_n_indiv) |> unique()
  
  # Average values for those variables in the reference sample
  values_reference <- 
    model.matrix(formula, reference_data_xgb) |> 
    as_tibble() |> 
    select(!!top_n_indiv_variables) |> 
    summarise(across(everything(), ~mean(.x))) |> 
    pivot_longer(cols = everything(), names_to = "variable", values_to = "val_ref")
  
  # Characteristics for these individuals
  top_n_indiv_characteristics <- map2(
    .x = i, .y = top_n_indiv,
    ~tibble(variable = top_n_indiv_variables, id_indiv = .x) |> 
      left_join(
        model.matrix(formula, df_all_xgb) |>
          as_tibble() |> 
          dplyr::slice(.x) |> 
          select(!!top_n_indiv_variables) |> 
          pivot_longer(
            cols = everything(), 
            names_to = "variable", values_to = "value_indiv"
          )
      ) |> 
      mutate(in_top_n = ifelse(variable %in% .y, yes = "$\\checkmark$", no = ""))
  ) |> 
    list_rbind()
  top_n_indiv_characteristics |> 
    left_join(variable_names, by = "variable") |> 
    left_join(variable_names_categ, by = "variable") |> 
    left_join(values_reference) |> 
    mutate(label = ifelse(is.na(label), label_categ, label)) |> 
    select(label, id_indiv, value_indiv, in_top_n, val_ref) |> 
    pivot_wider(names_from = id_indiv, values_from = c(value_indiv, in_top_n))
}

Let us identify two individuals:

i_low <- 23
i_high <- 3300
round(c(pred_val_xgb[i_low], pred_val_xgb[i_high]), 3)
[1] 0.217 0.714
Code
plot_indiv_low_pred <- 
  plot_individual_variable_effect(
  i = i_low, 
  treeshap_res = treeshap_xgb, 
  predicted_val_i = pred_val_xgb[i_low], 
  mean_pred_ref = mean_pred_xgb_ref, n = 10,
  min_max = c(0.15, .55)
)
plot_indiv_low_pred
Warning: Removed 1 row containing missing values or values outside the scale range
(`geom_segment()`).
Figure 8.3: SHAP values for an individual with a low predicted score, when the latter is estimated with XGBoost. Positive class: Imaginary healthy patient.
Code
plot_indiv_high_pred <- 
  plot_individual_variable_effect(
  i = i_high, 
  treeshap_res = treeshap_xgb, 
  predicted_val_i = pred_val_xgb[i_high], 
  mean_pred_ref = mean_pred_xgb_ref, n = 10,
  min_max = c(0.3, .85)
)
plot_indiv_high_pred
Warning: Removed 1 row containing missing values or values outside the scale range
(`geom_segment()`).
Figure 8.4: SHAP values for an individual with a high predicted score, when the latter is estimated with XGBoost. Positive class: Imaginary healthy patient.

We retrieve the characteristics of the identified most important variables for both individuals, as well as the average values in the reference sample (train set):

tb_two_indivs_example <- get_table_charact(
    i = c(i_low, i_high), 
    treeshap_res = treeshap_xgb, 
    n = 10
)
Joining with `by = join_by(variable)`
Joining with `by = join_by(variable)`
Joining with `by = join_by(variable)`

And then, we display them in a nice table:

Code
tb_two_indivs_example |> 
  select(label, value_indiv_23, in_top_n_23, value_indiv_3300, in_top_n_3300, val_ref) |> 
  mutate(val_ref = round(val_ref, 2)) |> 
  kableExtra::kable(booktabs = TRUE) |>
  kableExtra::kable_classic(full_width = F) |> 
  kableExtra::kable_styling() |>
  unclass() |> cat()
Table 8.1: Characteristics for the Most Important Variables for Two Individuals According to their SHAP Values.
label value_indiv_23 in_top_n_23 value_indiv_3300 in_top_n_3300 val_ref
Net Income per Cons. Unit 2400.00 \(\checkmark\) 900.00 \(\checkmark\) 1610.26
Age 63.00 \(\checkmark\) 79.00 \(\checkmark\) 48.56
Reimbursement General Practitioner 60.40 \(\checkmark\) 206.20 \(\checkmark\) 87.52
Gender = Male 1.00 \(\checkmark\) 0.00 \(\checkmark\) 0.48
Frequency Meeting with People in Organizations = At least once a week 1.00 \(\checkmark\) 0.00 0.17
Frequency Meeting with People in Organizations = Never 0.00 \(\checkmark\) 0.00 \(\checkmark\) 0.51
Low Back Pain = Yes 0.00 \(\checkmark\) 0.00 \(\checkmark\) 0.20
Participation in Group Activities = No 0.00 \(\checkmark\) 1.00 0.63
Neck Pain = Yes 0.00 \(\checkmark\) 0.00 0.15
Reimbursement Pharmacy 396.26 \(\checkmark\) 1818.33 \(\checkmark\) 364.21
Long-term condition (Self-declared) = No 1.00 0.00 \(\checkmark\) 0.81
No. Medical Sessions General Pract. 4.00 10.00 \(\checkmark\) 4.64
Waiver Appointment Delay Too Long = No 1.00 0.00 \(\checkmark\) 0.68

8.4 Variable importance

We compute the average absolute Shap value for each variable and order the results by descending values. We filter the results to focus on the predictions of being classified as a imaginary healthy patient.

# Mean(| SHAP |)
var_imp_rf_mean_abs_shap <- treeshap_rf_opposite$shaps |> 
  summarise(across(everything(), ~mean(abs(.x))))

# Mean(SHAP)
var_imp_rf_mean_shap <- treeshap_rf_opposite$shaps |> 
  summarise(across(everything(), ~mean(.x)))

We order the variables by descending values of the mean absolute SHAP values.

order_variables_shap_rf <- 
  sort(unlist(var_imp_rf_mean_abs_shap), decreasing = TRUE) |> 
  enframe() |> 
  left_join(
    sort(unlist(var_imp_rf_mean_shap), decreasing = TRUE) |> 
      enframe(value = "mean"),
    by = "name"
  )

The top 10 variables:

top_n_variables_rf <- order_variables_shap_rf$name[1:10]
top_n_variables_rf
 [1] "MENAGE_revucinsee"    "SOINS_remomn"         "QST_ct_depech"       
 [4] "SOINS_rempha"         "SOINS_seac_omn"       "PERSONNE_pb_lombalgi"
 [7] "SOINS_pf_frpha"       "OPINION1_renonc_long" "OPINION1_renonc_dent"
[10] "PERSONNE_pb_cervical"
# Mean(| SHAP |)
var_imp_xgb_mean_abs_shap <- treeshap_xgb$shaps |> 
  summarise(across(everything(), ~mean(abs(.x))))

# Mean(SHAP)
var_imp_xgb_mean_shap <- treeshap_xgb$shaps |> 
  summarise(across(everything(), ~mean(.x)))

We order the variables by descending values of the mean absolute SHAP values.

order_variables_shap_xgb <- 
  sort(unlist(var_imp_xgb_mean_abs_shap), decreasing = TRUE) |> 
  enframe() |> 
  left_join(
    sort(unlist(var_imp_xgb_mean_shap), decreasing = TRUE) |> 
      enframe(value = "mean"),
    by = "name"
  )
save(order_variables_shap_xgb, file = "../data/out/order_variables_shap_xgb.rda")

The top 10 variables:

top_n_variables_xgb <- order_variables_shap_xgb$name[1:10]
top_n_variables_xgb
 [1] "MENAGE_revucinsee"       "SOINS_remomn"           
 [3] "PERSONNE_age"            "PERSONNE_pb_lombalgiYes"
 [5] "SOINS_rempha"            "PERSONNE_pb_cervicalYes"
 [7] "QES_tpsassoNever"        "PERSONNE_sexeMale"      
 [9] "QST_ct_depechSometimes"  "QST_ct_liberteNever"    

Let us plot the top 10 variables by importance according to the SHAP values.

Code
p_variable_importance_rf <- 
  ggplot(
    data = order_variables_shap_rf |> 
      left_join(variable_names, by = c("name" = "variable")) |> 
      mutate(
        label = ifelse(name %in% !!top_n_variables_rf, 
                       label, "Other Variables")
      ) |> 
      group_by(label) |> 
      summarise(
        value = sum(value),
        .groups = "drop"
      ) |> 
      mutate(
        label = fct_reorder(label, value),
        label = fct_relevel(label, "Other Variables")
      ),
    mapping = aes(
      x = label, 
      y = value
    )
  ) +
  geom_bar(stat = "identity") +
  coord_flip() +
  labs(y = "Mean(|SHAP Value|)", x = NULL)

p_variable_importance_rf
Figure 8.5: Variable importance according to SHAP values (with the random forest classifier)
Code
p_variable_importance_xgb <- 
  ggplot(
    data = order_variables_shap_xgb |> 
      left_join(variable_names, by = c("name" = "variable")) |> 
      left_join(variable_names_categ, by = c("name" = "variable")) |> 
      mutate(
        label = ifelse(is.na(label), label_categ, label),
        label = ifelse(name %in% !!top_n_variables_xgb, 
                       label, "Other Variables")
      ) |> 
      group_by(label) |> 
      summarise(
        value = sum(value),
        .groups = "drop"
      ) |> 
      mutate(
        label = fct_reorder(label, value),
        label = fct_relevel(label, "Other Variables")
      ),
    mapping = aes(
      x = label, 
      y = value
    )
  ) +
  geom_bar(stat = "identity") +
  coord_flip() +
  labs(y = "Mean(|SHAP Value|)", x = NULL)

p_variable_importance_xgb
Figure 8.6: Variable importance according to SHAP values (with the XGBoost classifier)

The complete ranking, for both models:

order_variables_shap_xgb |> 
  mutate(rank_xgb = row_number()) |> 
  rename(
    mean_abs_shap_xgb = value,
    mean_shap_xgb = mean
  ) |> 
  full_join(
    order_variables_shap_rf |> 
      mutate(rank_rf = row_number()) |> 
      rename(
        mean_abs_shap_rf = value,
        mean_shap_rf = mean
      )
  ) |> 
  DT::datatable()
Joining with `by = join_by(name)`

In Python, the summary_plot() function is used to offer a quick view of the effect of each variable on the prediction. We can reproduce such a plot. For quantitative variables, this type of graph seems advantageous. For quantitative variables, on the other hand, we prefer a different type of graph (see below).

Let us define a function to rescale values (min-max rescaling).

#' std1
#' a function to standardize feature values into same range
#' Source: https://github.com/pablo14/shap-values/blob/master/shap.R
std1 <- function(x) {
  ((x - min(x, na.rm = T)) / (max(x, na.rm = T) - min(x, na.rm = T)))
}

Then, we need to create a table that contains for each individual and each variable, the Shap value, and the standardized value of the characteristics.

values_min_max_rf <- 
  df_all_rf |> 
  select(!!top_n_variables_rf) |> 
  pivot_longer(cols = everything()) |> 
  group_by(name) |> 
  summarise(
    min_value = min(value), max_value = max(value),
    med_value = median(value),
    q1_value = quantile(value, probs = .25),
    q3_value = quantile(value, probs = .75)
  )
values_min_max_xgb <- 
  as_tibble(model.matrix(formula, df_all_xgb)) |> 
  select(!!top_n_variables_xgb) |> 
  pivot_longer(cols = everything()) |> 
  group_by(name) |> 
  summarise(
    min_value = min(value), max_value = max(value),
    med_value = median(value),
    q1_value = quantile(value, probs = .25),
    q3_value = quantile(value, probs = .75)
  )

We format the data to be able to create the plot.

df_plot_rf <- 
  treeshap_rf_opposite$shaps |> 
  select(!!top_n_variables_rf) |> 
  mutate(id_row = row_number()) |> 
  pivot_longer(cols = -id_row, values_to = "shap") |> 
  left_join(
    df_all_rf |> 
      select(!!top_n_variables_rf) |> 
      mutate(id_row = row_number()) |> 
      pivot_longer(cols = -id_row, values_to = "value"),
    by = c("id_row", "name")
  ) |> 
  group_by(name) |> 
  mutate(
    std_shap_value = std1(shap),
    mean_shap_value = mean(shap),
    mean_abs_shap_value = mean(abs(shap))
  ) |> 
  ungroup() |> 
  left_join(
    values_min_max_rf,
    by = "name"
  ) |> 
  mutate(
    value_variable_std = (value - min_value) / (max_value - min_value)
  ) |> 
  left_join(variable_names, by = c("name" = "variable")) |> 
  mutate(
    label = fct_reorder(label, mean_abs_shap_value)
  )
df_plot_xgb <- 
  treeshap_xgb$shaps |> 
  select(!!top_n_variables_xgb) |> 
  mutate(id_row = row_number()) |> 
  pivot_longer(cols = -id_row, values_to = "shap") |> 
  left_join(
    as_tibble(model.matrix(formula, df_all_xgb)) |> 
      select(!!top_n_variables_xgb) |> 
      mutate(id_row = row_number()) |> 
      pivot_longer(cols = -id_row, values_to = "value"),
    by = c("id_row", "name")
  ) |> 
  group_by(name) |> 
  mutate(
    std_shap_value = std1(shap),
    mean_shap_value = mean(shap),
    mean_abs_shap_value = mean(abs(shap))
  ) |> 
  ungroup() |> 
  left_join(
    values_min_max_xgb,
    by = "name"
  ) |> 
  mutate(
    value_variable_std = (value - min_value) / (max_value - min_value)
  ) |> 
  left_join(variable_names, by = c("name" = "variable")) |> 
  left_join(variable_names_categ, by = c("name" = "variable")) |> 
  mutate(
    label = ifelse(is.na(label), label_categ, label),
    label = fct_reorder(label, mean_abs_shap_value)
  )

Then, using the geom_sina() function from {ggforce}, we can create the summary plot. Each dot, for each variable given in the y-axis represent an individual. The x-axis gives the estimated Shap value of the variable for each individual. The colours state whether the value of the variable for a specific individual is low (blue) or high (red), using the standardized value as the reference. The variables on the y-axis appear according to their relative importance in explaining the prediction. For example, the graph shows that pharmacy deductible amounts are relatively important in explaining the probability of being classified as an imaginary healthy patient. When the value of this franchise is low (blue dots), this variable influences the prediction downwards (since the blue dots are mostly associated with negative Shap values); conversely, when the pharmacy franchise is high (red dots), the probability of being classified as an imaginary healthy patient increases.

Code
library(ggforce)
p_rf <- ggplot(
  data = df_plot_rf,
  mapping = aes(
    y = shap, 
    x = label, 
    colour = value_variable_std
  )
) +
  geom_sina(alpha = .5) +
  scale_color_gradient(
    "Feature value",
    low="#0081BC", high = "#EE324E",
    breaks = c(0,1), labels = c("Low", "High"),
    guide = guide_colourbar(
      direction = "horizontal",
      barwidth = 40,
      title.position = "bottom",
      title.hjust = .5,
      midpoint = 0
    )
  ) +
  labs(y = "SHAP value (impact on model output)", x = NULL) +
  theme(
    legend.position = "bottom",
    legend.key.height   = unit(1, "line"),
    legend.key.width    = unit(1.5, "line")
  ) +
  coord_flip() +
  geom_hline(yintercept = 0, linetype = "dashed")
p_rf
Figure 8.7: Estimated effect of variables on the probability of being predicted an imaginary healthy patienteach individual (random forest)
Code
p_xgb <- ggplot(
  data = df_plot_xgb,
  mapping = aes(
    y = shap, 
    x = label, 
    colour = value_variable_std
  )
) +
  geom_sina(alpha = .5) +
  scale_color_gradient(
    "Feature value",
    low="#0081BC", high = "#EE324E",
    breaks = c(0,1), labels = c("Low", "High"),
    guide = guide_colourbar(
      direction = "horizontal",
      barwidth = 40,
      title.position = "bottom",
      title.hjust = .5,
      midpoint = 0
    )
  ) +
  labs(y = "SHAP value (impact on model output)", x = NULL) +
  theme(
    legend.position = "bottom",
    legend.key.height   = unit(1, "line"),
    legend.key.width    = unit(1.5, "line")
  ) +
  coord_flip() +
  geom_hline(yintercept = 0, linetype = "dashed")
p_xgb
Figure 8.8: Estimated effect of variables on the probability of being predicted an imaginary healthy patienteach individual (XGBoost)

Now, instead of a summary plot, let us create a plot for each of the variables in the top n. To that end, we create a function which plots the estimated SHAP values for a single variable for each individuals. When the variable is categorical, we want the graph to be split depending on the categories.

For continuous variable, we may want to make the colors of the dots depend on the quantiles. This is the case for net income, because of the shape of the distribution:

hist(df_all_rf$MENAGE_revucinsee, breaks = 100)

summary(df_all_rf$MENAGE_revucinsee)
   Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
     74     992    1400    1610    2000   25300 

The code of the plot function is a bit long.

Code
#' Plots the SHAP values for each individuals, for a single variable
#'
#' @param variable_name name of the variable
#' @param treeshap_res treeshap results
#' @param df_all data frame with observations from train and test sets
#' @param quantile if TRUE, uses the empirical quantiles as the colour code (default to FALSE)
#' @param size_dots size of the dots
#' @param bar_width legend bar width for qualitative variables
plot_shap_all_indiv_variable_rf <- function(variable_name, 
                                            treeshap_res,
                                            df_all,
                                            quantile = FALSE, 
                                            size_dots = .6, 
                                            bar_width = 30) {
  df_plot <- 
    treeshap_res$shaps |> 
    select(!!variable_name) |> 
    rename(SHAP_value = !!variable_name) |> 
    mutate(variable_name = variable_name) |> 
    as_tibble() |> 
    # Add value of the variable from the dataset
    mutate(
      variable_value = df_all |> pull(!!variable_name)
    )
  
  is_factor <- variable_name %in% names(corresp_factors)
  if (is_factor == TRUE) {
    # Get the levels for the variable
    niveaux <- corresp_factors[[variable_name]]
    df_plot <- 
      df_plot |> 
      left_join(
        tibble(
          variable_value = 1:length(niveaux),
          variable_value_level = niveaux
        ),
        by = "variable_value"
      )
    
    if (any(niveaux == "No answer")) {
      nouveaux_niveaux <- niveaux
      nouveaux_niveaux[nouveaux_niveaux == "No answer"] <- "Did not answer"
      
      if ("At least once a month" %in% niveaux) {
        ordre_niveaux <- c(
          "Every day or almost every day",
          "At least once a week",
          "At least once a month",
          "Less than once a month", "Never", "Did not answer"
        )
        ordre_niveaux_tbl <- tibble(niveaux, nouveaux_niveaux) |> 
          mutate(
            nouveaux_niveaux = factor(nouveaux_niveaux, levels = ordre_niveaux)
          ) |> 
          arrange(desc(nouveaux_niveaux))
        
        niveaux <- ordre_niveaux_tbl$niveaux
        nouveaux_niveaux <- as.character(ordre_niveaux_tbl$nouveaux_niveaux)
        
      } else if ("Sometimes" %in% niveaux) {
        ordre_niveaux <- c("Always", "Often", "Sometimes", "Never", "Did not answer")
        ordre_niveaux_tbl <- 
          tibble(niveaux, nouveaux_niveaux) |> 
          mutate(
            nouveaux_niveaux = factor(nouveaux_niveaux, levels = ordre_niveaux)
          ) |> 
          arrange(desc(nouveaux_niveaux))
        
        niveaux <- ordre_niveaux_tbl$niveaux
        nouveaux_niveaux <- as.character(ordre_niveaux_tbl$nouveaux_niveaux)
      }
      df_plot <- 
        df_plot |> left_join(
          tibble(
            variable_value_level = niveaux, 
            variable_value_level_new = nouveaux_niveaux
          ),
          by = "variable_value_level"
        ) |> 
        select(-variable_value_level) |> 
        rename(variable_value_level = variable_value_level_new)
      
      niveaux <- nouveaux_niveaux
    }
    
    df_plot <- df_plot |> 
      mutate(
        variable_value_level = factor(variable_value_level, levels = niveaux)
      )
  }
  
  if (is_factor == TRUE) {
    p <- 
      ggplot(
        data = df_plot, 
        mapping = aes(
          y = SHAP_value, 
          x = variable_value_level
        )
      ) +
      geom_sina(size = size_dots) +
      geom_violin(alpha=0) +
      coord_flip() +
      labs(y = "SHAP value (impact on model output)", x = NULL)
  } else {
    
    if (quantile == TRUE) {
      
      df_plot$quantile_variable <- 
        cut(
          pull(df_plot, "variable_value"),
          labels = c(
            "$\\leq$ D1", "D1 to D2", "D2 to D3", "D3 to D4",
            "D4 to D5", "D5 to D6", "D6 to D7", "D7 to D8", 
            "D8 to D9", "> D9"
          ),
          breaks = c(
            0,
            quantile(
              magrittr::extract2(df_all, variable_name),
              probs = seq(.1, .9, .1)
            ),
            Inf
          )
        )
      
      df_plot <- 
        df_plot |> 
        mutate(quantile_revenu_num = as.numeric(quantile_variable))
      
      p <- 
        ggplot(
          data = df_plot |> mutate(x = 1),,
          mapping = aes(
            y = SHAP_value,
            x = x,
            colour = quantile_revenu_num
          )
        ) +
        geom_sina(alpha = .5, size = size_dots) +
        geom_violin(alpha = 0) +
        scale_color_gradient(
          "Variable value",
          breaks = c(1, 4, 7, 10),
          labels = c("$\\leq$ D1", "D3 to D4", "D6 to D7", "$>$ D9"),
          low = "#0081BC", high = "#EE324E",
          guide = guide_colourbar(
            direction = "horizontal",
            barwidth = bar_width,
            title.position = "bottom",
            title.hjust = .5
          )
        ) +
        labs(y = "SHAP value (impact on model output)", x = NULL) +
        theme(
          legend.position = "bottom",
          legend.key.height = unit(1, "line"),
          legend.key.width  = unit(1.5, "line"),
          axis.text.y = element_blank(),
          axis.ticks.y = element_blank()
        ) +
        coord_flip()
      
    } else {
      # Not using quantile
      p <- 
        ggplot(
          data = df_plot |> mutate(x = 1),
          mapping = aes(
            x = x,
            y = SHAP_value,
            colour = variable_value
          )
        ) +
        geom_sina(alpha = .5, size = size_dots) +
        geom_violin(alpha = 0) +
        scale_color_gradient(
          "Variable value",
          low = "#0081BC", high = "#EE324E",
          guide = guide_colourbar(
            direction = "horizontal",
            barwidth = bar_width,
            title.position = "bottom",
            title.hjust = .5)
        ) +
        labs(y = "SHAP value (impact on model output)", x = NULL) +
        theme(
          legend.position = "bottom",
          legend.key.height = unit(1, "line"),
          legend.key.width  = unit(1.5, "line"),
          axis.text.y = element_blank(), 
          axis.ticks.y = element_blank()
        ) +
        coord_flip()
    }
  }
  p + theme(
    plot.title.position = "plot",
    plot.title = element_text(hjust = 0, size = rel(1.3), face = "bold"),
    legend.text = element_text(size = rel(.8)),
    legend.title = element_text(size = rel(.8))
  )
}
Net Income per Cons. Unit

Reimbursement General Practitioner

Have to Hurry to Do Job

Reimbursement Pharmacy

No. Medical Sessions General Pract.

Low Back Pain

Deduct. Pharmacy

Waiver Appointment Delay Too Long

Waiver Dental Care

Neck Pain

Code
#' Plots the SHAP values for each individuals, for a single variable
#'
#' @param variable_name name of the variable
#' @param treeshap_res treeshap results
#' @param df_all data frame with observations from train and test sets
#' @param quantile if TRUE, uses the empirical quantiles as the colour code (default to FALSE)
#' @param size_dots size of the dots
#' @param bar_width legend bar width for qualitative variables
plot_shap_all_indiv_variable_xgb <- function(variable_name, 
                                             treeshap_res,
                                             df_all,
                                             quantile = FALSE, 
                                             size_dots = .6, 
                                             bar_width = 30) {
  df_plot <- 
    treeshap_res$shaps |> 
    select(!!variable_name) |> 
    rename(SHAP_value = !!variable_name) |> 
    mutate(variable_name = variable_name) |> 
    as_tibble() |> 
    # Add value of the variable from the dataset
    mutate(
      variable_value = df_all |> pull(!!variable_name)
    )
  
  is_factor <- variable_name %in% variable_names_categ$variable
  if (is_factor == TRUE) {
    # Not using quantile
    p <-
      ggplot(
        data = df_plot |> mutate(x = 1),
        mapping = aes(
          x = x,
          y = SHAP_value,
          colour = variable_value
        )
      ) +
      geom_sina(alpha = .5, size = size_dots) +
      geom_violin(alpha = 0) +
      scale_color_gradient(
        "Variable value",
        low = "#0081BC", high = "#EE324E", n.breaks = 2
      ) +
      labs(y = "SHAP value (impact on model output)", x = NULL) +
      theme(
        legend.position = "bottom",
        legend.key.height = unit(1, "line"),
        legend.key.width  = unit(1.5, "line"),
        axis.text.y = element_blank(), 
        axis.ticks.y = element_blank()
      ) +
      coord_flip() +
      guides(color = guide_legend(override.aes = list(
        color = c("#0081BC", "#EE324E"),
        label = c("0", "1")
      )))
  } else {
    
    if (quantile == TRUE) {
      
      breaks_q <- c(
        0,
        quantile(
          magrittr::extract2(df_all, variable_name),
          probs = seq(.1, .9, .1)
        ),
        Inf
      )
      labels_q <- c(
        "$\\leq$ D1", "D1 to D2", "D2 to D3", "D3 to D4",
        "D4 to D5", "D5 to D6", "D6 to D7", "D7 to D8", 
        "D8 to D9", "> D9"
      )
      ind_remove <- which(duplicated(breaks_q))
      if (length(ind_remove) > 0) {
        labels_q <- labels_q[-ind_remove]
        breaks_q <- breaks_q[-ind_remove]
      }
      df_plot$quantile_variable <- 
        cut(
          pull(df_plot, "variable_value"),
          labels = labels_q,
          breaks = breaks_q
        )
      
      df_plot <- 
        df_plot |> 
        mutate(quantile_revenu_num = as.numeric(quantile_variable))
      
      p <- 
        ggplot(
          data = df_plot |> mutate(x = 1),,
          mapping = aes(
            y = SHAP_value,
            x = x,
            colour = quantile_revenu_num
          )
        ) +
        geom_sina(alpha = .5, size = size_dots) +
        geom_violin(alpha = 0) +
        scale_color_gradient(
          "Variable value",
          breaks = c(1, 4, 7, 10),
          labels = c("$\\leq$ D1", "D3 to D4", "D6 to D7", "$>$ D9"),
          low = "#0081BC", high = "#EE324E",
          guide = guide_colourbar(
            direction = "horizontal",
            barwidth = bar_width,
            title.position = "bottom",
            title.hjust = .5
          )
        ) +
        labs(y = "SHAP value (impact on model output)", x = NULL) +
        theme(
          legend.position = "bottom",
          legend.key.height = unit(1, "line"),
          legend.key.width  = unit(1.5, "line"),
          axis.text.y = element_blank(),
          axis.ticks.y = element_blank()
        ) +
        coord_flip()
      
    } else {
      # Not using quantile
      p <- 
        ggplot(
          data = df_plot |> mutate(x = 1),
          mapping = aes(
            x = x,
            y = SHAP_value,
            colour = variable_value
          )
        ) +
        geom_sina(alpha = .5, size = size_dots) +
        geom_violin(alpha = 0) +
        scale_color_gradient(
          "Variable value",
          low = "#0081BC", high = "#EE324E",
          guide = guide_colourbar(
            direction = "horizontal",
            barwidth = bar_width,
            title.position = "bottom",
            title.hjust = .5)
        ) +
        labs(y = "SHAP value (impact on model output)", x = NULL) +
        theme(
          legend.position = "bottom",
          legend.key.height = unit(1, "line"),
          legend.key.width  = unit(1.5, "line"),
          axis.text.y = element_blank(), 
          axis.ticks.y = element_blank()
        ) +
        coord_flip()
    }
  }
  p + theme(
    plot.title.position = "plot",
    plot.title = element_text(hjust = 0, size = rel(1.3), face = "bold"),
    legend.text = element_text(size = rel(.8)),
    legend.title = element_text(size = rel(.8))
  )
}
Net Income per Cons. Unit

Reimbursement General Practitioner

Age

Low Back Pain = Yes

Reimbursement Pharmacy

Neck Pain = Yes

Frequency Meeting with People in Organizations = Never

Gender = Male

Have to Hurry to Do Job = Sometimes

Very Little Freedom to Do Job = Never

8.5 Clustering

It may be possible to group individuals in the dataset depending the profile of the shap values. Let us perform a hierarchical clustering to group the individuals depending on their Shapley values. We will only keep a few variables to represent each individual, and will focus only on people classified as imaginary healthy patients. First, let us compute the aberage Shapley value for each variable, among the predicted imaginary healthy patients.

8.5.1 Note

We only focus on the SHAP values computed based on the predictions made with XGBoost in this part.

Recall we put in a tibble (order_variables_shap_xgb) the average absolute SHAP values (column value) and the average SHAP values (column mean):

order_variables_shap_xgb
# A tibble: 175 × 3
   name                     value     mean
   <chr>                    <dbl>    <dbl>
 1 MENAGE_revucinsee       0.181  -0.0190 
 2 SOINS_remomn            0.134  -0.00693
 3 PERSONNE_age            0.121  -0.0158 
 4 PERSONNE_pb_lombalgiYes 0.0971 -0.00476
 5 SOINS_rempha            0.0927 -0.00631
 6 PERSONNE_pb_cervicalYes 0.0760 -0.00664
 7 QES_tpsassoNever        0.0717 -0.00299
 8 PERSONNE_sexeMale       0.0694 -0.00195
 9 QST_ct_depechSometimes  0.0653 -0.00668
10 QST_ct_liberteNever     0.0557 -0.00462
# ℹ 165 more rows

We can display the average of the absolute SHAP values on a plot. A vertical red line is added at the average of the average of absolute Shapley values. Red dots indicate variables with negative average Shapley values, green dots indicate positive average values, and gray dots indicate zero average values.

ggplot(
  data = order_variables_shap_xgb |>
    left_join(variable_names, by = c("name" = "variable")) |>
    left_join(variable_names_categ, by = c("name" = "variable")) |>
    mutate(label = ifelse(is.na(label), label_categ, label)) |>
    rename(mean_abs_shap_value = value) |>
    mutate(
      sign = sign(mean),
      sign = factor(sign, levels = c(-1, 0, 1), labels = c("-", "0", "+"))
    ),
  mapping = aes(
    y = fct_reorder(label, desc(mean_abs_shap_value)),
    x = mean_abs_shap_value
  )
) +
  geom_point(mapping = aes(colour = sign)) +
  geom_vline(
    xintercept = mean(order_variables_shap_xgb$value), colour = "red"
  ) +
  scale_colour_manual(
    "Feature impact",
    values = c("-" = "#EE324E", "+" = "#009D57", "0" = "gray"),
    labels = c("-" = "Negative", "+" = "Positive")
  ) +
  labs(x = "Mean(|Shap value|)", y = "Variable")
Figure 8.9: Mean absolute SHAP values for each variable

Let us create a table in which we only keep the individuals predicted as imaginary healthy patients:

df_clustering_Not_D_and_inf_Q1 <-
  treeshap_xgb$shaps |>
  mutate(id_row = row_number()) |>
  mutate(predicted = pred_val_all_xgb) |>
  filter(predicted > .5) |>
  as_tibble() |>
  select(-predicted)

Let us only keep the 13 variables with the highest mean absolute Shapley values:

nb_top <- 13
variables_to_keep_cluster <-
  order_variables_shap_xgb |>
  arrange(desc(value)) |>
  dplyr::slice_head(n = nb_top) |>
  pull("name")
variables_to_keep_cluster
 [1] "MENAGE_revucinsee"             "SOINS_remomn"                 
 [3] "PERSONNE_age"                  "PERSONNE_pb_lombalgiYes"      
 [5] "SOINS_rempha"                  "PERSONNE_pb_cervicalYes"      
 [7] "QES_tpsassoNever"              "PERSONNE_sexeMale"            
 [9] "QST_ct_depechSometimes"        "QST_ct_liberteNever"          
[11] "PERSONNE_pb_allergiYes"        "OPINION1_renonc_consNo answer"
[13] "PERSONNE_aldNo"               

We can visualize which variables were kept:

ggplot(
  data = order_variables_shap_xgb |>
    left_join(variable_names, by = c("name" = "variable")) |>
    left_join(variable_names_categ, by = c("name" = "variable")) |>
    mutate(label = ifelse(is.na(label), label_categ, label)) |>
    rename(mean_abs_shap_value = value) |>
    mutate(
      in_top_kept = name %in% variables_to_keep_cluster
    ),
  mapping = aes(
    y = fct_reorder(label, desc(mean_abs_shap_value)),
    x = mean_abs_shap_value
  )
) +
  geom_point(mapping = aes(colour = in_top_kept)) +
  geom_vline(
    xintercept = mean(order_variables_shap_xgb$value), colour = "red"
  ) +
  scale_colour_manual(
    "Variable kept for clustering",
    values = c("FALSE" = "gray", "TRUE" = "#009D57"),
    labels = c("FALSE" = "Discarded", "+" = "Kept")
  ) +
  labs(x = "Mean(|Shap value|)", y = "Variable")
Figure 8.10: Variables kept for the clustering and their averable absolute Shapley value.

We need to select a number of clusters K. Instead of arbitrarily pick a value, we vary the number of clusters and look at the silhouette information for each value K.

sil_val <- NULL
dist_m <- dist(
  as.data.frame(
    df_clustering_Not_D_and_inf_Q1 |>
      select(!!!syms(variables_to_keep_cluster))
  )
)
hierarchical_clust <- hclust(dist_m, method = "ward.D")
for (K in 2:10) {
  clusters_k <- cutree(hierarchical_clust, K)
  sil <- cluster::silhouette(clusters_k, dist_m)
  sil_val <- c(sil_val, mean(sil[, 3]))
}
p_silhouette <-
  ggplot(data = tibble(K = 2:10, y = sil_val), aes(x = K, y = y)) +
  geom_line() +
  geom_point() +
  labs(x = "K", y = "Silhouette score")
p_silhouette
Figure 8.11: Silhouette scores

We hence pick 3 as the number of desired clusters, and then assign the observation their cluster.

K <- 3
clusters_km_Not_D_and_inf_Q <- hierarchical_clust
df_clustering_Not_D_and_inf_Q1$cluster <- cutree(clusters_km_Not_D_and_inf_Q, K)

We create a table to plot the Shap values of these individuals, and order the observations according to their distance from each other.

df_plot <- df_clustering_Not_D_and_inf_Q1[clusters_km_Not_D_and_inf_Q$order, ]

For variables that are not in the top, we create an “Other” category.

df_plot <-
  df_plot |> select(-id_row) |>
  mutate(x = row_number()) |>
  pivot_longer(
    cols = -c(cluster, x), names_to = "variable", values_to = "value"
  ) |>
  mutate(
    variable = ifelse(
      variable %in% variables_to_keep_cluster,
      yes = variable,
      no = "Other"
    )
  ) |>
  group_by(cluster, x, variable) |>
  summarise(value = mean(value), .groups = "drop") |>
  left_join(
    variable_names |>
      bind_rows(variable_names_categ |> select(variable, label = label_categ)) |>
      bind_rows(tibble(variable = "Other", label = "Other", type = NA)),
    by = c("variable" = "variable")
  ) |>
  arrange(x)

The order of the variables:

labels_order <-
  tibble(variable = variables_to_keep_cluster) |>
  left_join(variable_names, by = c("variable")) |>
  left_join(variable_names_categ, by = c("variable")) |>
  mutate(label = ifelse(is.na(label), label_categ, label)) |>
  pull(label)

labels_order <- c(labels_order, "Other")

We apply this order to the variables in the data plot.

df_plot <-
  df_plot |>
  mutate(label = factor(label, levels = labels_order))

Then, we can plot the graph to visualize the average the contribution of each variable of the top within each cluster.

nb_vars <- df_plot$variable |> unique() |> length()

x_lines <-
  df_plot |>
  group_by(cluster) |>
  summarise(x_max = max(x)) |>
  pull(x_max) |>
  sort()
x_lines <- c(1, x_lines)

plot_clusters <-
  ggplot(data = df_plot) +
  geom_vline(xintercept = x_lines) +
  geom_col(aes(x = x, y = value, fill = label), width =1, alpha = 0.9) +
  geom_hline(yintercept = 0, col = "black") +
  scale_fill_manual("", values = c(RColorBrewer::brewer.pal(nb_vars-1, 'Paired'), "black", "gray")) +
  labs(x = "Individuals", y = "Output value") +
  # theme_paper() +
  theme(
    legend.position = "bottom",
    axis.text.x = element_blank(),
    axis.ticks.x = element_blank()
  )
plot_clusters + guides(fill = guide_legend(ncol = 3))
Figure 8.12: Decomposition of the Effect of the Most Important Variables on the Probability of Being Predicted as an Imaginary Healthy Patient for Individuals Predicted as Such. All individuals.

With aggregated mean SHAP values at the cluster level:

df_plot_cluster <- 
  df_plot |> 
  group_by(cluster, variable, label) |> 
  summarise(value = mean(value)) |> 
  ungroup() |> 
  mutate(cluster = factor(cluster, levels = unique(df_plot$cluster)))
`summarise()` has grouped output by 'cluster', 'variable'. You can override
using the `.groups` argument.
plot_clusters_2 <-
  ggplot(
    data = df_plot_cluster |>
      mutate(
        cluster = factor(
          cluster,
          levels = levels(df_plot_cluster$cluster),
          labels = seq_len(length(levels(df_plot_cluster$cluster)))
        )
      ),
    mapping = aes(x = cluster, fill = label, y = value)
  ) +
  geom_bar(position = "stack", stat = "identity") +
  labs(x = "Cluster",y = "Output value") +
  scale_fill_manual(
    "", 
    values = c(RColorBrewer::brewer.pal(nb_vars-1, 'Paired'), "black", "gray")
  )+
  theme(
    legend.position = "bottom"
  )
Warning in RColorBrewer::brewer.pal(nb_vars - 1, "Paired"): n too large, allowed maximum for palette Paired is 12
Returning the palette you asked for with that many colors
plot_clusters_2 + guides(fill = guide_legend(ncol = 3)) +
  geom_hline(yintercept = 0, linetype = "dashed")
Figure 8.13: Decomposition of the Effect of the Most Important Variables on the Probability of Being Predicted as an Imaginary Healthy Patient for Individuals Predicted as Such. Average per Cluster.

8.5.2 Descriptive Statistics of the Clusters

Let us now get some descriptive statistics at the cluster level, for each of these variables. We need to get the values of the variables for the individuals predicted as imaginary healthy patients.

For categorical variables, we would like to get the most frequent value, in each cluster. For numerical variables, we would like to get the average and standard deviation. Let us isolate these two types of variables.

The categorical variables:

top_name_categ <-
  variable_names_categ |>
  filter(variable %in% variables_to_keep_cluster) |>
  pull("variable_raw")

The numerical variables:

top_name_num <- variables_to_keep_cluster[!variables_to_keep_cluster %in% variable_names_categ$variable]

Let us create a dataset with the characteristics and the cluster in which the individuals belong:

df_stat_des <-
  df_all_xgb |>
  tibble() |>
  mutate(id_row = row_number()) |>
  # Add cluster
  left_join(
    df_clustering_Not_D_and_inf_Q1 |> select(id_row, cluster),
    by = "id_row"
  )

We define a function, get_stats() which computes the desired statistics:

get_stats <- function(df) {
  # Desc. Stat on numerical variables
  df_num <-
    df |>
    select(!!top_name_num) |>
    pivot_longer(cols = everything()) |>
    group_by(name) |>
    summarise(mean = mean(value), sd = sd(value)) |>
    mutate(value = str_c(round(mean, 2), " (", round(sd, 2),")")) |>
    select(name, value)

  df_categ <-
    df |>
    select(!!top_name_categ) |>
    pivot_longer(cols = everything()) |>
    group_by(name) |>
    count(value) |>
    mutate(pct = round(100 *n / sum(n), 1)) |>
    top_n(n = 1) |>
    mutate(value = str_c(value, ' (', pct, "%)")) |>
    select(name, value)

  df_num |> bind_rows(df_categ)
}

We will compute the statistics on the following datasets:

  • whole dataset
  • sample of individuals predicted as imaginary healthy patients
  • sample of individuals predicted as healthy patients
  • each of the clusters (i.e., sub sample of the sample of individuals predicted as imaginary healthy patients).

On the whole dataset:

stat_des_whole <- get_stats(df_stat_des) |> mutate(type = "Whole sample")
Selecting by pct

On the dataset of individuals predicted as imaginary healthy patients by the classifier:

stat_des_imaginary <- df_stat_des |>
  filter(id_row %in% df_clustering_Not_D_and_inf_Q1$id_row) |>
  get_stats() |> mutate(type = "Imaginary healthy")
Selecting by pct

On the dataset of individuals predicted as imaginary healthy patients by the classifier:

stat_des_non_imaginary <- df_stat_des |>
  filter(!id_row %in% df_clustering_Not_D_and_inf_Q1$id_row) |>
  get_stats() |> mutate(type = "Healthy")
Selecting by pct

And on each of the clusters:

df_stat_des_clusters <- NULL
for (i_clust in 1:K) {
  df_stat_des_clusters <-
    df_stat_des_clusters |> bind_rows(
      df_stat_des |>
        filter(cluster == !!i_clust) |>
        get_stats() |> mutate(type = as.character(i_clust))
    )
}
Selecting by pct
Selecting by pct
Selecting by pct

We also compare the clusters by computing the p-values of an ANOVA for quantitative variables, and of a chi-squared test for qualitative variables.

For quantitative variables:

p_values_categ <- NULL
for (name in top_name_categ) {
  values <- df_stat_des |> pull(!!name)
  # confusion matrix
  tab <- table(df_stat_des$cluster, values)
  # chi-squared test
  chi2 <- chisq.test(tab, correct = T)
  # p-value
  pvalue <- chi2$p.value
  p_values_categ <- bind_rows(
    p_values_categ,
    tibble(name = name, pvalue = pvalue)
  )
}

For numerical variables:

p_values_num <- NULL
for (name in top_name_num) {
  # anova test
  values <- df_stat_des |> pull(!!name)
  anova <- anova(aov(as.numeric(values) ~ df_stat_des$cluster))
  # p-value
  pvalue <- anova$`Pr(>F)`[1]
  p_values_num <- bind_rows(
    p_values_num,
    tibble(name = name, pvalue = pvalue)
  )
}

We merge the p-values in a single object.

tb_p_values <- p_values_categ |> bind_rows(p_values_num)

To build a prettier table, let us build a tibble with the labels of the variables.

table_ref <-
  tibble(variable = variables_to_keep_cluster) |>
  left_join(variable_names_categ, by = c("variable")) |>
  mutate(
    variable = ifelse(is.na(variable_raw), variable, variable_raw)
  ) |>
  left_join(
    variable_names
  )
Joining with `by = join_by(variable)`

Let us merge the four tables with the descriptive statistics and add the p-values of the clusters comparisons.

tb_desc_stats_clusters <-
  df_stat_des_clusters |>
  bind_rows(stat_des_whole) |>
  bind_rows(stat_des_imaginary) |>
  bind_rows(stat_des_non_imaginary) |>
  mutate(
    type = factor(
      type, 
      levels = c(1:K, "Imaginary healthy", "Healthy", "Whole sample")
    )
  ) |>
  pivot_wider(names_from = type, values_from = value) |>
  left_join(
    table_ref |> select(variable, label),
    by = c("name" = "variable")
  ) |>
  left_join(unique(tb_p_values)) |>
  relocate(label, .before = name) |>
  select(-name) |>
  unique() |>
  mutate(label = factor(label, levels = unique(table_ref$label))) |>
  arrange(label)
Joining with `by = join_by(name)`
tb_desc_stats_clusters |>
  select(-pvalue) |>
  add_column(p_value = format.pval(tb_desc_stats_clusters$pvalue, 2)) |>
  kableExtra::kable() |>
  kableExtra::kable_classic(full_width = F, html_font = "Cambria") |>
  kableExtra::kable_styling(
    bootstrap_options = c("striped", "hover", "condensed", "responsive")
  )
Table 8.2: Average individual in each cluster and in the sample
label 1 2 3 Whole sample Imaginary healthy Healthy p_value
Net Income per Cons. Unit 1193.67 (687.48) 1640.95 (691.14) 643.45 (192.42) 1609.51 (1008.13) 1268.72 (728.7) 1865.21 (1108.24) < 2e-16
Reimbursement General Practitioner 228.85 (169.81) 131.15 (119.72) 85.79 (108.71) 89.44 (112.53) 146.79 (143.51) 46.41 (48.95) < 2e-16
Age 66.15 (15.23) 49.29 (16.48) 38.39 (12.82) 48.7 (18.58) 51.22 (18.39) 46.8 (18.49) < 2e-16
Low Back Pain No (81.6%) No (50.3%) No (75.6%) No (79.6%) No (65.3%) No (90.4%) < 2e-16
Reimbursement Pharmacy 1794.9 (3321.96) 306.26 (669.55) 99.92 (178.35) 361.5 (1431.61) 665.5 (1937.75) 133.41 (805.81) < 2e-16
Neck Pain No (88%) No (57.9%) No (86.9%) No (85.4%) No (73.5%) No (94.3%) < 2e-16
Frequency Meeting with People in Organizations Never (64.2%) Never (59.4%) Never (66.8%) Never (51.3%) Never (62.6%) Never (42.9%) 1.3e-05
Gender Female (51%) Female (71.7%) Female (62.9%) Female (52.2%) Female (63.8%) Male (56.4%) < 2e-16
Have to Hurry to Do Job No answer (85.8%) No answer (37.5%) No answer (71.6%) No answer (51.8%) No answer (59.3%) No answer (46.2%) < 2e-16
Very Little Freedom to Do Job No answer (85.9%) No answer (37.8%) No answer (72.1%) No answer (52.1%) No answer (59.6%) No answer (46.4%) < 2e-16
Allergy No (86.7%) No (69.5%) No (83.9%) No (85.4%) No (77.8%) No (91.1%) < 2e-16
Waiver General Practitioner No (81%) No (85.1%) No (64%) No (76%) No (78.7%) No (74%) < 2e-16
Long-term condition (Self-declared) Yes (74.1%) No (79.4%) No (89.9%) No (80.5%) No (67.2%) No (90.5%) < 2e-16

Number of observation per cluster:

table(df_clustering_Not_D_and_inf_Q1$cluster)

   1    2    3 
 626 1077  566 

Number of observation in the whole sample:

nrow(df_stat_des)
[1] 5293

Number of individuals predicted imaginary healthy:

nrow(df_clustering_Not_D_and_inf_Q1)
[1] 2269

Remaining obs (sanity check):

df_stat_des |>
  filter(!id_row %in% df_clustering_Not_D_and_inf_Q1$id_row) |> 
  nrow()
[1] 3024

Predicted class by XGB:

# Healthy
pred_class_xgb <- predict(
  grid_search_xgb, newdata = df_all_xgb
)

Overview of predictions in each clusters:

df_stat_des |> 
  mutate(predicted_val = pred_val_all_xgb) |> 
  mutate(
    pred_class = pred_class_xgb,
    correctly_pred = pred_class == status
  ) |> 
  # Keep only predicted as imaginary healthy
  filter(!is.na(cluster)) |> 
  group_by(cluster) |> 
  summarise(
    n = n(),
    accuracy = round(100*mean(correctly_pred), 1),
    no_im_healthy = sum(status == "Not_D_and_inf_Q1"),
    pct_im_healthy = round(100*no_im_healthy / n, 1)
  )
# A tibble: 3 × 5
  cluster     n accuracy no_im_healthy pct_im_healthy
    <int> <int>    <dbl>         <int>          <dbl>
1       1   626     48.9           306           48.9
2       2  1077     48             517           48  
3       3   566     43.8           248           43.8

Overall accuracy:

df_stat_des |> 
  mutate(predicted_val = pred_val_all_xgb) |> 
  mutate(
    pred_class = ifelse(
      predicted_val > .5, 
      yes = "Not_D_and_inf_Q1", "Not_D_and_sup_Q1"
    ),
    correctly_pred = pred_class == status
  ) |> 
  summarise(
    n = n(),
    accuracy = round(100*mean(correctly_pred),1),
    no_im_healthy = sum(status == "Not_D_and_inf_Q1"),
    pct_im_healthy = round(100*no_im_healthy / n, 1)
  )
# A tibble: 1 × 4
      n accuracy no_im_healthy pct_im_healthy
  <int>    <dbl>         <int>          <dbl>
1  5293     67.4          1597           30.2

Percentage of individuals correctly predicted Not_D_and_inf_Q1, i.e., imaginary healthy:

df_stat_des |> 
  mutate(predicted_val = pred_val_all_xgb) |> 
  mutate(
    pred_class = ifelse(
      predicted_val > .5, 
      yes = "Not_D_and_inf_Q1", "Not_D_and_sup_Q1"
    ),
    correctly_pred = pred_class == status
  ) |> 
  filter(pred_class == "Not_D_and_inf_Q1") |> 
  summarise(
    n = n(),
    accuracy = round(100*mean(correctly_pred),1),
    no_im_healthy = sum(status == "Not_D_and_inf_Q1"),
    pct_im_healthy = round(100*no_im_healthy / n, 1)
  )
# A tibble: 1 × 4
      n accuracy no_im_healthy pct_im_healthy
  <int>    <dbl>         <int>          <dbl>
1  2269     47.2          1071           47.2

Percentage of individuals correctly predicted Not_D_and_sup_Q1, i.e., healthy:

# healthy
df_stat_des |> 
  mutate(predicted_val = pred_val_all_xgb) |> 
  mutate(
    pred_class = ifelse(
      predicted_val > .5, 
      yes = "Not_D_and_inf_Q1", "Not_D_and_sup_Q1"
    ),
    correctly_pred = pred_class == status
  ) |> 
  filter(pred_class == "Not_D_and_sup_Q1") |> 
  summarise(
    n = n(),
    accuracy = round(100*mean(correctly_pred),1),
    no_im_healthy = sum(status == "Not_D_and_inf_Q1"),
    pct_im_healthy = round(100*no_im_healthy / n, 1)
  )
# A tibble: 1 × 4
      n accuracy no_im_healthy pct_im_healthy
  <int>    <dbl>         <int>          <dbl>
1  3024     82.6           526           17.4