-
Notifications
You must be signed in to change notification settings - Fork 470
SHAP Interpretability method implementation #611
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
examples/shap_stagenet_mimic4.py
Outdated
| # Initialize SHAP explainer with custom parameters | ||
| shap_explainer = ShapExplainer( | ||
| model, | ||
| use_embeddings=True, # Use embeddings for discrete features |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think SHAP should be compatible with discrete tokens like ICD codes here? Correct me if I'm wrong. Will look deeper into understanding the full implementation of SHAP here later when I'm more congitively sound.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated the SHAPExplainer class instance creation to pass the just the model and default all other values including "use_embeddings" inside the init method. Yes the SHAP works for ICD codes but will use the embeddings from the input model.
jhnwu3
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some other nice to haves:
- Can we add an entry to docs/api/interpret/shap.rst ? And add its entry in the interpretability index interpret.rst?
- Can we check that this is compatible when the device is on GPU? Maybe, through a colab notebook? (There's a way to install the branch/repo to the colab environment)
- I might be able to share some compute resources soon once NCSA gets back to me.
pyhealth/interpret/methods/shap.py
Outdated
| if coalition_size == 0 or coalition_size == n_features: | ||
| return torch.tensor(1000.0) # Large weight for edge cases | ||
|
|
||
| comb_val = math.comb(n_features - 1, coalition_size - 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait, isn't it binom(M, |z|) here? Why do we take n_features -1 and coalition-size -1 instead of n_features and coalition_size?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated the code to match the equation used in the original SHAP paper in the method _compute_kernel_weight
weight = (M - 1) / (binom(M, |z|) * |z| * (M - |z|))
- I also added the .rst file as requested.
- Added "examples/shap_stagenet_mimic4.ipynb" using colab with GPU
| coalition_vectors = [] | ||
| coalition_weights = [] | ||
| coalition_preds = [] | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you check if we don't need to add the edge case coalitions specifically (full features and no features) in the prediction set for training the kernel/linear model for predicting shapley values here?
I've linked some captum code examples here:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
handled the edge cases and updated the code accordingly in the method _compute_kernel_shap
Contributor: Naveen Baskaran
Contribution Type: Interpretability method, Tests, Example
Description
This PR implements the SHAP (SHapley Additive exPlanations) interpretability method for PyHealth models, enabling users to understand which features contribute most to model predictions. SHAP is based on coalitional game theory and provides theoretically grounded feature importance scores with desirable properties like local accuracy, missingness, and consistency.
Files to Review
pyhealth/interpret/methods/init.py
pyhealth/interpret/methods/shap.py - Core SHAP method implementation. Suports embedding based attribution, continuous feature support
pyhealth/processors/tensor_processor.py - minor fix to resolve warning message
examples/shap_stagenet_mimic4.py - Example script showing the usage of SHAP method
tests/core/test_shap.py - added comprehensive test cases to test the main class, utility methods and attribution methods.
Results on mimic4-demo dataset