-
Notifications
You must be signed in to change notification settings - Fork 473
SHAP Interpretability method implementation #611
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
naveenkcb
wants to merge
26
commits into
sunlabuiuc:master
Choose a base branch
from
naveenkcb:master
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+4,101
−1
Open
Changes from 21 commits
Commits
Show all changes
26 commits
Select commit
Hold shift + click to select a range
9c297e2
Initial attempt
naveenkcb 198e320
vocab size correction
naveenkcb 282d6e4
fix for capitalized code references
naveenkcb 3fcc157
updated more capitalized
naveenkcb bb3da26
saving ddi score
naveenkcb eb492e4
updated example for micron model
naveenkcb e79c809
fixed size determination
naveenkcb 697a050
Merge branch 'sunlabuiuc:master' into master
naveenkcb 23153f5
Merge branch 'sunlabuiuc:master' into master
naveenkcb de7ee7b
Merge branch 'sunlabuiuc:master' into master
naveenkcb 970a055
Merge branch 'sunlabuiuc:master' into master
naveenkcb aa10098
Merge branch 'sunlabuiuc:master' into master
naveenkcb e5755d8
SHAP implementation
naveenkcb 6eefb43
Merge branch 'master' of https://github.com/naveenkcb/PyHealth
naveenkcb 20d0a90
added SHAP test and example files
naveenkcb cfe7e6d
Merge branch 'sunlabuiuc:master' into master
naveenkcb 2dc4d83
added example file
naveenkcb 9d77ce5
Merge branch 'sunlabuiuc:master' into master
naveenkcb 2f3c18f
shap implementation
naveenkcb 71c4160
removed ipynb file
naveenkcb 1d784ec
update
naveenkcb 402c39f
address PR comments
naveenkcb 5d9e9e3
added example notebook
naveenkcb 47e5973
Merge branch 'master' into master
naveenkcb 80b3caa
fixed interpret/__init__
naveenkcb 0645efd
fix for failed CI test
naveenkcb File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,304 @@ | ||
| # %% Loading MIMIC-IV dataset | ||
| from pathlib import Path | ||
|
|
||
| import polars as pl | ||
| import torch | ||
|
|
||
| from pyhealth.datasets import ( | ||
| MIMIC4EHRDataset, | ||
| get_dataloader, | ||
| load_processors, | ||
| split_by_patient, | ||
| ) | ||
| from pyhealth.interpret.methods import ShapExplainer | ||
| from pyhealth.models import StageNet | ||
| from pyhealth.tasks import MortalityPredictionStageNetMIMIC4 | ||
|
|
||
| # Configure dataset location and load cached processors | ||
| dataset = MIMIC4EHRDataset( | ||
| #root="/home/naveen-baskaran/physionet.org/files/mimic-iv-demo/2.2/", | ||
| #root="/Users/naveenbaskaran/data/physionet.org/files/mimic-iv-demo/2.2/", | ||
| root="~/data/physionet.org/files/mimic-iv-demo/2.2/", | ||
| tables=[ | ||
| "patients", | ||
| "admissions", | ||
| "diagnoses_icd", | ||
| "procedures_icd", | ||
| "labevents", | ||
| ], | ||
| ) | ||
|
|
||
| # %% Setting StageNet Mortality Prediction Task | ||
| input_processors, output_processors = load_processors("../resources/") | ||
|
|
||
| sample_dataset = dataset.set_task( | ||
| MortalityPredictionStageNetMIMIC4(), | ||
| cache_dir="~/.cache/pyhealth/mimic4_stagenet_mortality", | ||
| input_processors=input_processors, | ||
| output_processors=output_processors, | ||
| ) | ||
| print(f"Total samples: {len(sample_dataset)}") | ||
|
|
||
|
|
||
| def load_icd_description_map(dataset_root: str) -> dict: | ||
| """Load ICD code → long title mappings from MIMIC-IV reference tables.""" | ||
| mapping = {} | ||
| root_path = Path(dataset_root).expanduser() | ||
| diag_path = root_path / "hosp" / "d_icd_diagnoses.csv.gz" | ||
| proc_path = root_path / "hosp" / "d_icd_procedures.csv.gz" | ||
|
|
||
| icd_dtype = {"icd_code": pl.Utf8, "long_title": pl.Utf8} | ||
|
|
||
| if diag_path.exists(): | ||
| diag_df = pl.read_csv( | ||
| diag_path, | ||
| columns=["icd_code", "long_title"], | ||
| dtypes=icd_dtype, | ||
| ) | ||
| mapping.update( | ||
| zip(diag_df["icd_code"].to_list(), diag_df["long_title"].to_list()) | ||
| ) | ||
|
|
||
| if proc_path.exists(): | ||
| proc_df = pl.read_csv( | ||
| proc_path, | ||
| columns=["icd_code", "long_title"], | ||
| dtypes=icd_dtype, | ||
| ) | ||
| mapping.update( | ||
| zip(proc_df["icd_code"].to_list(), proc_df["long_title"].to_list()) | ||
| ) | ||
|
|
||
| return mapping | ||
|
|
||
|
|
||
| ICD_CODE_TO_DESC = load_icd_description_map(dataset.root) | ||
|
|
||
| # %% Loading Pretrained StageNet Model | ||
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
| model = StageNet( | ||
| dataset=sample_dataset, | ||
| embedding_dim=128, | ||
| chunk_size=128, | ||
| levels=3, | ||
| dropout=0.3, | ||
| ) | ||
|
|
||
| state_dict = torch.load("../resources/best.ckpt", map_location=device) | ||
| model.load_state_dict(state_dict) | ||
| model = model.to(device) | ||
| model.eval() | ||
| print(model) | ||
|
|
||
| # %% Preparing dataloaders | ||
| _, _, test_data = split_by_patient(sample_dataset, [0.7, 0.1, 0.2], seed=42) | ||
| test_loader = get_dataloader(test_data, batch_size=1, shuffle=False) | ||
|
|
||
|
|
||
| def move_batch_to_device(batch, target_device): | ||
| """Move all tensors in batch to target device.""" | ||
| moved = {} | ||
| for key, value in batch.items(): | ||
| if isinstance(value, torch.Tensor): | ||
| moved[key] = value.to(target_device) | ||
| elif isinstance(value, tuple): | ||
| moved[key] = tuple(v.to(target_device) for v in value) | ||
| else: | ||
| moved[key] = value | ||
| return moved | ||
|
|
||
|
|
||
| LAB_CATEGORY_NAMES = MortalityPredictionStageNetMIMIC4.LAB_CATEGORY_NAMES | ||
|
|
||
|
|
||
| def decode_token(idx: int, processor, feature_key: str): | ||
| """Decode token index to human-readable string.""" | ||
| if processor is None or not hasattr(processor, "code_vocab"): | ||
| return str(idx) | ||
| reverse_vocab = {index: token for token, index in processor.code_vocab.items()} | ||
| token = reverse_vocab.get(idx, f"<UNK:{idx}>") | ||
|
|
||
| if feature_key == "icd_codes" and token not in {"<unk>", "<pad>"}: | ||
| desc = ICD_CODE_TO_DESC.get(token) | ||
| if desc: | ||
| return f"{token}: {desc}" | ||
|
|
||
| return token | ||
|
|
||
|
|
||
| def unravel(flat_index: int, shape: torch.Size): | ||
| """Convert flat index to multi-dimensional coordinates.""" | ||
| coords = [] | ||
| remaining = flat_index | ||
| for dim in reversed(shape): | ||
| coords.append(remaining % dim) | ||
| remaining //= dim | ||
| return list(reversed(coords)) | ||
|
|
||
|
|
||
| def print_top_attributions( | ||
| attributions, | ||
| batch, | ||
| processors, | ||
| top_k: int = 10, | ||
| ): | ||
| """Print top-k most important features from SHAP attributions.""" | ||
| for feature_key, attr in attributions.items(): | ||
| attr_cpu = attr.detach().cpu() | ||
| if attr_cpu.dim() == 0 or attr_cpu.size(0) == 0: | ||
| continue | ||
|
|
||
| feature_input = batch[feature_key] | ||
| if isinstance(feature_input, tuple): | ||
| feature_input = feature_input[1] | ||
| feature_input = feature_input.detach().cpu() | ||
|
|
||
| flattened = attr_cpu[0].flatten() | ||
| if flattened.numel() == 0: | ||
| continue | ||
|
|
||
| print(f"\nFeature: {feature_key}") | ||
| print(f" Shape: {attr_cpu[0].shape}") | ||
| print(f" Total attribution sum: {flattened.sum().item():+.6f}") | ||
| print(f" Mean attribution: {flattened.mean().item():+.6f}") | ||
|
|
||
| k = min(top_k, flattened.numel()) | ||
| top_values, top_indices = torch.topk(flattened.abs(), k=k) | ||
| processor = processors.get(feature_key) if processors else None | ||
| is_continuous = torch.is_floating_point(feature_input) | ||
|
|
||
| print(f"\n Top {k} most important features:") | ||
| for rank, (_, flat_idx) in enumerate(zip(top_values, top_indices), 1): | ||
| attribution_value = flattened[flat_idx].item() | ||
| coords = unravel(flat_idx.item(), attr_cpu[0].shape) | ||
|
|
||
| if is_continuous: | ||
| actual_value = feature_input[0][tuple(coords)].item() | ||
| label = "" | ||
| if feature_key == "labs" and len(coords) >= 1: | ||
| lab_idx = coords[-1] | ||
| if lab_idx < len(LAB_CATEGORY_NAMES): | ||
| label = f"{LAB_CATEGORY_NAMES[lab_idx]} " | ||
| print( | ||
| f" {rank:2d}. idx={coords} {label}value={actual_value:.4f} " | ||
| f"SHAP={attribution_value:+.6f}" | ||
| ) | ||
| else: | ||
| token_idx = int(feature_input[0][tuple(coords)].item()) | ||
| token = decode_token(token_idx, processor, feature_key) | ||
| print( | ||
| f" {rank:2d}. idx={coords} token='{token}' " | ||
| f"SHAP={attribution_value:+.6f}" | ||
| ) | ||
|
|
||
|
|
||
| # %% Run SHAP on a held-out sample | ||
| print("\n" + "="*80) | ||
| print("Initializing SHAP Explainer") | ||
| print("="*80) | ||
|
|
||
| # Initialize SHAP explainer with custom parameters | ||
| shap_explainer = ShapExplainer( | ||
| model, | ||
| use_embeddings=True, # Use embeddings for discrete features | ||
| n_background_samples=50, # Number of background samples | ||
| max_coalitions=200, # Number of feature coalitions to sample | ||
| random_seed=42, # For reproducibility | ||
| ) | ||
|
|
||
| print("\nSHAP Configuration:") | ||
| print(f" Use embeddings: {shap_explainer.use_embeddings}") | ||
| print(f" Background samples: {shap_explainer.n_background_samples}") | ||
| print(f" Max coalitions: {shap_explainer.max_coalitions}") | ||
| print(f" Regularization: {shap_explainer.regularization}") | ||
|
|
||
| # Get a sample from test set | ||
| sample_batch = next(iter(test_loader)) | ||
| sample_batch_device = move_batch_to_device(sample_batch, device) | ||
|
|
||
| # Get model prediction | ||
| with torch.no_grad(): | ||
| output = model(**sample_batch_device) | ||
| probs = output["y_prob"] | ||
| label_key = model.label_key | ||
| true_label = sample_batch_device[label_key] | ||
|
|
||
| # Handle binary classification (single probability output) | ||
| if probs.shape[-1] == 1: | ||
| prob_death = probs[0].item() | ||
| prob_survive = 1 - prob_death | ||
| preds = (probs > 0.5).long() | ||
| else: | ||
| # Multi-class classification | ||
| preds = torch.argmax(probs, dim=-1) | ||
| prob_survive = probs[0][0].item() | ||
| prob_death = probs[0][1].item() | ||
|
|
||
| print("\n" + "="*80) | ||
| print("Model Prediction for Sampled Patient") | ||
| print("="*80) | ||
| print(f" True label: {int(true_label.item())} {'(Deceased)' if true_label.item() == 1 else '(Survived)'}") | ||
| print(f" Predicted class: {int(preds.item())} {'(Deceased)' if preds.item() == 1 else '(Survived)'}") | ||
| print(f" Probabilities: [Survive={prob_survive:.4f}, Death={prob_death:.4f}]") | ||
|
|
||
| # Compute SHAP values | ||
| print("\n" + "="*80) | ||
| print("Computing SHAP Attributions (this may take a minute...)") | ||
| print("="*80) | ||
|
|
||
| attributions = shap_explainer.attribute(**sample_batch_device, target_class_idx=1) | ||
|
|
||
| print("\n" + "="*80) | ||
| print("SHAP Attribution Results") | ||
| print("="*80) | ||
| print("\nSHAP values explain the contribution of each feature to the model's") | ||
| print("prediction of MORTALITY (class 1). Positive values increase the") | ||
| print("mortality prediction, negative values decrease it.") | ||
|
|
||
| print_top_attributions(attributions, sample_batch_device, input_processors, top_k=15) | ||
|
|
||
| # %% Compare different baseline strategies | ||
| print("\n\n" + "="*80) | ||
| print("Testing Different Baseline Strategies") | ||
| print("="*80) | ||
|
|
||
| # 1. Automatic baseline (default) | ||
| print("\n1. Automatic baseline generation (recommended):") | ||
| attr_auto = shap_explainer.attribute(**sample_batch_device, target_class_idx=1) | ||
| print(f" Total attribution (icd_codes): {attr_auto['icd_codes'][0].sum().item():+.6f}") | ||
| print(f" Total attribution (labs): {attr_auto['labs'][0].sum().item():+.6f}") | ||
|
|
||
| # Note: Custom baselines for discrete features (like ICD codes) require careful | ||
| # construction to avoid invalid sequences. The automatic baseline generation | ||
| # handles this by sampling from the observed data distribution. | ||
|
|
||
| # %% Test callable interface | ||
| print("\n" + "="*80) | ||
| print("Testing Callable Interface") | ||
| print("="*80) | ||
|
|
||
| # Both methods should produce identical results (due to random_seed) | ||
| attr_from_attribute = shap_explainer.attribute(**sample_batch_device, target_class_idx=1) | ||
| attr_from_call = shap_explainer(**sample_batch_device, target_class_idx=1) | ||
|
|
||
| print("\nVerifying that explainer(**data) and explainer.attribute(**data) produce") | ||
| print("identical results when random_seed is set...") | ||
|
|
||
| all_close = True | ||
| for key in attr_from_attribute.keys(): | ||
| if not torch.allclose(attr_from_attribute[key], attr_from_call[key], atol=1e-6): | ||
| all_close = False | ||
| print(f" ❌ {key}: Results differ!") | ||
| else: | ||
| print(f" ✓ {key}: Results match") | ||
|
|
||
| if all_close: | ||
| print("\n✓ All attributions match! Callable interface works correctly.") | ||
| else: | ||
| print("\n❌ Some attributions differ. Check random seed configuration.") | ||
|
|
||
| print("\n" + "="*80) | ||
| print("SHAP Analysis Complete") | ||
| print("="*80) | ||
|
|
||
| # %% | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think SHAP should be compatible with discrete tokens like ICD codes here? Correct me if I'm wrong. Will look deeper into understanding the full implementation of SHAP here later when I'm more congitively sound.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated the SHAPExplainer class instance creation to pass the just the model and default all other values including "use_embeddings" inside the init method. Yes the SHAP works for ICD codes but will use the embeddings from the input model.