Skip to content

Conversation

@naveenkcb
Copy link
Contributor

Contributor: Naveen Baskaran

Contribution Type: Interpretability method, Tests, Example

Description

This PR implements the SHAP (SHapley Additive exPlanations) interpretability method for PyHealth models, enabling users to understand which features contribute most to model predictions. SHAP is based on coalitional game theory and provides theoretically grounded feature importance scores with desirable properties like local accuracy, missingness, and consistency.

Files to Review

pyhealth/interpret/methods/init.py
pyhealth/interpret/methods/shap.py - Core SHAP method implementation. Suports embedding based attribution, continuous feature support
pyhealth/processors/tensor_processor.py - minor fix to resolve warning message
examples/shap_stagenet_mimic4.py - Example script showing the usage of SHAP method
tests/core/test_shap.py - added comprehensive test cases to test the main class, utility methods and attribution methods.

Results on mimic4-demo dataset

image

# Initialize SHAP explainer with custom parameters
shap_explainer = ShapExplainer(
model,
use_embeddings=True, # Use embeddings for discrete features
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think SHAP should be compatible with discrete tokens like ICD codes here? Correct me if I'm wrong. Will look deeper into understanding the full implementation of SHAP here later when I'm more congitively sound.

Copy link
Contributor Author

@naveenkcb naveenkcb Nov 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the SHAPExplainer class instance creation to pass the just the model and default all other values including "use_embeddings" inside the init method. Yes the SHAP works for ICD codes but will use the embeddings from the input model.

Copy link
Collaborator

@jhnwu3 jhnwu3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some other nice to haves:

  1. Can we add an entry to docs/api/interpret/shap.rst ? And add its entry in the interpretability index interpret.rst?
  2. Can we check that this is compatible when the device is on GPU? Maybe, through a colab notebook? (There's a way to install the branch/repo to the colab environment)
  3. I might be able to share some compute resources soon once NCSA gets back to me.

if coalition_size == 0 or coalition_size == n_features:
return torch.tensor(1000.0) # Large weight for edge cases

comb_val = math.comb(n_features - 1, coalition_size - 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, isn't it binom(M, |z|) here? Why do we take n_features -1 and coalition-size -1 instead of n_features and coalition_size?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the code to match the equation used in the original SHAP paper in the method _compute_kernel_weight

weight = (M - 1) / (binom(M, |z|) * |z| * (M - |z|))

  1. I also added the .rst file as requested.
  2. Added "examples/shap_stagenet_mimic4.ipynb" using colab with GPU

coalition_vectors = []
coalition_weights = []
coalition_preds = []

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you check if we don't need to add the edge case coalitions specifically (full features and no features) in the prediction set for training the kernel/linear model for predicting shapley values here?

I've linked some captum code examples here:

https://github.com/meta-pytorch/captum/blob/master/captum/attr/_core/kernel_shap.pyhttps://github.com/meta-pytorch/captum/blob/master/captum/attr/_core/lime.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

handled the edge cases and updated the code accordingly in the method _compute_kernel_shap

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants