diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 000000000..e69de29bb diff --git a/docs/api/interpret/pyhealth.interpret.methods.shap.rst b/docs/api/interpret/pyhealth.interpret.methods.shap.rst new file mode 100644 index 000000000..aedd5ce46 --- /dev/null +++ b/docs/api/interpret/pyhealth.interpret.methods.shap.rst @@ -0,0 +1,21 @@ +pyhealth.interpret.methods.ShapExplainer +======================================== + +Overview +-------- + +The SHAP (SHapley Additive exPlanations) method computes feature attributions for PyHealth models +based on coalitional game theory. This helps identify which features (e.g., diagnosis codes, +lab values) that most influenced a model's prediction. + +For a complete working example, see: +``examples/shap_mortality_mimic4_stagenet.py`` + +API Reference +------------- + +.. autoclass:: pyhealth.interpret.methods.ShapExplainer + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource diff --git a/examples/shap_stagenet_mimic4.ipynb b/examples/shap_stagenet_mimic4.ipynb new file mode 100644 index 000000000..a871d7326 --- /dev/null +++ b/examples/shap_stagenet_mimic4.ipynb @@ -0,0 +1,1693 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "14fe2649", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "14fe2649", + "outputId": "c7d5f834-b9ac-45b9-d67e-65e2e73e2924" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "PyTorch version: 2.9.0+cu126\n", + "CUDA available: True\n", + "CUDA version: 12.6\n", + "GPU Device: Tesla T4\n", + "GPU Memory: 15.83 GB\n" + ] + } + ], + "source": [ + "# Check GPU availability\n", + "import torch\n", + "\n", + "print(f\"PyTorch version: {torch.__version__}\")\n", + "print(f\"CUDA available: {torch.cuda.is_available()}\")\n", + "if torch.cuda.is_available():\n", + " print(f\"CUDA version: {torch.version.cuda}\")\n", + " print(f\"GPU Device: {torch.cuda.get_device_name(0)}\")\n", + " print(f\"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB\")\n", + "else:\n", + " print(\"⚠️ GPU not available. Please enable GPU: Runtime > Change runtime type > GPU\")" + ] + }, + { + "cell_type": "markdown", + "id": "0428f9a4", + "metadata": { + "id": "0428f9a4" + }, + "source": [ + "## 1. Installation\n", + "\n", + "Install PyHealth and required dependencies:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "c349da42", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "id": "c349da42", + "outputId": "eb83c82a-9472-4365-e849-d7fdc89d3f54" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Collecting git+https://github.com/naveenkcb/PyHealth.git\n", + " Cloning https://github.com/naveenkcb/PyHealth.git to /tmp/pip-req-build-u5cek8co\n", + " Running command git clone --filter=blob:none --quiet https://github.com/naveenkcb/PyHealth.git /tmp/pip-req-build-u5cek8co\n", + " Resolved https://github.com/naveenkcb/PyHealth.git to commit 402c39fa6cf3509dbcd83810f812d4afc1dcd44f\n", + " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", + " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "Requirement already satisfied: accelerate in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (1.11.0)\n", + "Collecting mne~=1.10.0 (from pyhealth==2.0a8)\n", + " Downloading mne-1.10.2-py3-none-any.whl.metadata (21 kB)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (3.5)\n", + "Collecting numpy~=1.26.4 (from pyhealth==2.0a8)\n", + " Downloading numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m61.0/61.0 kB\u001b[0m \u001b[31m4.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting ogb>=1.3.5 (from pyhealth==2.0a8)\n", + " Downloading ogb-1.3.6-py3-none-any.whl.metadata (6.2 kB)\n", + "Collecting pandarallel~=1.6.5 (from pyhealth==2.0a8)\n", + " Downloading pandarallel-1.6.5.tar.gz (14 kB)\n", + " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + "Collecting pandas~=2.3.1 (from pyhealth==2.0a8)\n", + " Downloading pandas-2.3.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.metadata (91 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m91.2/91.2 kB\u001b[0m \u001b[31m7.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: peft in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (0.18.0)\n", + "Requirement already satisfied: polars~=1.31.0 in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (1.31.0)\n", + "Requirement already satisfied: pydantic~=2.11.7 in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (2.11.10)\n", + "Collecting rdkit (from pyhealth==2.0a8)\n", + " Downloading rdkit-2025.9.1-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (4.1 kB)\n", + "Collecting scikit-learn~=1.7.0 (from pyhealth==2.0a8)\n", + " Downloading scikit_learn-1.7.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (11 kB)\n", + "Requirement already satisfied: torchvision in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (0.24.0+cu126)\n", + "Collecting torch~=2.7.1 (from pyhealth==2.0a8)\n", + " Downloading torch-2.7.1-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (29 kB)\n", + "Requirement already satisfied: tqdm in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (4.67.1)\n", + "Collecting transformers~=4.53.2 (from pyhealth==2.0a8)\n", + " Downloading transformers-4.53.3-py3-none-any.whl.metadata (40 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.9/40.9 kB\u001b[0m \u001b[31m3.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: urllib3~=2.5.0 in /usr/local/lib/python3.12/dist-packages (from pyhealth==2.0a8) (2.5.0)\n", + "Requirement already satisfied: decorator in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (4.4.2)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (3.1.6)\n", + "Requirement already satisfied: lazy-loader>=0.3 in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (0.4)\n", + "Requirement already satisfied: matplotlib>=3.7 in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (3.10.0)\n", + "Requirement already satisfied: packaging in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (25.0)\n", + "Requirement already satisfied: pooch>=1.5 in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (1.8.2)\n", + "Requirement already satisfied: scipy>=1.11 in /usr/local/lib/python3.12/dist-packages (from mne~=1.10.0->pyhealth==2.0a8) (1.16.3)\n", + "Requirement already satisfied: six>=1.12.0 in /usr/local/lib/python3.12/dist-packages (from ogb>=1.3.5->pyhealth==2.0a8) (1.17.0)\n", + "Collecting outdated>=0.2.0 (from ogb>=1.3.5->pyhealth==2.0a8)\n", + " Downloading outdated-0.2.2-py2.py3-none-any.whl.metadata (4.7 kB)\n", + "Requirement already satisfied: dill>=0.3.1 in /usr/local/lib/python3.12/dist-packages (from pandarallel~=1.6.5->pyhealth==2.0a8) (0.3.8)\n", + "Requirement already satisfied: psutil in /usr/local/lib/python3.12/dist-packages (from pandarallel~=1.6.5->pyhealth==2.0a8) (5.9.5)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.12/dist-packages (from pandas~=2.3.1->pyhealth==2.0a8) (2.9.0.post0)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.12/dist-packages (from pandas~=2.3.1->pyhealth==2.0a8) (2025.2)\n", + "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.12/dist-packages (from pandas~=2.3.1->pyhealth==2.0a8) (2025.2)\n", + "Requirement already satisfied: annotated-types>=0.6.0 in /usr/local/lib/python3.12/dist-packages (from pydantic~=2.11.7->pyhealth==2.0a8) (0.7.0)\n", + "Requirement already satisfied: pydantic-core==2.33.2 in /usr/local/lib/python3.12/dist-packages (from pydantic~=2.11.7->pyhealth==2.0a8) (2.33.2)\n", + "Requirement already satisfied: typing-extensions>=4.12.2 in /usr/local/lib/python3.12/dist-packages (from pydantic~=2.11.7->pyhealth==2.0a8) (4.15.0)\n", + "Requirement already satisfied: typing-inspection>=0.4.0 in /usr/local/lib/python3.12/dist-packages (from pydantic~=2.11.7->pyhealth==2.0a8) (0.4.2)\n", + "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn~=1.7.0->pyhealth==2.0a8) (1.5.2)\n", + "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn~=1.7.0->pyhealth==2.0a8) (3.6.0)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (3.20.0)\n", + "Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (75.2.0)\n", + "Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (1.13.3)\n", + "Requirement already satisfied: fsspec in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (2025.3.0)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.6.77)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.6.77)\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.6.80 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.6.80)\n", + "Collecting nvidia-cudnn-cu12==9.5.1.17 (from torch~=2.7.1->pyhealth==2.0a8)\n", + " Downloading nvidia_cudnn_cu12-9.5.1.17-py3-none-manylinux_2_28_x86_64.whl.metadata (1.6 kB)\n", + "Requirement already satisfied: nvidia-cublas-cu12==12.6.4.1 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.6.4.1)\n", + "Requirement already satisfied: nvidia-cufft-cu12==11.3.0.4 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (11.3.0.4)\n", + "Requirement already satisfied: nvidia-curand-cu12==10.3.7.77 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (10.3.7.77)\n", + "Requirement already satisfied: nvidia-cusolver-cu12==11.7.1.2 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (11.7.1.2)\n", + "Requirement already satisfied: nvidia-cusparse-cu12==12.5.4.2 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.5.4.2)\n", + "Collecting nvidia-cusparselt-cu12==0.6.3 (from torch~=2.7.1->pyhealth==2.0a8)\n", + " Downloading nvidia_cusparselt_cu12-0.6.3-py3-none-manylinux2014_x86_64.whl.metadata (6.8 kB)\n", + "Collecting nvidia-nccl-cu12==2.26.2 (from torch~=2.7.1->pyhealth==2.0a8)\n", + " Downloading nvidia_nccl_cu12-2.26.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (2.0 kB)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.6.77)\n", + "Requirement already satisfied: nvidia-nvjitlink-cu12==12.6.85 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (12.6.85)\n", + "Requirement already satisfied: nvidia-cufile-cu12==1.11.1.6 in /usr/local/lib/python3.12/dist-packages (from torch~=2.7.1->pyhealth==2.0a8) (1.11.1.6)\n", + "Collecting triton==3.3.1 (from torch~=2.7.1->pyhealth==2.0a8)\n", + " Downloading triton-3.3.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (1.5 kB)\n", + "Requirement already satisfied: huggingface-hub<1.0,>=0.30.0 in /usr/local/lib/python3.12/dist-packages (from transformers~=4.53.2->pyhealth==2.0a8) (0.36.0)\n", + "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.12/dist-packages (from transformers~=4.53.2->pyhealth==2.0a8) (6.0.3)\n", + "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.12/dist-packages (from transformers~=4.53.2->pyhealth==2.0a8) (2024.11.6)\n", + "Requirement already satisfied: requests in /usr/local/lib/python3.12/dist-packages (from transformers~=4.53.2->pyhealth==2.0a8) (2.32.4)\n", + "Collecting tokenizers<0.22,>=0.21 (from transformers~=4.53.2->pyhealth==2.0a8)\n", + " Downloading tokenizers-0.21.4-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)\n", + "Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.12/dist-packages (from transformers~=4.53.2->pyhealth==2.0a8) (0.7.0)\n", + "Requirement already satisfied: Pillow in /usr/local/lib/python3.12/dist-packages (from rdkit->pyhealth==2.0a8) (11.3.0)\n", + "INFO: pip is looking at multiple versions of torchvision to determine which version is compatible with other requirements. This could take a while.\n", + "Collecting torchvision (from pyhealth==2.0a8)\n", + " Downloading torchvision-0.24.1-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (5.9 kB)\n", + " Downloading torchvision-0.24.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (5.9 kB)\n", + " Downloading torchvision-0.23.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (6.1 kB)\n", + " Downloading torchvision-0.22.1-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (6.1 kB)\n", + "Requirement already satisfied: hf-xet<2.0.0,>=1.1.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<1.0,>=0.30.0->transformers~=4.53.2->pyhealth==2.0a8) (1.2.0)\n", + "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.7->mne~=1.10.0->pyhealth==2.0a8) (1.3.3)\n", + "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.7->mne~=1.10.0->pyhealth==2.0a8) (0.12.1)\n", + "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.7->mne~=1.10.0->pyhealth==2.0a8) (4.60.1)\n", + "Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.7->mne~=1.10.0->pyhealth==2.0a8) (1.4.9)\n", + "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.7->mne~=1.10.0->pyhealth==2.0a8) (3.2.5)\n", + "Collecting littleutils (from outdated>=0.2.0->ogb>=1.3.5->pyhealth==2.0a8)\n", + " Downloading littleutils-0.2.4-py3-none-any.whl.metadata (679 bytes)\n", + "Requirement already satisfied: platformdirs>=2.5.0 in /usr/local/lib/python3.12/dist-packages (from pooch>=1.5->mne~=1.10.0->pyhealth==2.0a8) (4.5.0)\n", + "Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests->transformers~=4.53.2->pyhealth==2.0a8) (3.4.4)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests->transformers~=4.53.2->pyhealth==2.0a8) (3.11)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests->transformers~=4.53.2->pyhealth==2.0a8) (2025.11.12)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch~=2.7.1->pyhealth==2.0a8) (1.3.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->mne~=1.10.0->pyhealth==2.0a8) (3.0.3)\n", + "Downloading mne-1.10.2-py3-none-any.whl (7.4 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.4/7.4 MB\u001b[0m \u001b[31m68.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.0 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m18.0/18.0 MB\u001b[0m \u001b[31m139.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading ogb-1.3.6-py3-none-any.whl (78 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m78.8/78.8 kB\u001b[0m \u001b[31m8.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading pandas-2.3.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (12.4 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m12.4/12.4 MB\u001b[0m \u001b[31m140.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading scikit_learn-1.7.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (9.5 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m9.5/9.5 MB\u001b[0m \u001b[31m149.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading torch-2.7.1-cp312-cp312-manylinux_2_28_x86_64.whl (821.0 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m821.0/821.0 MB\u001b[0m \u001b[31m843.9 kB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cudnn_cu12-9.5.1.17-py3-none-manylinux_2_28_x86_64.whl (571.0 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m571.0/571.0 MB\u001b[0m \u001b[31m1.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cusparselt_cu12-0.6.3-py3-none-manylinux2014_x86_64.whl (156.8 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m156.8/156.8 MB\u001b[0m \u001b[31m6.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_nccl_cu12-2.26.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (201.3 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m201.3/201.3 MB\u001b[0m \u001b[31m2.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading triton-3.3.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (155.7 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m155.7/155.7 MB\u001b[0m \u001b[31m7.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading transformers-4.53.3-py3-none-any.whl (10.8 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m10.8/10.8 MB\u001b[0m \u001b[31m112.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading rdkit-2025.9.1-cp312-cp312-manylinux_2_28_x86_64.whl (36.2 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m36.2/36.2 MB\u001b[0m \u001b[31m20.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading torchvision-0.22.1-cp312-cp312-manylinux_2_28_x86_64.whl (7.5 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.5/7.5 MB\u001b[0m \u001b[31m92.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading outdated-0.2.2-py2.py3-none-any.whl (7.5 kB)\n", + "Downloading tokenizers-0.21.4-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.1/3.1 MB\u001b[0m \u001b[31m96.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading littleutils-0.2.4-py3-none-any.whl (8.1 kB)\n", + "Building wheels for collected packages: pyhealth, pandarallel\n", + " Building wheel for pyhealth (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for pyhealth: filename=pyhealth-2.0a8-py3-none-any.whl size=418794 sha256=fec988935069784916af4a5994a25563675c3763c4db3a6f63f4fe3aa657a053\n", + " Stored in directory: /tmp/pip-ephem-wheel-cache-9qply4d1/wheels/e9/10/11/3146f609c6b24edf823d697c4a93da2e447bada2d1fb3fb819\n", + " Building wheel for pandarallel (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for pandarallel: filename=pandarallel-1.6.5-py3-none-any.whl size=16674 sha256=ee68b46c34bfa8e757d6cc354498688a15c3becf3ab8801a60484f18526e9fec\n", + " Stored in directory: /root/.cache/pip/wheels/46/f9/0d/40c9cd74a7cb8dc8fe57e8d6c3c19e2c730449c0d3f2bf66b5\n", + "Successfully built pyhealth pandarallel\n", + "Installing collected packages: nvidia-cusparselt-cu12, triton, nvidia-nccl-cu12, nvidia-cudnn-cu12, numpy, littleutils, rdkit, pandas, outdated, torch, tokenizers, scikit-learn, pandarallel, transformers, torchvision, ogb, mne, pyhealth\n", + " Attempting uninstall: nvidia-cusparselt-cu12\n", + " Found existing installation: nvidia-cusparselt-cu12 0.7.1\n", + " Uninstalling nvidia-cusparselt-cu12-0.7.1:\n", + " Successfully uninstalled nvidia-cusparselt-cu12-0.7.1\n", + " Attempting uninstall: triton\n", + " Found existing installation: triton 3.5.0\n", + " Uninstalling triton-3.5.0:\n", + " Successfully uninstalled triton-3.5.0\n", + " Attempting uninstall: nvidia-nccl-cu12\n", + " Found existing installation: nvidia-nccl-cu12 2.27.5\n", + " Uninstalling nvidia-nccl-cu12-2.27.5:\n", + " Successfully uninstalled nvidia-nccl-cu12-2.27.5\n", + " Attempting uninstall: nvidia-cudnn-cu12\n", + " Found existing installation: nvidia-cudnn-cu12 9.10.2.21\n", + " Uninstalling nvidia-cudnn-cu12-9.10.2.21:\n", + " Successfully uninstalled nvidia-cudnn-cu12-9.10.2.21\n", + " Attempting uninstall: numpy\n", + " Found existing installation: numpy 2.0.2\n", + " Uninstalling numpy-2.0.2:\n", + " Successfully uninstalled numpy-2.0.2\n", + " Attempting uninstall: pandas\n", + " Found existing installation: pandas 2.2.2\n", + " Uninstalling pandas-2.2.2:\n", + " Successfully uninstalled pandas-2.2.2\n", + " Attempting uninstall: torch\n", + " Found existing installation: torch 2.9.0+cu126\n", + " Uninstalling torch-2.9.0+cu126:\n", + " Successfully uninstalled torch-2.9.0+cu126\n", + " Attempting uninstall: tokenizers\n", + " Found existing installation: tokenizers 0.22.1\n", + " Uninstalling tokenizers-0.22.1:\n", + " Successfully uninstalled tokenizers-0.22.1\n", + " Attempting uninstall: scikit-learn\n", + " Found existing installation: scikit-learn 1.6.1\n", + " Uninstalling scikit-learn-1.6.1:\n", + " Successfully uninstalled scikit-learn-1.6.1\n", + " Attempting uninstall: transformers\n", + " Found existing installation: transformers 4.57.1\n", + " Uninstalling transformers-4.57.1:\n", + " Successfully uninstalled transformers-4.57.1\n", + " Attempting uninstall: torchvision\n", + " Found existing installation: torchvision 0.24.0+cu126\n", + " Uninstalling torchvision-0.24.0+cu126:\n", + " Successfully uninstalled torchvision-0.24.0+cu126\n", + "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "google-colab 1.0.0 requires pandas==2.2.2, but you have pandas 2.3.3 which is incompatible.\n", + "jaxlib 0.7.2 requires numpy>=2.0, but you have numpy 1.26.4 which is incompatible.\n", + "torchaudio 2.9.0+cu126 requires torch==2.9.0, but you have torch 2.7.1 which is incompatible.\n", + "opencv-python-headless 4.12.0.88 requires numpy<2.3.0,>=2; python_version >= \"3.9\", but you have numpy 1.26.4 which is incompatible.\n", + "shap 0.50.0 requires numpy>=2, but you have numpy 1.26.4 which is incompatible.\n", + "jax 0.7.2 requires numpy>=2.0, but you have numpy 1.26.4 which is incompatible.\n", + "pytensor 2.35.1 requires numpy>=2.0, but you have numpy 1.26.4 which is incompatible.\n", + "opencv-python 4.12.0.88 requires numpy<2.3.0,>=2; python_version >= \"3.9\", but you have numpy 1.26.4 which is incompatible.\n", + "opencv-contrib-python 4.12.0.88 requires numpy<2.3.0,>=2; python_version >= \"3.9\", but you have numpy 1.26.4 which is incompatible.\u001b[0m\u001b[31m\n", + "\u001b[0mSuccessfully installed littleutils-0.2.4 mne-1.10.2 numpy-1.26.4 nvidia-cudnn-cu12-9.5.1.17 nvidia-cusparselt-cu12-0.6.3 nvidia-nccl-cu12-2.26.2 ogb-1.3.6 outdated-0.2.2 pandarallel-1.6.5 pandas-2.3.3 pyhealth-2.0a8 rdkit-2025.9.1 scikit-learn-1.7.2 tokenizers-0.21.4 torch-2.7.1 torchvision-0.22.1 transformers-4.53.3 triton-3.3.1\n" + ] + }, + { + "output_type": "display_data", + "data": { + "application/vnd.colab-display-data+json": { + "pip_warning": { + "packages": [ + "numpy", + "torch", + "torchgen" + ] + }, + "id": "a21cf9550c9c4d1a916e50ccaf894bf2" + } + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Requirement already satisfied: torch in /usr/local/lib/python3.12/dist-packages (2.7.1)\n", + "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.12/dist-packages (1.7.2)\n", + "Requirement already satisfied: pandas in /usr/local/lib/python3.12/dist-packages (2.3.3)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.12/dist-packages (1.26.4)\n", + "Requirement already satisfied: tqdm in /usr/local/lib/python3.12/dist-packages (4.67.1)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from torch) (3.20.0)\n", + "Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.12/dist-packages (from torch) (4.15.0)\n", + "Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from torch) (75.2.0)\n", + "Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch) (1.13.3)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.12/dist-packages (from torch) (3.5)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch) (3.1.6)\n", + "Requirement already satisfied: fsspec in /usr/local/lib/python3.12/dist-packages (from torch) (2025.3.0)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch) (12.6.77)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch) (12.6.77)\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.6.80 in /usr/local/lib/python3.12/dist-packages (from torch) (12.6.80)\n", + "Requirement already satisfied: nvidia-cudnn-cu12==9.5.1.17 in /usr/local/lib/python3.12/dist-packages (from torch) (9.5.1.17)\n", + "Requirement already satisfied: nvidia-cublas-cu12==12.6.4.1 in /usr/local/lib/python3.12/dist-packages (from torch) (12.6.4.1)\n", + "Requirement already satisfied: nvidia-cufft-cu12==11.3.0.4 in /usr/local/lib/python3.12/dist-packages (from torch) (11.3.0.4)\n", + "Requirement already satisfied: nvidia-curand-cu12==10.3.7.77 in /usr/local/lib/python3.12/dist-packages (from torch) (10.3.7.77)\n", + "Requirement already satisfied: nvidia-cusolver-cu12==11.7.1.2 in /usr/local/lib/python3.12/dist-packages (from torch) (11.7.1.2)\n", + "Requirement already satisfied: nvidia-cusparse-cu12==12.5.4.2 in /usr/local/lib/python3.12/dist-packages (from torch) (12.5.4.2)\n", + "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.3 in /usr/local/lib/python3.12/dist-packages (from torch) (0.6.3)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.26.2 in /usr/local/lib/python3.12/dist-packages (from torch) (2.26.2)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch) (12.6.77)\n", + "Requirement already satisfied: nvidia-nvjitlink-cu12==12.6.85 in /usr/local/lib/python3.12/dist-packages (from torch) (12.6.85)\n", + "Requirement already satisfied: nvidia-cufile-cu12==1.11.1.6 in /usr/local/lib/python3.12/dist-packages (from torch) (1.11.1.6)\n", + "Requirement already satisfied: triton==3.3.1 in /usr/local/lib/python3.12/dist-packages (from torch) (3.3.1)\n", + "Requirement already satisfied: scipy>=1.8.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn) (1.16.3)\n", + "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn) (1.5.2)\n", + "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn) (3.6.0)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.12/dist-packages (from pandas) (2.9.0.post0)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.12/dist-packages (from pandas) (2025.2)\n", + "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.12/dist-packages (from pandas) (2025.2)\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.12/dist-packages (from python-dateutil>=2.8.2->pandas) (1.17.0)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch) (1.3.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->torch) (3.0.3)\n" + ] + } + ], + "source": [ + "import os\n", + "import sys\n", + "\n", + "# Uninstall existing pandas to avoid conflicts\n", + "#!pip uninstall -y pandas\n", + "\n", + "# Install a compatible pandas version (e.g., 2.2.2 as required by google-colab)\n", + "# Then reinstall pyhealth and polars to ensure they are built against this pandas version.\n", + "#!pip install pandas==2.2.2 pyhealth polars -q\n", + "\n", + "# If using development version from GitHub:\n", + "!pip install git+https://github.com/naveenkcb/PyHealth.git\n", + "!pip install torch scikit-learn pandas numpy tqdm\n" + ] + }, + { + "cell_type": "markdown", + "id": "a41f4ba9", + "metadata": { + "id": "a41f4ba9" + }, + "source": [ + "## 2. Download MIMIC-IV Demo Dataset\n", + "\n", + "Download the MIMIC-IV demo dataset. You'll need PhysioNet credentials." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "fc9b20b8", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "fc9b20b8", + "outputId": "7d9ebe62-a243-4487-876e-8e251ab5c901" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n", + "Data directory: /content/mimic-iv-demo/2.2\n", + "\n", + "⚠️ Please download MIMIC-IV demo dataset from:\n", + "https://physionet.org/content/mimic-iv-demo/2.2/\n", + "\n", + "Or mount Google Drive if you have the dataset stored there.\n" + ] + } + ], + "source": [ + "import os\n", + "from pathlib import Path\n", + "\n", + "# Mount Google Drive\n", + "from google.colab import drive\n", + "drive.mount('/content/drive')\n", + "\n", + "# Create data directory\n", + "data_dir = Path(\"/content/mimic-iv-demo/2.2\")\n", + "data_dir.mkdir(parents=True, exist_ok=True)\n", + "\n", + "# Download MIMIC-IV demo dataset\n", + "# Note: Replace with actual download method or mount Google Drive with dataset\n", + "print(f\"Data directory: {data_dir}\")\n", + "print(\"\\n⚠️ Please download MIMIC-IV demo dataset from:\")\n", + "print(\"https://physionet.org/content/mimic-iv-demo/2.2/\")\n", + "print(\"\\nOr mount Google Drive if you have the dataset stored there.\")" + ] + }, + { + "cell_type": "markdown", + "id": "ae78c2a5", + "metadata": { + "id": "ae78c2a5" + }, + "source": [ + "## 3. Load Pre-trained Model Checkpoint\n", + "\n", + "Upload or download the pre-trained StageNet model checkpoint." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "e5992a83", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "e5992a83", + "outputId": "ff6c5262-74fc-451e-faef-1968ca2d21ee" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Model checkpoint should be at: /content/resources/best.ckpt\n", + "Checkpoint exists: True\n" + ] + } + ], + "source": [ + "# Create resources directory\n", + "resources_dir = Path(\"/content/resources\")\n", + "resources_dir.mkdir(parents=True, exist_ok=True)\n", + "\n", + "# Upload model checkpoint\n", + "# You can use Google Colab's file upload or download from URL\n", + "# from google.colab import files\n", + "# uploaded = files.upload()\n", + "\n", + "checkpoint_path = resources_dir / \"best.ckpt\"\n", + "print(f\"Model checkpoint should be at: {checkpoint_path}\")\n", + "print(f\"Checkpoint exists: {checkpoint_path.exists()}\")" + ] + }, + { + "cell_type": "markdown", + "id": "a6898e63", + "metadata": { + "id": "a6898e63" + }, + "source": [ + "## 4. Load Dataset and Processors" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "11338065", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 682 + }, + "id": "11338065", + "outputId": "d78e0412-d952-47f8-b09f-41588bf8a469" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Using default EHR config: /usr/local/lib/python3.12/dist-packages/pyhealth/datasets/configs/mimic4_ehr.yaml\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "INFO:pyhealth.datasets.mimic4:Using default EHR config: /usr/local/lib/python3.12/dist-packages/pyhealth/datasets/configs/mimic4_ehr.yaml\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Memory usage Before initializing mimic4_ehr: 1574.5 MB\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "INFO:pyhealth.datasets.mimic4:Memory usage Before initializing mimic4_ehr: 1574.5 MB\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Duplicate table names in tables list. Removing duplicates.\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "WARNING:pyhealth.datasets.base_dataset:Duplicate table names in tables list. Removing duplicates.\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Initializing mimic4_ehr dataset from https://physionet.org/files/mimic-iv-demo/2.2/ (dev mode: False)\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "INFO:pyhealth.datasets.base_dataset:Initializing mimic4_ehr dataset from https://physionet.org/files/mimic-iv-demo/2.2/ (dev mode: False)\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Scanning table: admissions from https://physionet.org/files/mimic-iv-demo/2.2/hosp/admissions.csv.gz\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "INFO:pyhealth.datasets.base_dataset:Scanning table: admissions from https://physionet.org/files/mimic-iv-demo/2.2/hosp/admissions.csv.gz\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Scanning table: procedures_icd from https://physionet.org/files/mimic-iv-demo/2.2/hosp/procedures_icd.csv.gz\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "INFO:pyhealth.datasets.base_dataset:Scanning table: procedures_icd from https://physionet.org/files/mimic-iv-demo/2.2/hosp/procedures_icd.csv.gz\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Joining with table: https://physionet.org/files/mimic-iv-demo/2.2/hosp/admissions.csv.gz\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "INFO:pyhealth.datasets.base_dataset:Joining with table: https://physionet.org/files/mimic-iv-demo/2.2/hosp/admissions.csv.gz\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Scanning table: diagnoses_icd from https://physionet.org/files/mimic-iv-demo/2.2/hosp/diagnoses_icd.csv.gz\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "INFO:pyhealth.datasets.base_dataset:Scanning table: diagnoses_icd from https://physionet.org/files/mimic-iv-demo/2.2/hosp/diagnoses_icd.csv.gz\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Joining with table: https://physionet.org/files/mimic-iv-demo/2.2/hosp/admissions.csv.gz\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "INFO:pyhealth.datasets.base_dataset:Joining with table: https://physionet.org/files/mimic-iv-demo/2.2/hosp/admissions.csv.gz\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Scanning table: icustays from https://physionet.org/files/mimic-iv-demo/2.2/icu/icustays.csv.gz\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "INFO:pyhealth.datasets.base_dataset:Scanning table: icustays from https://physionet.org/files/mimic-iv-demo/2.2/icu/icustays.csv.gz\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Scanning table: patients from https://physionet.org/files/mimic-iv-demo/2.2/hosp/patients.csv.gz\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "INFO:pyhealth.datasets.base_dataset:Scanning table: patients from https://physionet.org/files/mimic-iv-demo/2.2/hosp/patients.csv.gz\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Scanning table: labevents from https://physionet.org/files/mimic-iv-demo/2.2/hosp/labevents.csv.gz\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "INFO:pyhealth.datasets.base_dataset:Scanning table: labevents from https://physionet.org/files/mimic-iv-demo/2.2/hosp/labevents.csv.gz\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Joining with table: https://physionet.org/files/mimic-iv-demo/2.2/hosp/d_labitems.csv.gz\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "INFO:pyhealth.datasets.base_dataset:Joining with table: https://physionet.org/files/mimic-iv-demo/2.2/hosp/d_labitems.csv.gz\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Memory usage After initializing mimic4_ehr: 1574.9 MB\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "INFO:pyhealth.datasets.mimic4:Memory usage After initializing mimic4_ehr: 1574.9 MB\n" + ] + }, + { + "output_type": "error", + "ename": "TypeError", + "evalue": "object of type 'MIMIC4EHRDataset' has no len()", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/tmp/ipython-input-64345272.py\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 31\u001b[0m )\n\u001b[1;32m 32\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 33\u001b[0;31m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"Dataset loaded: {len(dataset)} patients\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m: object of type 'MIMIC4EHRDataset' has no len()" + ] + } + ], + "source": [ + "from pathlib import Path\n", + "import polars as pl\n", + "import torch\n", + "import sys\n", + "import subprocess\n", + "import importlib\n", + "\n", + "\n", + "from pyhealth.datasets import (\n", + " MIMIC4EHRDataset,\n", + " get_dataloader,\n", + " load_processors,\n", + " split_by_patient,\n", + ")\n", + "from pyhealth.interpret.methods import ShapExplainer\n", + "from pyhealth.models import StageNet\n", + "from pyhealth.tasks import MortalityPredictionStageNetMIMIC4\n", + "\n", + "MIMIC4_PATH = \"https://physionet.org/files/mimic-iv-demo/2.2/\"\n", + "# Configure dataset location\n", + "dataset = MIMIC4EHRDataset(\n", + " #root=\"/content/mimic-iv-demo/2.2/\", # Adjust path as needed\n", + " root=MIMIC4_PATH,\n", + " tables=[\n", + " \"patients\",\n", + " \"admissions\",\n", + " \"diagnoses_icd\",\n", + " \"procedures_icd\",\n", + " \"labevents\",\n", + " ],\n", + ")\n", + "\n", + "print(f\"Dataset loaded: {len(dataset)} patients\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "1a4d785d", + "metadata": { + "id": "1a4d785d", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "25686b72-cfd7-4766-c527-3e5e920da519" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "✓ Loaded input processors from /content/resources/input_processors.pkl\n", + "✓ Loaded output processors from /content/resources/output_processors.pkl\n", + "Setting task MortalityPredictionStageNetMIMIC4 for mimic4_ehr base dataset...\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "INFO:pyhealth.datasets.base_dataset:Setting task MortalityPredictionStageNetMIMIC4 for mimic4_ehr base dataset...\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Generating samples with 1 worker(s)...\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "INFO:pyhealth.datasets.base_dataset:Generating samples with 1 worker(s)...\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Collecting global event dataframe...\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "INFO:pyhealth.datasets.base_dataset:Collecting global event dataframe...\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Collected dataframe with shape: (113470, 39)\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "INFO:pyhealth.datasets.base_dataset:Collected dataframe with shape: (113470, 39)\n", + "Generating samples for MortalityPredictionStageNetMIMIC4 with 1 worker: 100%|██████████| 100/100 [00:16<00:00, 6.18it/s]" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Caching samples to /content/.cache/pyhealth/mimic4_stagenet_mortality/MortalityPredictionStageNetMIMIC4.parquet\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "\n", + "INFO:pyhealth.datasets.base_dataset:Caching samples to /content/.cache/pyhealth/mimic4_stagenet_mortality/MortalityPredictionStageNetMIMIC4.parquet\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Failed to cache samples: failed to determine supertype of list[f64] and list[list[str]]\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "WARNING:pyhealth.datasets.base_dataset:Failed to cache samples: failed to determine supertype of list[f64] and list[list[str]]\n", + "Processing samples: 100%|██████████| 100/100 [00:00<00:00, 1923.08it/s]" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Generated 100 samples for task MortalityPredictionStageNetMIMIC4\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "\n", + "INFO:pyhealth.datasets.base_dataset:Generated 100 samples for task MortalityPredictionStageNetMIMIC4\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Total samples: 100\n" + ] + } + ], + "source": [ + "# Load processors and set task\n", + "input_processors, output_processors = load_processors(\"/content/resources/\")\n", + "\n", + "sample_dataset = dataset.set_task(\n", + " MortalityPredictionStageNetMIMIC4(),\n", + " cache_dir=\"/content/.cache/pyhealth/mimic4_stagenet_mortality\",\n", + " input_processors=input_processors,\n", + " output_processors=output_processors,\n", + ")\n", + "print(f\"Total samples: {len(sample_dataset)}\")" + ] + }, + { + "cell_type": "markdown", + "id": "d3c5f116", + "metadata": { + "id": "d3c5f116" + }, + "source": [ + "## 5. Load ICD Code Descriptions" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "4594eea4", + "metadata": { + "id": "4594eea4", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "a5f32db8-60fb-4fcb-c024-f1a3e6ce861e" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Loaded 0 ICD code descriptions\n" + ] + } + ], + "source": [ + "def load_icd_description_map(dataset_root: str) -> dict:\n", + " \"\"\"Load ICD code → long title mappings from MIMIC-IV reference tables.\"\"\"\n", + " mapping = {}\n", + " root_path = Path(dataset_root).expanduser()\n", + " diag_path = root_path / \"hosp\" / \"d_icd_diagnoses.csv.gz\"\n", + " proc_path = root_path / \"hosp\" / \"d_icd_procedures.csv.gz\"\n", + "\n", + " icd_dtype = {\"icd_code\": pl.Utf8, \"long_title\": pl.Utf8}\n", + "\n", + " if diag_path.exists():\n", + " diag_df = pl.read_csv(\n", + " diag_path,\n", + " columns=[\"icd_code\", \"long_title\"],\n", + " dtypes=icd_dtype,\n", + " )\n", + " mapping.update(\n", + " zip(diag_df[\"icd_code\"].to_list(), diag_df[\"long_title\"].to_list())\n", + " )\n", + "\n", + " if proc_path.exists():\n", + " proc_df = pl.read_csv(\n", + " proc_path,\n", + " columns=[\"icd_code\", \"long_title\"],\n", + " dtypes=icd_dtype,\n", + " )\n", + " mapping.update(\n", + " zip(proc_df[\"icd_code\"].to_list(), proc_df[\"long_title\"].to_list())\n", + " )\n", + "\n", + " return mapping\n", + "\n", + "ICD_CODE_TO_DESC = load_icd_description_map(dataset.root)\n", + "print(f\"Loaded {len(ICD_CODE_TO_DESC)} ICD code descriptions\")" + ] + }, + { + "cell_type": "markdown", + "id": "b4274bd9", + "metadata": { + "id": "b4274bd9" + }, + "source": [ + "## 6. Load Pre-trained StageNet Model on GPU" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "22f70a91", + "metadata": { + "id": "22f70a91", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "4ce86eee-b4f6-4dd6-bc64-900dfbda1f52" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Using device: cuda:0\n", + "\n", + "Model loaded successfully on cuda:0\n", + "Model parameters: 9,337,777\n" + ] + } + ], + "source": [ + "# Set device to GPU\n", + "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", + "print(f\"Using device: {device}\")\n", + "\n", + "# Initialize model\n", + "model = StageNet(\n", + " dataset=sample_dataset,\n", + " embedding_dim=128,\n", + " chunk_size=128,\n", + " levels=3,\n", + " dropout=0.3,\n", + ")\n", + "\n", + "# Load checkpoint\n", + "state_dict = torch.load(\"/content/resources/best.ckpt\", map_location=device)\n", + "model.load_state_dict(state_dict)\n", + "model = model.to(device)\n", + "model.eval()\n", + "\n", + "print(f\"\\nModel loaded successfully on {device}\")\n", + "print(f\"Model parameters: {sum(p.numel() for p in model.parameters()):,}\")" + ] + }, + { + "cell_type": "markdown", + "id": "5d5741ab", + "metadata": { + "id": "5d5741ab" + }, + "source": [ + "## 7. Prepare Test Data" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "6cbda428", + "metadata": { + "id": "6cbda428", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "54f18aa7-fd73-4acc-bda9-4bc28658d998" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Test samples: 20\n" + ] + } + ], + "source": [ + "# Split dataset\n", + "_, _, test_data = split_by_patient(sample_dataset, [0.7, 0.1, 0.2], seed=42)\n", + "test_loader = get_dataloader(test_data, batch_size=1, shuffle=False)\n", + "\n", + "print(f\"Test samples: {len(test_data)}\")\n", + "\n", + "def move_batch_to_device(batch, target_device):\n", + " \"\"\"Move all tensors in batch to target device.\"\"\"\n", + " moved = {}\n", + " for key, value in batch.items():\n", + " if isinstance(value, torch.Tensor):\n", + " moved[key] = value.to(target_device)\n", + " elif isinstance(value, tuple):\n", + " moved[key] = tuple(v.to(target_device) for v in value)\n", + " else:\n", + " moved[key] = value\n", + " return moved" + ] + }, + { + "cell_type": "markdown", + "id": "abe56d5e", + "metadata": { + "id": "abe56d5e" + }, + "source": [ + "## 8. Define Helper Functions for Visualization" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "0cce5100", + "metadata": { + "id": "0cce5100" + }, + "outputs": [], + "source": [ + "LAB_CATEGORY_NAMES = MortalityPredictionStageNetMIMIC4.LAB_CATEGORY_NAMES\n", + "\n", + "def decode_token(idx: int, processor, feature_key: str):\n", + " \"\"\"Decode token index to human-readable string.\"\"\"\n", + " if processor is None or not hasattr(processor, \"code_vocab\"):\n", + " return str(idx)\n", + " reverse_vocab = {index: token for token, index in processor.code_vocab.items()}\n", + " token = reverse_vocab.get(idx, f\"\")\n", + "\n", + " if feature_key == \"icd_codes\" and token not in {\"\", \"\"}:\n", + " desc = ICD_CODE_TO_DESC.get(token)\n", + " if desc:\n", + " return f\"{token}: {desc}\"\n", + "\n", + " return token\n", + "\n", + "def unravel(flat_index: int, shape: torch.Size):\n", + " \"\"\"Convert flat index to multi-dimensional coordinates.\"\"\"\n", + " coords = []\n", + " remaining = flat_index\n", + " for dim in reversed(shape):\n", + " coords.append(remaining % dim)\n", + " remaining //= dim\n", + " return list(reversed(coords))\n", + "\n", + "def print_top_attributions(\n", + " attributions,\n", + " batch,\n", + " processors,\n", + " top_k: int = 10,\n", + "):\n", + " \"\"\"Print top-k most important features from SHAP attributions.\"\"\"\n", + " for feature_key, attr in attributions.items():\n", + " attr_cpu = attr.detach().cpu()\n", + " if attr_cpu.dim() == 0 or attr_cpu.size(0) == 0:\n", + " continue\n", + "\n", + " feature_input = batch[feature_key]\n", + " if isinstance(feature_input, tuple):\n", + " feature_input = feature_input[1]\n", + " feature_input = feature_input.detach().cpu()\n", + "\n", + " flattened = attr_cpu[0].flatten()\n", + " if flattened.numel() == 0:\n", + " continue\n", + "\n", + " print(f\"\\nFeature: {feature_key}\")\n", + " print(f\" Shape: {attr_cpu[0].shape}\")\n", + " print(f\" Total attribution sum: {flattened.sum().item():+.6f}\")\n", + " print(f\" Mean attribution: {flattened.mean().item():+.6f}\")\n", + "\n", + " k = min(top_k, flattened.numel())\n", + " top_values, top_indices = torch.topk(flattened.abs(), k=k)\n", + " processor = processors.get(feature_key) if processors else None\n", + " is_continuous = torch.is_floating_point(feature_input)\n", + "\n", + " print(f\"\\n Top {k} most important features:\")\n", + " for rank, (_, flat_idx) in enumerate(zip(top_values, top_indices), 1):\n", + " attribution_value = flattened[flat_idx].item()\n", + " coords = unravel(flat_idx.item(), attr_cpu[0].shape)\n", + "\n", + " if is_continuous:\n", + " actual_value = feature_input[0][tuple(coords)].item()\n", + " label = \"\"\n", + " if feature_key == \"labs\" and len(coords) >= 1:\n", + " lab_idx = coords[-1]\n", + " if lab_idx < len(LAB_CATEGORY_NAMES):\n", + " label = f\"{LAB_CATEGORY_NAMES[lab_idx]} \"\n", + " print(\n", + " f\" {rank:2d}. idx={coords} {label}value={actual_value:.4f} \"\n", + " f\"SHAP={attribution_value:+.6f}\"\n", + " )\n", + " else:\n", + " token_idx = int(feature_input[0][tuple(coords)].item())\n", + " token = decode_token(token_idx, processor, feature_key)\n", + " print(\n", + " f\" {rank:2d}. idx={coords} token='{token}' \"\n", + " f\"SHAP={attribution_value:+.6f}\"\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "21ec480e", + "metadata": { + "id": "21ec480e" + }, + "source": [ + "## 9. Initialize SHAP Explainer\n", + "\n", + "Initialize the SHAP explainer with Kernel SHAP configuration." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "5ae57044", + "metadata": { + "id": "5ae57044", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "0a6bd6e9-8168-467a-c5d1-6134b2be9c3b" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "================================================================================\n", + "Initializing SHAP Explainer\n", + "================================================================================\n", + "\n", + "SHAP Configuration:\n", + " Use embeddings: True\n", + " Background samples: 100\n", + " Max coalitions: 1000\n", + " Regularization: 1e-06\n", + " Device: cuda:0\n" + ] + } + ], + "source": [ + "print(\"=\"*80)\n", + "print(\"Initializing SHAP Explainer\")\n", + "print(\"=\"*80)\n", + "\n", + "# Initialize SHAP explainer (Kernel SHAP)\n", + "shap_explainer = ShapExplainer(model)\n", + "\n", + "print(\"\\nSHAP Configuration:\")\n", + "print(f\" Use embeddings: {shap_explainer.use_embeddings}\")\n", + "print(f\" Background samples: {shap_explainer.n_background_samples}\")\n", + "print(f\" Max coalitions: {shap_explainer.max_coalitions}\")\n", + "print(f\" Regularization: {shap_explainer.regularization}\")\n", + "print(f\" Device: {next(shap_explainer.model.parameters()).device}\")" + ] + }, + { + "cell_type": "markdown", + "id": "a95e8951", + "metadata": { + "id": "a95e8951" + }, + "source": [ + "## 10. Get Model Prediction on Test Sample" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "3cb63b98", + "metadata": { + "id": "3cb63b98", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "695af029-b984-4957-e6de-4ceb92974ed4" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "icd_codes: device=cuda:0\n", + "labs: device=cuda:0\n", + "mortality: device=cuda:0\n", + "\n", + "================================================================================\n", + "Model Prediction for Sampled Patient\n", + "================================================================================\n", + " True label: 0 (Survived)\n", + " Predicted class: 0 (Survived)\n", + " Probabilities: [Survive=0.8156, Death=0.1844]\n" + ] + } + ], + "source": [ + "# Get a sample from test set\n", + "sample_batch = next(iter(test_loader))\n", + "sample_batch_device = move_batch_to_device(sample_batch, device)\n", + "\n", + "# Verify data is on GPU\n", + "for key, val in sample_batch_device.items():\n", + " if isinstance(val, torch.Tensor):\n", + " print(f\"{key}: device={val.device}\")\n", + " elif isinstance(val, tuple) and len(val) > 0 and isinstance(val[0], torch.Tensor):\n", + " print(f\"{key}: device={val[0].device}\")\n", + "\n", + "# Get model prediction\n", + "with torch.no_grad():\n", + " output = model(**sample_batch_device)\n", + " probs = output[\"y_prob\"]\n", + " label_key = model.label_key\n", + " true_label = sample_batch_device[label_key]\n", + "\n", + " # Handle binary classification (single probability output)\n", + " if probs.shape[-1] == 1:\n", + " prob_death = probs[0].item()\n", + " prob_survive = 1 - prob_death\n", + " preds = (probs > 0.5).long()\n", + " else:\n", + " # Multi-class classification\n", + " preds = torch.argmax(probs, dim=-1)\n", + " prob_survive = probs[0][0].item()\n", + " prob_death = probs[0][1].item()\n", + "\n", + " print(\"\\n\" + \"=\"*80)\n", + " print(\"Model Prediction for Sampled Patient\")\n", + " print(\"=\"*80)\n", + " print(f\" True label: {int(true_label.item())} {'(Deceased)' if true_label.item() == 1 else '(Survived)'}\")\n", + " print(f\" Predicted class: {int(preds.item())} {'(Deceased)' if preds.item() == 1 else '(Survived)'}\")\n", + " print(f\" Probabilities: [Survive={prob_survive:.4f}, Death={prob_death:.4f}]\")" + ] + }, + { + "cell_type": "markdown", + "id": "ff2eb9c3", + "metadata": { + "id": "ff2eb9c3" + }, + "source": [ + "## 11. Compute SHAP Attributions (GPU-Accelerated)\n", + "\n", + "This step computes SHAP values using Kernel SHAP, running on GPU." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "c65de0c3", + "metadata": { + "id": "c65de0c3", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "2225e28a-0653-4441-a8c0-611f36e8bd73" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\n", + "================================================================================\n", + "Computing SHAP Attributions on GPU\n", + "================================================================================\n", + "\n", + "✓ Computation completed in 2.24 seconds\n", + "\n", + "Attribution tensor devices:\n", + " icd_codes: device=cuda:0, shape=torch.Size([1, 2, 79])\n", + " labs: device=cuda:0, shape=torch.Size([1, 7, 10])\n" + ] + } + ], + "source": [ + "import time\n", + "\n", + "print(\"\\n\" + \"=\"*80)\n", + "print(\"Computing SHAP Attributions on GPU\")\n", + "print(\"=\"*80)\n", + "\n", + "# Time the computation\n", + "start_time = time.time()\n", + "\n", + "attributions = shap_explainer.attribute(**sample_batch_device, target_class_idx=1)\n", + "\n", + "elapsed = time.time() - start_time\n", + "print(f\"\\n✓ Computation completed in {elapsed:.2f} seconds\")\n", + "\n", + "# Verify attributions are on GPU\n", + "print(\"\\nAttribution tensor devices:\")\n", + "for key, val in attributions.items():\n", + " print(f\" {key}: device={val.device}, shape={val.shape}\")" + ] + }, + { + "cell_type": "markdown", + "id": "481c4c31", + "metadata": { + "id": "481c4c31" + }, + "source": [ + "## 12. Analyze SHAP Results" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "93490ab1", + "metadata": { + "id": "93490ab1", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "4bff2e31-e7d5-42ad-d63c-ea7e1c02c20b" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\n", + "================================================================================\n", + "SHAP Attribution Results\n", + "================================================================================\n", + "\n", + "SHAP values explain the contribution of each feature to the model's\n", + "prediction of MORTALITY (class 1). Positive values increase the\n", + "mortality prediction, negative values decrease it.\n", + "\n", + "Feature: icd_codes\n", + " Shape: torch.Size([2, 79])\n", + " Total attribution sum: +44.044949\n", + " Mean attribution: +0.278765\n", + "\n", + " Top 15 most important features:\n", + " 1. idx=[0, 50] token='' SHAP=+3.113653\n", + " 2. idx=[0, 59] token='' SHAP=+3.113653\n", + " 3. idx=[0, 58] token='' SHAP=+3.113653\n", + " 4. idx=[0, 57] token='' SHAP=+3.113653\n", + " 5. idx=[0, 56] token='' SHAP=+3.113653\n", + " 6. idx=[0, 55] token='' SHAP=+3.113653\n", + " 7. idx=[0, 54] token='' SHAP=+3.113653\n", + " 8. idx=[0, 53] token='' SHAP=+3.113653\n", + " 9. idx=[0, 52] token='' SHAP=+3.113653\n", + " 10. idx=[0, 51] token='' SHAP=+3.113653\n", + " 11. idx=[0, 42] token='' SHAP=+3.113653\n", + " 12. idx=[0, 44] token='' SHAP=+3.113653\n", + " 13. idx=[0, 43] token='' SHAP=+3.113653\n", + " 14. idx=[0, 45] token='' SHAP=+3.113653\n", + " 15. idx=[0, 41] token='' SHAP=+3.113653\n", + "\n", + "Feature: labs\n", + " Shape: torch.Size([7, 10])\n", + " Total attribution sum: +0.576018\n", + " Mean attribution: +0.008229\n", + "\n", + " Top 15 most important features:\n", + " 1. idx=[6, 9] Phosphate value=6.4000 SHAP=+0.047629\n", + " 2. idx=[6, 0] Sodium value=139.0000 SHAP=+0.047629\n", + " 3. idx=[6, 1] Potassium value=5.5000 SHAP=+0.047629\n", + " 4. idx=[6, 2] Chloride value=95.0000 SHAP=+0.047629\n", + " 5. idx=[6, 3] Bicarbonate value=0.0000 SHAP=+0.047629\n", + " 6. idx=[6, 4] Glucose value=90.0000 SHAP=+0.047629\n", + " 7. idx=[6, 6] Magnesium value=2.6000 SHAP=+0.047629\n", + " 8. idx=[6, 5] Calcium value=0.0000 SHAP=+0.047629\n", + " 9. idx=[6, 8] Osmolality value=0.0000 SHAP=+0.047629\n", + " 10. idx=[6, 7] Anion Gap value=20.0000 SHAP=+0.047629\n", + " 11. idx=[2, 3] Bicarbonate value=0.0000 SHAP=-0.036567\n", + " 12. idx=[2, 5] Calcium value=0.0000 SHAP=-0.036567\n", + " 13. idx=[2, 4] Glucose value=208.0000 SHAP=-0.036567\n", + " 14. idx=[2, 2] Chloride value=96.0000 SHAP=-0.036567\n", + " 15. idx=[2, 1] Potassium value=4.5000 SHAP=-0.036567\n" + ] + } + ], + "source": [ + "print(\"\\n\" + \"=\"*80)\n", + "print(\"SHAP Attribution Results\")\n", + "print(\"=\"*80)\n", + "print(\"\\nSHAP values explain the contribution of each feature to the model's\")\n", + "print(\"prediction of MORTALITY (class 1). Positive values increase the\")\n", + "print(\"mortality prediction, negative values decrease it.\")\n", + "\n", + "print_top_attributions(attributions, sample_batch_device, input_processors, top_k=15)" + ] + }, + { + "cell_type": "markdown", + "id": "7d5b8e9c", + "metadata": { + "id": "7d5b8e9c" + }, + "source": [ + "## 13. Test Different Target Classes" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "c7b02451", + "metadata": { + "id": "c7b02451", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "5f2b6f99-5432-4db4-f2fa-4227bd335858" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\n", + "================================================================================\n", + "Comparing SHAP Attributions for Different Target Classes\n", + "================================================================================\n", + "\n", + "Computing attributions for SURVIVAL (class 0)...\n", + "Computing attributions for MORTALITY (class 1)...\n", + "\n", + "--- Features promoting SURVIVAL ---\n", + "\n", + "Feature: icd_codes\n", + " Shape: torch.Size([2, 79])\n", + " Total attribution sum: +34.955112\n", + " Mean attribution: +0.221235\n", + "\n", + " Top 5 most important features:\n", + " 1. idx=[0, 52] token='' SHAP=+9.780861\n", + " 2. idx=[0, 54] token='' SHAP=+9.780861\n", + " 3. idx=[0, 53] token='' SHAP=+9.780861\n", + " 4. idx=[0, 55] token='' SHAP=+9.780861\n", + " 5. idx=[0, 51] token='' SHAP=+9.780861\n", + "\n", + "Feature: labs\n", + " Shape: torch.Size([7, 10])\n", + " Total attribution sum: +9.426432\n", + " Mean attribution: +0.134663\n", + "\n", + " Top 5 most important features:\n", + " 1. idx=[6, 2] Chloride value=95.0000 SHAP=+0.391763\n", + " 2. idx=[6, 4] Glucose value=90.0000 SHAP=+0.391763\n", + " 3. idx=[6, 3] Bicarbonate value=0.0000 SHAP=+0.391763\n", + " 4. idx=[6, 9] Phosphate value=6.4000 SHAP=+0.391763\n", + " 5. idx=[6, 1] Potassium value=5.5000 SHAP=+0.391763\n", + "\n", + "--- Features promoting MORTALITY ---\n", + "\n", + "Feature: icd_codes\n", + " Shape: torch.Size([2, 79])\n", + " Total attribution sum: +44.044949\n", + " Mean attribution: +0.278765\n", + "\n", + " Top 5 most important features:\n", + " 1. idx=[0, 52] token='' SHAP=+3.113653\n", + " 2. idx=[0, 54] token='' SHAP=+3.113653\n", + " 3. idx=[0, 53] token='' SHAP=+3.113653\n", + " 4. idx=[0, 55] token='' SHAP=+3.113653\n", + " 5. idx=[0, 51] token='' SHAP=+3.113653\n", + "\n", + "Feature: labs\n", + " Shape: torch.Size([7, 10])\n", + " Total attribution sum: +0.576018\n", + " Mean attribution: +0.008229\n", + "\n", + " Top 5 most important features:\n", + " 1. idx=[6, 2] Chloride value=95.0000 SHAP=+0.047629\n", + " 2. idx=[6, 4] Glucose value=90.0000 SHAP=+0.047629\n", + " 3. idx=[6, 3] Bicarbonate value=0.0000 SHAP=+0.047629\n", + " 4. idx=[6, 9] Phosphate value=6.4000 SHAP=+0.047629\n", + " 5. idx=[6, 1] Potassium value=5.5000 SHAP=+0.047629\n" + ] + } + ], + "source": [ + "print(\"\\n\" + \"=\"*80)\n", + "print(\"Comparing SHAP Attributions for Different Target Classes\")\n", + "print(\"=\"*80)\n", + "\n", + "# Compute for survival (class 0)\n", + "print(\"\\nComputing attributions for SURVIVAL (class 0)...\")\n", + "attr_survive = shap_explainer.attribute(**sample_batch_device, target_class_idx=0)\n", + "\n", + "# Compute for mortality (class 1)\n", + "print(\"Computing attributions for MORTALITY (class 1)...\")\n", + "attr_death = shap_explainer.attribute(**sample_batch_device, target_class_idx=1)\n", + "\n", + "print(\"\\n--- Features promoting SURVIVAL ---\")\n", + "print_top_attributions(attr_survive, sample_batch_device, input_processors, top_k=5)\n", + "\n", + "print(\"\\n--- Features promoting MORTALITY ---\")\n", + "print_top_attributions(attr_death, sample_batch_device, input_processors, top_k=5)" + ] + }, + { + "cell_type": "markdown", + "id": "12cc5987", + "metadata": { + "id": "12cc5987" + }, + "source": [ + "## 14. Verify GPU Memory Usage" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "4a5c098c", + "metadata": { + "id": "4a5c098c", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "c1d93796-238e-42c1-ac0c-7ae9013030b4" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\n", + "================================================================================\n", + "GPU Memory Usage\n", + "================================================================================\n", + " Currently allocated: 0.08 GB\n", + " Reserved: 0.17 GB\n", + " Peak allocated: 0.14 GB\n" + ] + } + ], + "source": [ + "if torch.cuda.is_available():\n", + " print(\"\\n\" + \"=\"*80)\n", + " print(\"GPU Memory Usage\")\n", + " print(\"=\"*80)\n", + "\n", + " allocated = torch.cuda.memory_allocated(0) / 1e9\n", + " reserved = torch.cuda.memory_reserved(0) / 1e9\n", + " max_allocated = torch.cuda.max_memory_allocated(0) / 1e9\n", + "\n", + " print(f\" Currently allocated: {allocated:.2f} GB\")\n", + " print(f\" Reserved: {reserved:.2f} GB\")\n", + " print(f\" Peak allocated: {max_allocated:.2f} GB\")\n", + "\n", + " # Reset peak stats\n", + " torch.cuda.reset_peak_memory_stats(0)\n", + "else:\n", + " print(\"GPU not available\")" + ] + }, + { + "cell_type": "markdown", + "id": "483d95cd", + "metadata": { + "id": "483d95cd" + }, + "source": [ + "## 15. Test Callable Interface" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "69867127", + "metadata": { + "id": "69867127", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "d366b86c-e7f4-40d7-881f-cf7710acc812" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\n", + "================================================================================\n", + "Testing Callable Interface\n", + "================================================================================\n", + "\n", + "Verifying that explainer(**data) and explainer.attribute(**data) produce\n", + "identical results...\n", + " ✓ icd_codes: Results match\n", + " ✓ labs: Results match\n", + "\n", + "✓ All attributions match! Callable interface works correctly.\n" + ] + } + ], + "source": [ + "print(\"\\n\" + \"=\"*80)\n", + "print(\"Testing Callable Interface\")\n", + "print(\"=\"*80)\n", + "\n", + "# Both methods should produce identical results\n", + "attr_from_attribute = shap_explainer.attribute(**sample_batch_device, target_class_idx=1)\n", + "attr_from_call = shap_explainer(**sample_batch_device, target_class_idx=1)\n", + "\n", + "print(\"\\nVerifying that explainer(**data) and explainer.attribute(**data) produce\")\n", + "print(\"identical results...\")\n", + "\n", + "all_close = True\n", + "for key in attr_from_attribute.keys():\n", + " if not torch.allclose(attr_from_attribute[key], attr_from_call[key], atol=1e-6):\n", + " all_close = False\n", + " print(f\" ❌ {key}: Results differ!\")\n", + " else:\n", + " print(f\" ✓ {key}: Results match\")\n", + "\n", + "if all_close:\n", + " print(\"\\n✓ All attributions match! Callable interface works correctly.\")\n", + "else:\n", + " print(\"\\n❌ Some attributions differ.\")" + ] + }, + { + "cell_type": "markdown", + "id": "0a9d0d8e", + "metadata": { + "id": "0a9d0d8e" + }, + "source": [ + "## Summary\n", + "\n", + "This notebook demonstrated:\n", + "\n", + "1. ✅ **GPU Setup**: Verified GPU availability and configured PyTorch to use CUDA\n", + "2. ✅ **Model Loading**: Loaded pre-trained StageNet model on GPU\n", + "3. ✅ **SHAP Computation**: Computed SHAP attributions on GPU for discrete features (ICD codes)\n", + "4. ✅ **Feature Interpretation**: Identified which diagnosis/procedure codes and lab values most influenced mortality predictions\n", + "5. ✅ **Multi-class Analysis**: Compared attributions for different target classes (survival vs. mortality)\n", + "6. ✅ **GPU Optimization**: Verified all tensors and computations run on GPU\n", + "\n", + "**Key Takeaways:**\n", + "- SHAP provides interpretable, theoretically-grounded feature attributions\n", + "- GPU acceleration significantly speeds up coalition sampling and model evaluations\n", + "- The method works seamlessly with discrete healthcare features like ICD codes\n", + "- Positive SHAP values indicate features that increase the prediction, negative values decrease it" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + }, + "colab": { + "provenance": [], + "gpuType": "T4" + }, + "accelerator": "GPU", + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/examples/shap_stagenet_mimic4.py b/examples/shap_stagenet_mimic4.py new file mode 100644 index 000000000..8a948762b --- /dev/null +++ b/examples/shap_stagenet_mimic4.py @@ -0,0 +1,298 @@ +# %% Loading MIMIC-IV dataset +from pathlib import Path + +import polars as pl +import torch + +from pyhealth.datasets import ( + MIMIC4EHRDataset, + get_dataloader, + load_processors, + split_by_patient, +) +from pyhealth.interpret.methods import ShapExplainer +from pyhealth.models import StageNet +from pyhealth.tasks import MortalityPredictionStageNetMIMIC4 + +# Configure dataset location and load cached processors +dataset = MIMIC4EHRDataset( + #root="/home/naveen-baskaran/physionet.org/files/mimic-iv-demo/2.2/", + #root="/Users/naveenbaskaran/data/physionet.org/files/mimic-iv-demo/2.2/", + root="~/data/physionet.org/files/mimic-iv-demo/2.2/", + tables=[ + "patients", + "admissions", + "diagnoses_icd", + "procedures_icd", + "labevents", + ], +) + +# %% Setting StageNet Mortality Prediction Task +input_processors, output_processors = load_processors("../resources/") + +sample_dataset = dataset.set_task( + MortalityPredictionStageNetMIMIC4(), + cache_dir="~/.cache/pyhealth/mimic4_stagenet_mortality", + input_processors=input_processors, + output_processors=output_processors, +) +print(f"Total samples: {len(sample_dataset)}") + + +def load_icd_description_map(dataset_root: str) -> dict: + """Load ICD code → long title mappings from MIMIC-IV reference tables.""" + mapping = {} + root_path = Path(dataset_root).expanduser() + diag_path = root_path / "hosp" / "d_icd_diagnoses.csv.gz" + proc_path = root_path / "hosp" / "d_icd_procedures.csv.gz" + + icd_dtype = {"icd_code": pl.Utf8, "long_title": pl.Utf8} + + if diag_path.exists(): + diag_df = pl.read_csv( + diag_path, + columns=["icd_code", "long_title"], + dtypes=icd_dtype, + ) + mapping.update( + zip(diag_df["icd_code"].to_list(), diag_df["long_title"].to_list()) + ) + + if proc_path.exists(): + proc_df = pl.read_csv( + proc_path, + columns=["icd_code", "long_title"], + dtypes=icd_dtype, + ) + mapping.update( + zip(proc_df["icd_code"].to_list(), proc_df["long_title"].to_list()) + ) + + return mapping + + +ICD_CODE_TO_DESC = load_icd_description_map(dataset.root) + +# %% Loading Pretrained StageNet Model +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +model = StageNet( + dataset=sample_dataset, + embedding_dim=128, + chunk_size=128, + levels=3, + dropout=0.3, +) + +state_dict = torch.load("../resources/best.ckpt", map_location=device) +model.load_state_dict(state_dict) +model = model.to(device) +model.eval() +print(model) + +# %% Preparing dataloaders +_, _, test_data = split_by_patient(sample_dataset, [0.7, 0.1, 0.2], seed=42) +test_loader = get_dataloader(test_data, batch_size=1, shuffle=False) + + +def move_batch_to_device(batch, target_device): + """Move all tensors in batch to target device.""" + moved = {} + for key, value in batch.items(): + if isinstance(value, torch.Tensor): + moved[key] = value.to(target_device) + elif isinstance(value, tuple): + moved[key] = tuple(v.to(target_device) for v in value) + else: + moved[key] = value + return moved + + +LAB_CATEGORY_NAMES = MortalityPredictionStageNetMIMIC4.LAB_CATEGORY_NAMES + + +def decode_token(idx: int, processor, feature_key: str): + """Decode token index to human-readable string.""" + if processor is None or not hasattr(processor, "code_vocab"): + return str(idx) + reverse_vocab = {index: token for token, index in processor.code_vocab.items()} + token = reverse_vocab.get(idx, f"") + + if feature_key == "icd_codes" and token not in {"", ""}: + desc = ICD_CODE_TO_DESC.get(token) + if desc: + return f"{token}: {desc}" + + return token + + +def unravel(flat_index: int, shape: torch.Size): + """Convert flat index to multi-dimensional coordinates.""" + coords = [] + remaining = flat_index + for dim in reversed(shape): + coords.append(remaining % dim) + remaining //= dim + return list(reversed(coords)) + + +def print_top_attributions( + attributions, + batch, + processors, + top_k: int = 10, +): + """Print top-k most important features from SHAP attributions.""" + for feature_key, attr in attributions.items(): + attr_cpu = attr.detach().cpu() + if attr_cpu.dim() == 0 or attr_cpu.size(0) == 0: + continue + + feature_input = batch[feature_key] + if isinstance(feature_input, tuple): + feature_input = feature_input[1] + feature_input = feature_input.detach().cpu() + + flattened = attr_cpu[0].flatten() + if flattened.numel() == 0: + continue + + print(f"\nFeature: {feature_key}") + print(f" Shape: {attr_cpu[0].shape}") + print(f" Total attribution sum: {flattened.sum().item():+.6f}") + print(f" Mean attribution: {flattened.mean().item():+.6f}") + + k = min(top_k, flattened.numel()) + top_values, top_indices = torch.topk(flattened.abs(), k=k) + processor = processors.get(feature_key) if processors else None + is_continuous = torch.is_floating_point(feature_input) + + print(f"\n Top {k} most important features:") + for rank, (_, flat_idx) in enumerate(zip(top_values, top_indices), 1): + attribution_value = flattened[flat_idx].item() + coords = unravel(flat_idx.item(), attr_cpu[0].shape) + + if is_continuous: + actual_value = feature_input[0][tuple(coords)].item() + label = "" + if feature_key == "labs" and len(coords) >= 1: + lab_idx = coords[-1] + if lab_idx < len(LAB_CATEGORY_NAMES): + label = f"{LAB_CATEGORY_NAMES[lab_idx]} " + print( + f" {rank:2d}. idx={coords} {label}value={actual_value:.4f} " + f"SHAP={attribution_value:+.6f}" + ) + else: + token_idx = int(feature_input[0][tuple(coords)].item()) + token = decode_token(token_idx, processor, feature_key) + print( + f" {rank:2d}. idx={coords} token='{token}' " + f"SHAP={attribution_value:+.6f}" + ) + + +# %% Run SHAP on a held-out sample +print("\n" + "="*80) +print("Initializing SHAP Explainer") +print("="*80) + +# Initialize SHAP explainer (Kernel SHAP)) +shap_explainer = ShapExplainer(model) + +print("\nSHAP Configuration:") +print(f" Use embeddings: {shap_explainer.use_embeddings}") +print(f" Background samples: {shap_explainer.n_background_samples}") +print(f" Max coalitions: {shap_explainer.max_coalitions}") +print(f" Regularization: {shap_explainer.regularization}") + +# Get a sample from test set +sample_batch = next(iter(test_loader)) +sample_batch_device = move_batch_to_device(sample_batch, device) + +# Get model prediction +with torch.no_grad(): + output = model(**sample_batch_device) + probs = output["y_prob"] + label_key = model.label_key + true_label = sample_batch_device[label_key] + + # Handle binary classification (single probability output) + if probs.shape[-1] == 1: + prob_death = probs[0].item() + prob_survive = 1 - prob_death + preds = (probs > 0.5).long() + else: + # Multi-class classification + preds = torch.argmax(probs, dim=-1) + prob_survive = probs[0][0].item() + prob_death = probs[0][1].item() + + print("\n" + "="*80) + print("Model Prediction for Sampled Patient") + print("="*80) + print(f" True label: {int(true_label.item())} {'(Deceased)' if true_label.item() == 1 else '(Survived)'}") + print(f" Predicted class: {int(preds.item())} {'(Deceased)' if preds.item() == 1 else '(Survived)'}") + print(f" Probabilities: [Survive={prob_survive:.4f}, Death={prob_death:.4f}]") + +# Compute SHAP values +print("\n" + "="*80) +print("Computing SHAP Attributions (this may take a minute...)") +print("="*80) + +attributions = shap_explainer.attribute(**sample_batch_device, target_class_idx=1) + +print("\n" + "="*80) +print("SHAP Attribution Results") +print("="*80) +print("\nSHAP values explain the contribution of each feature to the model's") +print("prediction of MORTALITY (class 1). Positive values increase the") +print("mortality prediction, negative values decrease it.") + +print_top_attributions(attributions, sample_batch_device, input_processors, top_k=15) + +# %% Compare different baseline strategies +print("\n\n" + "="*80) +print("Testing Different Baseline Strategies") +print("="*80) + +# 1. Automatic baseline (default) +print("\n1. Automatic baseline generation (recommended):") +attr_auto = shap_explainer.attribute(**sample_batch_device, target_class_idx=1) +print(f" Total attribution (icd_codes): {attr_auto['icd_codes'][0].sum().item():+.6f}") +print(f" Total attribution (labs): {attr_auto['labs'][0].sum().item():+.6f}") + +# Note: Custom baselines for discrete features (like ICD codes) require careful +# construction to avoid invalid sequences. The automatic baseline generation +# handles this by sampling from the observed data distribution. + +# %% Test callable interface +print("\n" + "="*80) +print("Testing Callable Interface") +print("="*80) + +# Both methods should produce identical results (due to random_seed) +attr_from_attribute = shap_explainer.attribute(**sample_batch_device, target_class_idx=1) +attr_from_call = shap_explainer(**sample_batch_device, target_class_idx=1) + +print("\nVerifying that explainer(**data) and explainer.attribute(**data) produce") +print("identical results when random_seed is set...") + +all_close = True +for key in attr_from_attribute.keys(): + if not torch.allclose(attr_from_attribute[key], attr_from_call[key], atol=1e-6): + all_close = False + print(f" ❌ {key}: Results differ!") + else: + print(f" ✓ {key}: Results match") + +if all_close: + print("\n✓ All attributions match! Callable interface works correctly.") +else: + print("\n❌ Some attributions differ. Check random seed configuration.") + +print("\n" + "="*80) +print("SHAP Analysis Complete") +print("="*80) + +# %% diff --git a/pyhealth/interpret/methods/__init__.py b/pyhealth/interpret/methods/__init__.py index c6b6e461d..52796ffd1 100644 --- a/pyhealth/interpret/methods/__init__.py +++ b/pyhealth/interpret/methods/__init__.py @@ -4,6 +4,7 @@ from pyhealth.interpret.methods.deeplift import DeepLift from pyhealth.interpret.methods.gim import GIM from pyhealth.interpret.methods.integrated_gradients import IntegratedGradients +from pyhealth.interpret.methods.shap import ShapExplainer __all__ = [ "BaseInterpreter", @@ -11,5 +12,6 @@ "DeepLift", "GIM", "IntegratedGradients", + "BasicGradientSaliencyMaps", + "ShapExplainer" ] -__all__ = ["BaseInterpreter", "BasicGradientSaliencyMaps", "CheferRelevance", "DeepLift", "IntegratedGradients"] diff --git a/pyhealth/interpret/methods/shap.py b/pyhealth/interpret/methods/shap.py new file mode 100644 index 000000000..9a371f012 --- /dev/null +++ b/pyhealth/interpret/methods/shap.py @@ -0,0 +1,913 @@ +from __future__ import annotations + +import math +from typing import Dict, Optional, Tuple + +import torch + +from pyhealth.models import BaseModel +from .base_interpreter import BaseInterpreter + + +class ShapExplainer(BaseInterpreter): + """SHAP (SHapley Additive exPlanations) attribution method for PyHealth models. + + This class implements the SHAP method for computing feature attributions in + neural networks. SHAP values represent each feature's contribution to the + prediction, based on coalitional game theory principles. + + The method is based on the paper: + A Unified Approach to Interpreting Model Predictions + Scott Lundberg, Su-In Lee + NeurIPS 2017 + https://arxiv.org/abs/1705.07874 + + Kernel SHAP Method: + This implementation uses Kernel SHAP, which combines ideas from LIME (Local + Interpretable Model-agnostic Explanations) with Shapley values from game theory. + The key steps are: + 1. Generate background samples to establish baseline predictions + 2. Create feature coalitions (subsets of features) using weighted sampling + 3. Compute model predictions for each coalition + 4. Solve a weighted least squares problem to estimate Shapley values + + Mathematical Foundation: + The Shapley value for feature i is computed as: + φᵢ = Σ (|S|!(n-|S|-1)!/n!) * [f₀(S ∪ {i}) - f₀(S)] + where: + - S is a subset of features excluding i + - n is the total number of features + - f₀(S) is the model prediction with only features in S + + SHAP provides several desirable properties: + 1. Local Accuracy: The sum of feature attributions equals the difference between + the model output and the expected output + 2. Missingness: Features with zero impact get zero attribution + 3. Consistency: Changing a model to increase a feature's impact increases its attribution + + Args: + model: A trained PyHealth model to interpret. Can be any model that + inherits from BaseModel (e.g., MLP, StageNet, Transformer, RNN). + use_embeddings: If True, compute SHAP values with respect to + embeddings rather than discrete input tokens. This is crucial + for models with discrete inputs (like ICD codes). The model + must support returning embeddings via an 'embed' parameter. + Default is True. + n_background_samples: Number of background samples to use for + estimating feature contributions. More samples give better + estimates but increase computation time. Default is 100. + max_coalitions: Maximum number of feature coalitions to sample for + Kernel SHAP approximation. Default is 1000. + regularization: L2 regularization strength for the weighted least + squares problem. Default is 1e-6. + + Examples: + >>> import torch + >>> from pyhealth.datasets import SampleDataset, get_dataloader + >>> from pyhealth.models import MLP + >>> from pyhealth.interpret.methods import ShapExplainer + >>> from pyhealth.trainer import Trainer + >>> + >>> # Create dataset and model + >>> dataset = SampleDataset(...) + >>> model = MLP(...) + >>> trainer = Trainer(model=model, device="cuda:0") + >>> trainer.train(...) + >>> test_batch = next(iter(test_loader)) + >>> + >>> # Initialize SHAP explainer + >>> explainer = ShapExplainer(model, use_embeddings=True) + >>> shap_values = explainer.attribute(**test_batch) + >>> + >>> # With custom baseline + >>> baseline = { + ... 'conditions': torch.zeros_like(test_batch['conditions']), + ... 'procedures': torch.full_like(test_batch['procedures'], + ... test_batch['procedures'].mean()) + ... } + >>> shap_values = explainer.attribute(baseline=baseline, **test_batch) + >>> + >>> print(shap_values) + {'conditions': tensor([[0.1234, 0.5678, 0.9012]], device='cuda:0'), + 'procedures': tensor([[0.2345, 0.6789, 0.0123, 0.4567]])} + """ + + def __init__( + self, + model: BaseModel, + use_embeddings: bool = True, + n_background_samples: int = 100, + max_coalitions: int = 1000, + regularization: float = 1e-6, + random_seed: Optional[int] = 42, + ): + """Initialize SHAP explainer. + + Args: + model: A trained PyHealth model to interpret. + use_embeddings: If True, compute SHAP values with respect to + embeddings rather than discrete input tokens. + n_background_samples: Number of background samples to use for + estimating feature contributions. + max_coalitions: Maximum number of feature coalitions to sample. + regularization: L2 regularization strength for weighted least squares. + random_seed: Optional random seed for reproducibility. If provided, + this seed will be used to initialize the random number generator + before each attribution computation, ensuring deterministic results. + + Raises: + AssertionError: If use_embeddings=True but model does not + implement forward_from_embedding() method. + """ + super().__init__(model) + self.use_embeddings = use_embeddings + self.n_background_samples = n_background_samples + self.max_coalitions = max_coalitions + self.regularization = regularization + self.random_seed = random_seed + + # Validate model requirements + if use_embeddings: + assert hasattr(model, "forward_from_embedding"), ( + f"Model {type(model).__name__} must implement " + "forward_from_embedding() method to support embedding-level " + "SHAP values. Set use_embeddings=False to use " + "input-level attributions (only for continuous features)." + ) + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + def attribute( + self, + baseline: Optional[Dict[str, torch.Tensor]] = None, + target_class_idx: Optional[int] = None, + **data, + ) -> Dict[str, torch.Tensor]: + """Compute SHAP attributions for input features. + + This is the main interface for computing feature attributions. It handles: + 1. Input preparation and validation + 2. Background sample generation or validation + 3. Feature attribution computation using Kernel SHAP + + Args: + baseline: Optional dictionary mapping feature names to background + samples. If None, generates samples automatically using + _generate_background_samples(). Shape of each tensor should + be (n_background_samples, ..., feature_dim). + target_class_idx: For multi-class models, specifies which class's + prediction to explain. If None, explains the model's + maximum prediction across all classes. + **data: Input data dictionary from dataloader batch. Should contain: + - Feature tensors with shape (batch_size, ..., feature_dim) + - Optional time information for temporal models + - Optional label data for supervised models + + Returns: + Dictionary mapping feature names to their SHAP values. Each value + tensor has the same shape as its corresponding input and contains + the feature's contribution to the prediction relative to the baseline. + Positive values indicate features that increased the prediction, + negative values indicate features that decreased it. + + Example: + >>> shap_values = explainer.attribute( + ... x_continuous=torch.tensor([[1.0, 2.0, 3.0]]), + ... x_categorical=torch.tensor([[0, 1, 2]]), + ... target_class_idx=1 + ... ) + >>> print(shap_values['x_continuous']) # Shape: (1, 3) + """ + # Set random seed for reproducibility if specified + if self.random_seed is not None: + torch.manual_seed(self.random_seed) + + device = next(self.model.parameters()).device + + # Extract and prepare inputs + feature_inputs: Dict[str, torch.Tensor] = {} + time_info: Dict[str, torch.Tensor] = {} + label_data: Dict[str, torch.Tensor] = {} + + for key in self.model.feature_keys: + if key not in data: + continue + value = data[key] + + # Handle (time, value) tuples for temporal data + if isinstance(value, tuple): + time_tensor, feature_tensor = value + if time_tensor is not None: + time_info[key] = time_tensor.to(device) + value = feature_tensor + + if not isinstance(value, torch.Tensor): + value = torch.as_tensor(value) + feature_inputs[key] = value.to(device) + + # Store label data + for key in self.model.label_keys: + if key in data: + label_val = data[key] + if not isinstance(label_val, torch.Tensor): + label_val = torch.as_tensor(label_val) + label_data[key] = label_val.to(device) + + # Generate or validate background samples + if baseline is None: + background = self._generate_background_samples(feature_inputs) + else: + background = {k: v.to(device) for k, v in baseline.items()} + + # Compute SHAP values + if self.use_embeddings: + return self._shap_embeddings( + feature_inputs, + background=background, + target_class_idx=target_class_idx, + time_info=time_info, + label_data=label_data, + ) + else: + return self._shap_continuous( + feature_inputs, + background=background, + target_class_idx=target_class_idx, + time_info=time_info, + label_data=label_data, + ) + + # ------------------------------------------------------------------ + # Embedding-based SHAP (discrete features) + # ------------------------------------------------------------------ + def _shap_embeddings( + self, + inputs: Dict[str, torch.Tensor], + background: Dict[str, torch.Tensor], + target_class_idx: Optional[int], + time_info: Dict[str, torch.Tensor], + label_data: Dict[str, torch.Tensor], + ) -> Dict[str, torch.Tensor]: + """Compute SHAP values for discrete inputs in embedding space. + + Args: + inputs: Dictionary of input tensors. + background: Dictionary of background samples. + target_class_idx: Target class index for attribution. + time_info: Temporal information for time-series models. + label_data: Label information for supervised models. + + Returns: + Dictionary of SHAP values mapped back to input shapes. + """ + # Embed inputs and background + input_embs = self.model.embedding_model(inputs) + background_embs = self.model.embedding_model(background) + + # Store original input shapes for mapping back + input_shapes = {key: val.shape for key, val in inputs.items()} + + # Compute SHAP values for each feature + shap_values = {} + for key in inputs: + n_features = self._determine_n_features(key, inputs, input_embs) + + shap_matrix = self._compute_kernel_shap( + key=key, + input_emb=input_embs, + background_emb=background_embs, + n_features=n_features, + target_class_idx=target_class_idx, + time_info=time_info, + label_data=label_data, + ) + + shap_values[key] = shap_matrix + + # Map embedding-space attributions back to input shapes + return self._map_to_input_shapes(shap_values, input_shapes) + + # ------------------------------------------------------------------ + # Continuous SHAP (for tensor inputs) + # ------------------------------------------------------------------ + def _shap_continuous( + self, + inputs: Dict[str, torch.Tensor], + background: Dict[str, torch.Tensor], + target_class_idx: Optional[int], + time_info: Dict[str, torch.Tensor], + label_data: Dict[str, torch.Tensor], + ) -> Dict[str, torch.Tensor]: + """Compute SHAP values for continuous tensor inputs. + + Args: + inputs: Dictionary of input tensors. + background: Dictionary of background samples. + target_class_idx: Target class index for attribution. + time_info: Temporal information for time-series models. + label_data: Label information for supervised models. + + Returns: + Dictionary of SHAP values with same shapes as inputs. + """ + shap_values = {} + + for key in inputs: + n_features = self._determine_n_features(key, inputs, inputs) + + shap_matrix = self._compute_kernel_shap( + key=key, + input_emb=inputs, + background_emb=background, + n_features=n_features, + target_class_idx=target_class_idx, + time_info=time_info, + label_data=label_data, + ) + + shap_values[key] = shap_matrix + + return shap_values + + # ------------------------------------------------------------------ + # Core Kernel SHAP computation + # ------------------------------------------------------------------ + def _compute_kernel_shap( + self, + key: str, + input_emb: Dict[str, torch.Tensor], + background_emb: Dict[str, torch.Tensor], + n_features: int, + target_class_idx: Optional[int], + time_info: Optional[Dict[str, torch.Tensor]], + label_data: Optional[Dict[str, torch.Tensor]], + ) -> torch.Tensor: + """Compute SHAP values using the Kernel SHAP approximation method. + + This implements the Kernel SHAP algorithm that approximates Shapley values + through a weighted least squares regression. The key steps are: + + 1. Feature Coalitions: Generate random subsets of features + 2. Model Evaluation: Evaluate mixed samples (background + coalition) + 3. Weighted Least Squares: Solve for SHAP values using kernel weights + + Args: + key: Feature key being explained. + input_emb: Dictionary of input embeddings/tensors. + background_emb: Dictionary of background embeddings/tensors. + n_features: Number of features to explain. + target_class_idx: Target class index for multi-class models. + time_info: Optional temporal information. + label_data: Optional label information. + + Returns: + torch.Tensor: SHAP values with shape (batch_size, n_features). + """ + device = input_emb[key].device + batch_size = input_emb[key].shape[0] if input_emb[key].dim() >= 2 else 1 + n_coalitions = min(2 ** n_features, self.max_coalitions) + + # Storage for coalition sampling + coalition_vectors = [] + coalition_weights = [] + coalition_preds = [] + + # Add edge case coalitions explicitly (empty and full) + # These are crucial for the local accuracy property of SHAP + edge_coalitions = [ + torch.zeros(n_features, device=device), # Empty coalition (baseline) + torch.ones(n_features, device=device), # Full coalition (actual input) + ] + + for coalition in edge_coalitions: + per_input_preds = [] + for b_idx in range(batch_size): + mixed_emb = self._create_mixed_sample( + key, coalition, input_emb, background_emb, b_idx + ) + + pred = self._evaluate_coalition( + key, mixed_emb, background_emb, + target_class_idx, time_info, label_data + ) + per_input_preds.append(pred) + + coalition_vectors.append(coalition.float()) + coalition_preds.append(torch.stack(per_input_preds, dim=0)) + coalition_weights.append( + self._compute_kernel_weight(coalition.sum().item(), n_features) + ) + + # Sample remaining coalitions randomly (excluding edge cases already added) + n_random_coalitions = max(0, n_coalitions - 2) + for _ in range(n_random_coalitions): + coalition = torch.randint(2, (n_features,), device=device) + + # Evaluate model for each input sample with this coalition + per_input_preds = [] + for b_idx in range(batch_size): + mixed_emb = self._create_mixed_sample( + key, coalition, input_emb, background_emb, b_idx + ) + + pred = self._evaluate_coalition( + key, mixed_emb, background_emb, + target_class_idx, time_info, label_data + ) + per_input_preds.append(pred) + + # Store coalition information + coalition_vectors.append(coalition.float()) + coalition_preds.append(torch.stack(per_input_preds, dim=0)) + coalition_weights.append( + self._compute_kernel_weight(coalition.sum().item(), n_features) + ) + + # Solve weighted least squares + return self._solve_weighted_least_squares( + coalition_vectors, coalition_preds, coalition_weights, device + ) + + def _create_mixed_sample( + self, + key: str, + coalition: torch.Tensor, + input_emb: Dict[str, torch.Tensor], + background_emb: Dict[str, torch.Tensor], + batch_idx: int, + ) -> torch.Tensor: + """Create a mixed sample by combining background and input based on coalition. + + Args: + key: Feature key. + coalition: Binary vector indicating which features to use from input. + input_emb: Input embeddings. + background_emb: Background embeddings. + batch_idx: Index of the sample in the batch. + + Returns: + Mixed sample tensor. + """ + mixed = background_emb[key].clone() + + for i, use_input in enumerate(coalition): + if not use_input: + continue + + # Handle various embedding shapes + dim = input_emb[key].dim() + if dim == 4: # (batch, seq_len, inner_len, emb) + mixed[:, i, :, :] = input_emb[key][batch_idx, i, :, :] + elif dim == 3: # (batch, seq_len, emb) + mixed[:, i, :] = input_emb[key][batch_idx, i, :] + else: # 2D or other + mixed[:, i] = input_emb[key][batch_idx, i] + + return mixed + + def _evaluate_coalition( + self, + key: str, + mixed_emb: torch.Tensor, + background_emb: Dict[str, torch.Tensor], + target_class_idx: Optional[int], + time_info: Optional[Dict[str, torch.Tensor]], + label_data: Optional[Dict[str, torch.Tensor]], + ) -> torch.Tensor: + """Evaluate model prediction for a coalition. + + Args: + key: Feature key being explained. + mixed_emb: Mixed embedding tensor. + background_emb: Background embeddings for other features. + target_class_idx: Target class index. + time_info: Optional temporal information. + label_data: Optional label information. + + Returns: + Scalar prediction averaged over background samples. + """ + if self.use_embeddings: + logits = self._forward_from_embeddings( + key, mixed_emb, background_emb, time_info, label_data + ) + else: + logits = self._forward_from_inputs( + key, mixed_emb, background_emb, time_info, label_data + ) + + # Extract target class prediction + pred_vec = self._extract_target_prediction(logits, target_class_idx) + + # Average over background samples + return pred_vec.detach().mean() + + def _forward_from_embeddings( + self, + key: str, + mixed_emb: torch.Tensor, + background_emb: Dict[str, torch.Tensor], + time_info: Optional[Dict[str, torch.Tensor]], + label_data: Optional[Dict[str, torch.Tensor]], + ) -> torch.Tensor: + """Forward pass using embeddings. + + Args: + key: Feature key being explained. + mixed_emb: Mixed embedding tensor. + background_emb: Background embeddings. + time_info: Optional temporal information. + label_data: Optional label information. + + Returns: + Model logits. + """ + # Build feature embeddings dictionary + feature_embeddings = {key: mixed_emb} + for fk in self.model.feature_keys: + if fk not in feature_embeddings: + if fk in background_emb: + feature_embeddings[fk] = background_emb[fk].clone() + else: + # Zero fallback + ref_tensor = next(iter(feature_embeddings.values())) + feature_embeddings[fk] = torch.zeros_like(ref_tensor) + + # Prepare time info matching background batch size + time_info_bg = self._prepare_time_info( + time_info, feature_embeddings, mixed_emb.shape[0] + ) + + # Forward pass + with torch.no_grad(): + # Create kwargs with proper label key + forward_kwargs = { + "time_info": time_info_bg, + } + # Add label with the correct key name + if len(self.model.label_keys) > 0: + label_key = self.model.label_keys[0] + forward_kwargs[label_key] = torch.zeros( + (mixed_emb.shape[0], 1), device=self.model.device + ) + + model_output = self.model.forward_from_embedding( + feature_embeddings, + **forward_kwargs + ) + + return self._extract_logits(model_output) + + def _forward_from_inputs( + self, + key: str, + mixed_inputs: torch.Tensor, + background_inputs: Dict[str, torch.Tensor], + time_info: Optional[Dict[str, torch.Tensor]], + label_data: Optional[Dict[str, torch.Tensor]], + ) -> torch.Tensor: + """Forward pass using raw inputs (continuous features). + + Args: + key: Feature key being explained. + mixed_inputs: Mixed input tensor. + background_inputs: Background inputs. + time_info: Optional temporal information. + label_data: Optional label information. + + Returns: + Model logits. + """ + model_inputs = {} + for fk in self.model.feature_keys: + if fk == key: + model_inputs[fk] = mixed_inputs + elif fk in background_inputs: + model_inputs[fk] = background_inputs[fk].clone() + else: + model_inputs[fk] = torch.zeros_like(mixed_inputs) + + # Add label stub if needed + if len(self.model.label_keys) > 0: + label_key = self.model.label_keys[0] + model_inputs[label_key] = torch.zeros( + (mixed_inputs.shape[0], 1), device=mixed_inputs.device + ) + + output = self.model(**model_inputs) + return self._extract_logits(output) + + def _prepare_time_info( + self, + time_info: Optional[Dict[str, torch.Tensor]], + feature_embeddings: Dict[str, torch.Tensor], + n_background: int, + ) -> Optional[Dict[str, torch.Tensor]]: + """Prepare time information to match background batch size. + + Args: + time_info: Original time information. + feature_embeddings: Feature embeddings to match sequence lengths. + n_background: Number of background samples. + + Returns: + Adjusted time information or None. + """ + if time_info is None: + return None + + time_info_bg = {} + for fk, emb in feature_embeddings.items(): + if fk not in time_info or time_info[fk] is None: + continue + + seq_len = emb.shape[1] + t_orig = time_info[fk].to(self.model.device) + + # Normalize to 1D sequence + t_vec = self._normalize_time_vector(t_orig) + + # Adjust length to match embedding sequence length + t_adj = self._adjust_time_length(t_vec, seq_len) + + # Expand to background batch size + time_info_bg[fk] = t_adj.unsqueeze(0).expand(n_background, -1) + + return time_info_bg if time_info_bg else None + + # ------------------------------------------------------------------ + # Weighted least squares solver + # ------------------------------------------------------------------ + def _solve_weighted_least_squares( + self, + coalition_vectors: list, + coalition_preds: list, + coalition_weights: list, + device: torch.device, + ) -> torch.Tensor: + """Solve weighted least squares to estimate SHAP values. + + Uses Tikhonov regularization for numerical stability. + + Args: + coalition_vectors: List of coalition binary vectors. + coalition_preds: List of prediction tensors per coalition. + coalition_weights: List of kernel weights per coalition. + device: Device for computation. + + Returns: + SHAP values with shape (batch_size, n_features). + """ + # Stack collected data + X = torch.stack(coalition_vectors, dim=0).to(device) # (n_coalitions, n_features) + Y = torch.stack(coalition_preds, dim=0).to(device) # (n_coalitions, batch_size) + W = torch.stack(coalition_weights, dim=0).to(device) # (n_coalitions,) + + # Apply sqrt weights for weighted least squares + sqrtW = torch.sqrt(W).unsqueeze(1) # (n_coalitions, 1) + Xw = sqrtW * X # (n_coalitions, n_features) + Yw = sqrtW * Y # (n_coalitions, batch_size) + + # Add Tikhonov regularization + n_features = X.shape[1] + reg_scale = torch.sqrt(torch.tensor(self.regularization, device=device)) + reg_mat = reg_scale * torch.eye(n_features, device=device) + + # Augment for regularized least squares: [Xw; reg_mat] phi = [Yw; 0] + Xw_aug = torch.cat([Xw, reg_mat], dim=0) + Yw_aug = torch.cat( + [Yw, torch.zeros((n_features, Y.shape[1]), device=device)], dim=0 + ) + + # Solve using torch.linalg.lstsq + res = torch.linalg.lstsq(Xw_aug, Yw_aug) + phi_sol = getattr(res, 'solution', res[0]) # (n_features, batch_size) + + # Return per-sample attributions: (batch_size, n_features) + return phi_sol.transpose(0, 1) + + # ------------------------------------------------------------------ + # Background sample generation + # ------------------------------------------------------------------ + def _generate_background_samples( + self, inputs: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + """Generate background samples for SHAP computation. + + Creates reference samples to establish baseline predictions. The sampling + strategy adapts to the feature type: + - Discrete features: Sample uniformly from observed unique values + - Continuous features: Sample uniformly from the range [min, max] + + Args: + inputs: Dictionary mapping feature names to input tensors. + + Returns: + Dictionary mapping feature names to background sample tensors. + """ + background_samples = {} + + for key, x in inputs.items(): + if x.dtype in [torch.int64, torch.int32, torch.long]: + # Discrete features: sample from unique values + unique_vals = torch.unique(x) + samples = unique_vals[ + torch.randint( + len(unique_vals), + (self.n_background_samples,) + x.shape[1:], + ) + ] + else: + # Continuous features: sample from range + min_val = torch.min(x) + max_val = torch.max(x) + samples = torch.rand( + (self.n_background_samples,) + x.shape[1:], device=x.device + ) * (max_val - min_val) + min_val + + background_samples[key] = samples.to(x.device) + + return background_samples + + # ------------------------------------------------------------------ + # Utility helpers + # ------------------------------------------------------------------ + @staticmethod + def _determine_n_features( + key: str, + inputs: Dict[str, torch.Tensor], + embeddings: Dict[str, torch.Tensor], + ) -> int: + """Determine the number of features to explain for a given key. + + Args: + key: Feature key. + inputs: Original input tensors. + embeddings: Embedding tensors. + + Returns: + Number of features (typically sequence length or feature dimension). + """ + # Prefer original input shape + if key in inputs and inputs[key].dim() >= 2: + return inputs[key].shape[1] + + # Fallback to embedding shape + emb = embeddings[key] + if emb.dim() >= 2: + return emb.shape[1] + return emb.shape[-1] + + @staticmethod + + def _compute_kernel_weight(coalition_size: int, n_features: int) -> torch.Tensor: + """Compute Kernel SHAP weight for a coalition. + + Correct formula from Lundberg & Lee (2017): + weight = (M - 1) / (binom(M, |z|) * |z| * (M - |z|)) + + Args: + coalition_size: Number of present features (|z|). + n_features: Total number of features (M). + + Returns: + Scalar tensor with the kernel weight. + """ + M = n_features + z = coalition_size + + # Edge cases (empty or full coalition) + if z == 0 or z == M: + # Assign infinite weight; we approximate with a large number. + return torch.tensor(1000, dtype=torch.float32) + + # Compute binomial coefficient C(M, z) + comb_val = math.comb(M, z) + + # SHAP kernel weight + weight = (M - 1) / (comb_val * z * (M - z)) + + return torch.tensor(weight, dtype=torch.float32) + + @staticmethod + def _extract_logits(model_output) -> torch.Tensor: + """Extract logits from model output. + + Args: + model_output: Model output (dict or tensor). + + Returns: + Logit tensor. + """ + if isinstance(model_output, dict) and "logit" in model_output: + return model_output["logit"] + return model_output + + @staticmethod + def _extract_target_prediction( + logits: torch.Tensor, target_class_idx: Optional[int] + ) -> torch.Tensor: + """Extract target class prediction from logits. + + Args: + logits: Model logits. + target_class_idx: Target class index (None for max prediction). + + Returns: + Target prediction tensor. + """ + if target_class_idx is None: + return torch.max(logits, dim=-1)[0] + + if logits.dim() > 1 and logits.shape[-1] > 1: + return logits[..., target_class_idx] + else: + # Binary classification with single logit + sig = torch.sigmoid(logits.squeeze(-1)) + return sig if target_class_idx == 1 else 1.0 - sig + + @staticmethod + def _normalize_time_vector(time_tensor: torch.Tensor) -> torch.Tensor: + """Normalize time tensor to 1D vector. + + Args: + time_tensor: Time information tensor. + + Returns: + 1D time vector. + """ + if time_tensor.dim() == 2 and time_tensor.shape[0] > 0: + return time_tensor[0].detach() + elif time_tensor.dim() == 1: + return time_tensor.detach() + else: + return time_tensor.reshape(-1).detach() + + @staticmethod + def _adjust_time_length(time_vec: torch.Tensor, target_len: int) -> torch.Tensor: + """Adjust time vector length to match target length. + + Args: + time_vec: 1D time vector. + target_len: Target sequence length. + + Returns: + Adjusted time vector. + """ + current_len = time_vec.numel() + + if current_len == target_len: + return time_vec + elif current_len < target_len: + # Pad by repeating last value + if current_len == 0: + return torch.zeros(target_len, device=time_vec.device) + pad_len = target_len - current_len + pad = time_vec[-1].unsqueeze(0).repeat(pad_len) + return torch.cat([time_vec, pad], dim=0) + else: + # Truncate + return time_vec[:target_len] + + @staticmethod + def _map_to_input_shapes( + shap_values: Dict[str, torch.Tensor], + input_shapes: Dict[str, tuple], + ) -> Dict[str, torch.Tensor]: + """Map SHAP values from embedding space back to input shapes. + + For embedding-based attributions, this projects the attribution scores + from embedding dimensions back to the original input tensor shapes. + + Args: + shap_values: Dictionary of SHAP values in embedding space. + input_shapes: Dictionary of original input shapes. + + Returns: + Dictionary of SHAP values reshaped to match inputs. + """ + mapped = {} + for key, values in shap_values.items(): + if key not in input_shapes: + mapped[key] = values + continue + + orig_shape = input_shapes[key] + + # If shapes already match, no adjustment needed + if values.shape == orig_shape: + mapped[key] = values + continue + + # Reshape to match original input + reshaped = values + while len(reshaped.shape) < len(orig_shape): + reshaped = reshaped.unsqueeze(-1) + + if reshaped.shape != orig_shape: + reshaped = reshaped.expand(orig_shape) + + mapped[key] = reshaped + + return mapped \ No newline at end of file diff --git a/pyhealth/processors/tensor_processor.py b/pyhealth/processors/tensor_processor.py index fe02d89b0..b74b98ac5 100644 --- a/pyhealth/processors/tensor_processor.py +++ b/pyhealth/processors/tensor_processor.py @@ -41,6 +41,11 @@ def process(self, value: Any) -> torch.Tensor: Returns: torch.Tensor: Processed tensor """ + # Prefer to avoid constructing a new tensor from an existing tensor + # which can trigger a UserWarning. If value is already a tensor, + # return a detached clone cast to the requested dtype. + if isinstance(value, torch.Tensor): + return value.detach().clone().to(dtype=self.dtype) return torch.tensor(value, dtype=self.dtype) def size(self) -> None: diff --git a/tests/core/test_shap.py b/tests/core/test_shap.py new file mode 100644 index 000000000..8a6865f42 --- /dev/null +++ b/tests/core/test_shap.py @@ -0,0 +1,1168 @@ +import unittest +from typing import Dict + +import torch +import torch.nn as nn + +from pyhealth.datasets import SampleDataset, get_dataloader +from pyhealth.models import MLP, StageNet, BaseModel +from pyhealth.interpret.methods import ShapExplainer +from pyhealth.interpret.methods.base_interpreter import BaseInterpreter + + +class _SimpleShapModel(BaseModel): + """Minimal model for testing SHAP with continuous inputs.""" + + def __init__(self): + super().__init__(dataset=None) + self.feature_keys = ["x"] + self.label_keys = ["y"] + self.mode = "binary" + + self.linear1 = nn.Linear(3, 4, bias=True) + self.linear2 = nn.Linear(4, 1, bias=True) + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> dict: + hidden = torch.relu(self.linear1(x)) + logit = self.linear2(hidden) + y_prob = torch.sigmoid(logit) + + return { + "logit": logit, + "y_prob": y_prob, + "y_true": y.to(y_prob.device), + "loss": torch.zeros((), device=y_prob.device), + } + + +class _SimpleEmbeddingModel(nn.Module): + """Simple embedding module mapping integer tokens to vectors.""" + + def __init__(self, vocab_size: int = 20, embedding_dim: int = 4): + super().__init__() + self.embedding = nn.Embedding(vocab_size, embedding_dim) + + def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + return {key: self.embedding(value.long()) for key, value in inputs.items()} + + +class _EmbeddingForwardModel(BaseModel): + """Toy model exposing forward_from_embedding for discrete features.""" + + def __init__(self): + super().__init__(dataset=None) + self.feature_keys = ["seq"] + self.label_keys = ["label"] + self.mode = "binary" + + self.embedding_model = _SimpleEmbeddingModel() + self.linear = nn.Linear(4, 1, bias=True) + + def forward_from_embedding( + self, + feature_embeddings: Dict[str, torch.Tensor], + time_info: Dict[str, torch.Tensor] = None, + label: torch.Tensor = None, + ) -> Dict[str, torch.Tensor]: + # Pool embeddings: (batch, seq_len, emb_dim) -> (batch, emb_dim) + pooled = feature_embeddings["seq"].mean(dim=1) + logits = self.linear(pooled) + y_prob = torch.sigmoid(logits) + + return { + "logit": logits, + "y_prob": y_prob, + "loss": torch.zeros((), device=logits.device), + } + + +class _MultiFeatureModel(BaseModel): + """Model with multiple feature inputs for testing multi-feature SHAP.""" + + def __init__(self): + super().__init__(dataset=None) + self.feature_keys = ["x1", "x2"] + self.label_keys = ["y"] + self.mode = "binary" + + self.linear1 = nn.Linear(2, 3, bias=True) + self.linear2 = nn.Linear(2, 3, bias=True) + self.linear_out = nn.Linear(6, 1, bias=True) + + def forward(self, x1: torch.Tensor, x2: torch.Tensor, y: torch.Tensor) -> dict: + h1 = torch.relu(self.linear1(x1)) + h2 = torch.relu(self.linear2(x2)) + combined = torch.cat([h1, h2], dim=-1) + logit = self.linear_out(combined) + y_prob = torch.sigmoid(logit) + + return { + "logit": logit, + "y_prob": y_prob, + "y_true": y.to(y_prob.device), + "loss": torch.zeros((), device=y_prob.device), + } + + +class TestShapExplainerBasic(unittest.TestCase): + """Basic tests for ShapExplainer functionality.""" + + def setUp(self): + self.model = _SimpleShapModel() + self.model.eval() + + # Set deterministic weights + with torch.no_grad(): + self.model.linear1.weight.copy_( + torch.tensor([ + [0.5, -0.3, 0.2], + [0.1, 0.4, -0.1], + [-0.2, 0.3, 0.5], + [0.3, -0.1, 0.2], + ]) + ) + self.model.linear1.bias.copy_(torch.tensor([0.1, -0.1, 0.2, 0.0])) + self.model.linear2.weight.copy_(torch.tensor([[0.4, -0.3, 0.2, 0.1]])) + self.model.linear2.bias.copy_(torch.tensor([0.05])) + + self.labels = torch.zeros((1, 1)) + self.explainer = ShapExplainer( + self.model, + use_embeddings=False, + n_background_samples=50, + max_coalitions=100, + random_seed=42, + ) + + def test_inheritance(self): + """ShapExplainer should inherit from BaseInterpreter.""" + self.assertIsInstance(self.explainer, BaseInterpreter) + + def test_shap_initialization(self): + """Test that ShapExplainer initializes correctly.""" + explainer = ShapExplainer(self.model, use_embeddings=False) + self.assertIsInstance(explainer, ShapExplainer) + self.assertEqual(explainer.model, self.model) + self.assertFalse(explainer.use_embeddings) + self.assertEqual(explainer.n_background_samples, 100) + self.assertEqual(explainer.max_coalitions, 1000) + + def test_attribute_returns_dict(self): + """Attribute method should return dictionary of SHAP values.""" + inputs = torch.tensor([[1.0, 0.5, -0.3]]) + + attributions = self.explainer.attribute( + x=inputs, + y=self.labels, + ) + + self.assertIsInstance(attributions, dict) + self.assertIn("x", attributions) + self.assertEqual(attributions["x"].shape, inputs.shape) + + def test_shap_values_are_tensors(self): + """SHAP values should be PyTorch tensors.""" + inputs = torch.tensor([[0.8, -0.2, 0.5]]) + + attributions = self.explainer.attribute( + x=inputs, + y=self.labels, + ) + + self.assertIsInstance(attributions["x"], torch.Tensor) + self.assertFalse(attributions["x"].requires_grad) + + def test_baseline_generation(self): + """Should generate baseline automatically if not provided.""" + inputs = torch.tensor([[1.0, 0.5, -0.3], [0.5, 1.0, 0.2]]) + + attributions = self.explainer.attribute( + x=inputs, + y=torch.zeros((2, 1)), + ) + + self.assertEqual(attributions["x"].shape, inputs.shape) + + def test_custom_baseline(self): + """Should accept custom baseline dictionary.""" + inputs = torch.tensor([[1.0, 0.5, -0.3]]) + baseline = {"x": torch.zeros((50, 3))} + + attributions = self.explainer.attribute( + baseline=baseline, + x=inputs, + y=self.labels, + ) + + self.assertEqual(attributions["x"].shape, inputs.shape) + + def test_zero_input_produces_small_attributions(self): + """Zero input should produce near-zero attributions with zero baseline.""" + inputs = torch.zeros((1, 3)) + baseline = {"x": torch.zeros((50, 3))} + + attributions = self.explainer.attribute( + baseline=baseline, + x=inputs, + y=self.labels, + ) + + # Attributions should be very small (not exactly zero due to sampling) + self.assertTrue(torch.all(torch.abs(attributions["x"]) < 0.1)) + + def test_target_class_idx_none(self): + """Should handle None target class index (max prediction).""" + inputs = torch.tensor([[1.0, 0.5, -0.3]]) + + attributions = self.explainer.attribute( + x=inputs, + y=self.labels, + target_class_idx=None, + ) + + self.assertIn("x", attributions) + self.assertEqual(attributions["x"].shape, inputs.shape) + + def test_target_class_idx_specified(self): + """Should handle specific target class index.""" + inputs = torch.tensor([[1.0, 0.5, -0.3]]) + + attr_class_0 = self.explainer.attribute( + x=inputs, + y=self.labels, + target_class_idx=0, + ) + + attr_class_1 = self.explainer.attribute( + x=inputs, + y=self.labels, + target_class_idx=1, + ) + + # Attributions should differ for different classes + self.assertFalse(torch.allclose(attr_class_0["x"], attr_class_1["x"], atol=0.01)) + + def test_attribution_values_are_finite(self): + """Test that attribution values are finite (no NaN or Inf).""" + inputs = torch.tensor([[1.0, 0.5, -0.3]]) + + attributions = self.explainer.attribute( + x=inputs, + y=self.labels, + ) + + self.assertTrue(torch.isfinite(attributions["x"]).all()) + + def test_multiple_samples(self): + """Test attribution on batch with multiple samples.""" + inputs = torch.tensor([[1.0, 0.5, -0.3], [0.5, 1.0, 0.2], [-0.5, 0.3, 0.8]]) + + attributions = self.explainer.attribute( + x=inputs, + y=torch.zeros((3, 1)), + ) + + # Check batch dimension + self.assertEqual(attributions["x"].shape[0], 3) + self.assertEqual(attributions["x"].shape, inputs.shape) + + def test_callable_interface(self): + """ShapExplainer instances should be callable via BaseInterpreter.__call__.""" + inputs = torch.tensor([[0.3, -0.4, 0.5]]) + kwargs = {"x": inputs, "y": self.labels} + + from_attribute = self.explainer.attribute(**kwargs) + from_call = self.explainer(**kwargs) + + # Use relaxed tolerances since SHAP is a stochastic approximation method + # and minor variations can occur across different Python/PyTorch versions + torch.testing.assert_close( + from_call["x"], + from_attribute["x"], + rtol=1e-3, # 0.1% relative tolerance + atol=1e-4 # absolute tolerance + ) + + def test_different_n_background_samples(self): + """Test with different numbers of background samples.""" + inputs = torch.tensor([[1.0, 0.5, -0.3]]) + + # Few background samples + explainer_few = ShapExplainer( + self.model, + use_embeddings=False, + n_background_samples=20, + max_coalitions=50, + ) + attr_few = explainer_few.attribute(x=inputs, y=self.labels) + + # More background samples + explainer_many = ShapExplainer( + self.model, + use_embeddings=False, + n_background_samples=100, + max_coalitions=50, + ) + attr_many = explainer_many.attribute(x=inputs, y=self.labels) + + # Both should produce valid output + self.assertEqual(attr_few["x"].shape, inputs.shape) + self.assertEqual(attr_many["x"].shape, inputs.shape) + self.assertTrue(torch.isfinite(attr_few["x"]).all()) + self.assertTrue(torch.isfinite(attr_many["x"]).all()) + + +class TestShapExplainerEmbedding(unittest.TestCase): + """Tests for ShapExplainer with embedding-based models.""" + + def setUp(self): + self.model = _EmbeddingForwardModel() + self.model.eval() + + # Set deterministic weights + with torch.no_grad(): + self.model.linear.weight.copy_(torch.tensor([[0.4, -0.3, 0.2, 0.1]])) + self.model.linear.bias.copy_(torch.tensor([0.05])) + + self.labels = torch.zeros((1, 1)) + self.explainer = ShapExplainer( + self.model, + use_embeddings=True, + n_background_samples=30, + max_coalitions=50, + ) + + def test_embedding_initialization(self): + """Test that ShapExplainer initializes with embedding mode.""" + self.assertTrue(self.explainer.use_embeddings) + self.assertTrue(hasattr(self.model, "forward_from_embedding")) + + def test_attribute_with_embeddings(self): + """Test attribution computation in embedding mode.""" + seq_inputs = torch.tensor([[1, 2, 3]]) + + attributions = self.explainer.attribute( + seq=seq_inputs, + label=self.labels, + ) + + self.assertIn("seq", attributions) + self.assertEqual(attributions["seq"].shape, seq_inputs.shape) + + def test_embedding_attributions_are_finite(self): + """Test that embedding-based attributions are finite.""" + seq_inputs = torch.tensor([[5, 10, 15]]) + + attributions = self.explainer.attribute( + seq=seq_inputs, + label=self.labels, + ) + + self.assertTrue(torch.isfinite(attributions["seq"]).all()) + + def test_embedding_with_time_info(self): + """Test attribution with time information (temporal data).""" + time_tensor = torch.tensor([[0.0, 1.5, 3.0]]) + seq_tensor = torch.tensor([[1, 2, 3]]) + + attributions = self.explainer.attribute( + seq=(time_tensor, seq_tensor), + label=self.labels, + ) + + self.assertIn("seq", attributions) + self.assertEqual(attributions["seq"].shape, seq_tensor.shape) + + def test_embedding_with_custom_baseline(self): + """Test embedding-based SHAP with custom baseline.""" + seq_inputs = torch.tensor([[1, 2, 3]]) + baseline_emb = torch.zeros((30, 3, 4)) # (n_background, seq_len, emb_dim) + + attributions = self.explainer.attribute( + baseline={"seq": baseline_emb}, + seq=seq_inputs, + label=self.labels, + ) + + self.assertEqual(attributions["seq"].shape, seq_inputs.shape) + + def test_embedding_model_without_forward_from_embedding_fails(self): + """Test that using embeddings without forward_from_embedding raises error.""" + model_without_embed = _SimpleShapModel() + + with self.assertRaises(AssertionError): + ShapExplainer(model_without_embed, use_embeddings=True) + + +class TestShapExplainerMultiFeature(unittest.TestCase): + """Tests for ShapExplainer with multiple feature inputs.""" + + def setUp(self): + self.model = _MultiFeatureModel() + self.model.eval() + + # Set deterministic weights + with torch.no_grad(): + self.model.linear1.weight.copy_( + torch.tensor([[0.5, -0.3], [0.1, 0.4], [-0.2, 0.3]]) + ) + self.model.linear2.weight.copy_( + torch.tensor([[0.3, -0.1], [0.2, 0.5], [0.4, -0.2]]) + ) + self.model.linear_out.weight.copy_( + torch.tensor([[0.1, 0.2, -0.1, 0.3, -0.2, 0.15]]) + ) + + self.labels = torch.zeros((1, 1)) + self.explainer = ShapExplainer( + self.model, + use_embeddings=False, + n_background_samples=40, + max_coalitions=60, + ) + + def test_multi_feature_attribution(self): + """Test attribution with multiple feature inputs.""" + x1 = torch.tensor([[1.0, 0.5]]) + x2 = torch.tensor([[-0.3, 0.8]]) + + attributions = self.explainer.attribute( + x1=x1, + x2=x2, + y=self.labels, + ) + + self.assertIn("x1", attributions) + self.assertIn("x2", attributions) + self.assertEqual(attributions["x1"].shape, x1.shape) + self.assertEqual(attributions["x2"].shape, x2.shape) + + def test_multi_feature_with_custom_baselines(self): + """Test multi-feature attribution with custom baselines.""" + x1 = torch.tensor([[1.0, 0.5]]) + x2 = torch.tensor([[-0.3, 0.8]]) + baseline = { + "x1": torch.zeros((40, 2)), + "x2": torch.ones((40, 2)) * 0.5, + } + + attributions = self.explainer.attribute( + baseline=baseline, + x1=x1, + x2=x2, + y=self.labels, + ) + + self.assertEqual(attributions["x1"].shape, x1.shape) + self.assertEqual(attributions["x2"].shape, x2.shape) + + def test_multi_feature_finite_values(self): + """Test that multi-feature attributions are finite.""" + x1 = torch.tensor([[1.0, 0.5], [0.3, -0.2]]) + x2 = torch.tensor([[-0.3, 0.8], [0.5, 0.1]]) + + attributions = self.explainer.attribute( + x1=x1, + x2=x2, + y=torch.zeros((2, 1)), + ) + + self.assertTrue(torch.isfinite(attributions["x1"]).all()) + self.assertTrue(torch.isfinite(attributions["x2"]).all()) + + +class TestShapExplainerMLP(unittest.TestCase): + """Test cases for SHAP with MLP model on real dataset.""" + + def setUp(self): + """Set up test data and model.""" + self.samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "conditions": ["cond-33", "cond-86", "cond-80", "cond-12"], + "procedures": [1.0, 2.0, 3.5, 4], + "label": 0, + }, + { + "patient_id": "patient-1", + "visit_id": "visit-1", + "conditions": ["cond-33", "cond-86", "cond-80"], + "procedures": [5.0, 2.0, 3.5, 4], + "label": 1, + }, + { + "patient_id": "patient-2", + "visit_id": "visit-2", + "conditions": ["cond-55", "cond-12"], + "procedures": [2.0, 3.0, 1.5, 5], + "label": 1, + }, + ] + + # Define input and output schemas + self.input_schema = { + "conditions": "sequence", + "procedures": "tensor", + } + self.output_schema = {"label": "binary"} + + # Create dataset + self.dataset = SampleDataset( + samples=self.samples, + input_schema=self.input_schema, + output_schema=self.output_schema, + dataset_name="test_shap", + ) + + # Create model + self.model = MLP( + dataset=self.dataset, + embedding_dim=32, + hidden_dim=32, + n_layers=2, + ) + self.model.eval() + + # Create dataloader + self.test_loader = get_dataloader(self.dataset, batch_size=1, shuffle=False) + + def test_shap_mlp_basic_attribution(self): + """Test basic SHAP attribution computation with MLP.""" + explainer = ShapExplainer( + self.model, + use_embeddings=True, + n_background_samples=20, + max_coalitions=50, + ) + data_batch = next(iter(self.test_loader)) + + # Compute attributions + attributions = explainer.attribute(**data_batch) + + # Check output structure + self.assertIn("conditions", attributions) + self.assertIn("procedures", attributions) + + # Check shapes match input shapes + self.assertEqual( + attributions["conditions"].shape, data_batch["conditions"].shape + ) + self.assertEqual( + attributions["procedures"].shape, data_batch["procedures"].shape + ) + + # Check that attributions are tensors + self.assertIsInstance(attributions["conditions"], torch.Tensor) + self.assertIsInstance(attributions["procedures"], torch.Tensor) + + def test_shap_mlp_with_target_class(self): + """Test SHAP attribution with specific target class.""" + explainer = ShapExplainer(self.model ) + data_batch = next(iter(self.test_loader)) + + # Compute attributions for class 0 + attr_class_0 = explainer.attribute(**data_batch, target_class_idx=0) + + # Compute attributions for class 1 + attr_class_1 = explainer.attribute(**data_batch, target_class_idx=1) + + # Check that attributions are different for different classes + self.assertFalse( + torch.allclose(attr_class_0["conditions"], attr_class_1["conditions"], atol=0.01) + ) + + def test_shap_mlp_values_finite(self): + """Test that SHAP values are finite (no NaN or Inf).""" + explainer = ShapExplainer( + self.model, + use_embeddings=True, + n_background_samples=20, + max_coalitions=50, + ) + data_batch = next(iter(self.test_loader)) + + attributions = explainer.attribute(**data_batch) + + # Check no NaN or Inf + self.assertTrue(torch.isfinite(attributions["conditions"]).all()) + self.assertTrue(torch.isfinite(attributions["procedures"]).all()) + + def test_shap_mlp_multiple_samples(self): + """Test SHAP on batch with multiple samples.""" + explainer = ShapExplainer( + self.model, + use_embeddings=True, + n_background_samples=20, + max_coalitions=50, + ) + + # Use batch size > 1 + test_loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + data_batch = next(iter(test_loader)) + + attributions = explainer.attribute(**data_batch) + + # Check batch dimension + self.assertEqual(attributions["conditions"].shape[0], 2) + self.assertEqual(attributions["procedures"].shape[0], 2) + + def test_shap_mlp_different_coalitions(self): + """Test SHAP with different numbers of coalitions.""" + data_batch = next(iter(self.test_loader)) + + # Few coalitions + explainer_few = ShapExplainer( + self.model, + use_embeddings=True, + n_background_samples=10, + max_coalitions=20, + ) + attr_few = explainer_few.attribute(**data_batch) + + # More coalitions + explainer_many = ShapExplainer( + self.model, + use_embeddings=True, + n_background_samples=10, + max_coalitions=100, + ) + attr_many = explainer_many.attribute(**data_batch) + + # Both should produce valid output + self.assertIn("conditions", attr_few) + self.assertIn("conditions", attr_many) + self.assertEqual(attr_few["conditions"].shape, attr_many["conditions"].shape) + + +class TestShapExplainerStageNet(unittest.TestCase): + """Test cases for SHAP with StageNet model. + + Note: StageNet tests demonstrate SHAP working with temporal/sequential data. + """ + + def setUp(self): + """Set up test data and StageNet model.""" + self.samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "codes": ([0.0, 2.0, 1.3], ["505800458", "50580045810", "50580045811"]), + "procedures": ( + [0.0, 1.5], + [["A05B", "A05C", "A06A"], ["A11D", "A11E"]], + ), + "lab_values": (None, [[1.0, 2.55, 3.4], [4.1, 5.5, 6.0]]), + "label": 1, + }, + { + "patient_id": "patient-1", + "visit_id": "visit-1", + "codes": ( + [0.0, 2.0, 1.3, 1.0, 2.0], + [ + "55154191800", + "551541928", + "55154192800", + "705182798", + "70518279800", + ], + ), + "procedures": ([0.0], [["A04A", "B035", "C129"]]), + "lab_values": ( + None, + [ + [1.4, 3.2, 3.5], + [4.1, 5.9, 1.7], + [4.5, 5.9, 1.7], + ], + ), + "label": 0, + }, + ] + + # Define input and output schemas + self.input_schema = { + "codes": "stagenet", + "procedures": "stagenet", + "lab_values": "stagenet_tensor", + } + self.output_schema = {"label": "binary"} + + # Create dataset + self.dataset = SampleDataset( + samples=self.samples, + input_schema=self.input_schema, + output_schema=self.output_schema, + dataset_name="test_stagenet_shap", + ) + + # Create StageNet model + self.model = StageNet( + dataset=self.dataset, + embedding_dim=32, + chunk_size=16, + levels=2, + ) + self.model.eval() + + # Create dataloader + self.test_loader = get_dataloader(self.dataset, batch_size=1, shuffle=False) + + def test_shap_initialization_stagenet(self): + """Test that ShapExplainer works with StageNet.""" + explainer = ShapExplainer( + self.model, + use_embeddings=True, + n_background_samples=10, + max_coalitions=30, + ) + self.assertIsInstance(explainer, ShapExplainer) + self.assertEqual(explainer.model, self.model) + + @unittest.skip("StageNet with discrete codes requires special handling for SHAP") + def test_shap_basic_attribution_stagenet(self): + """Test basic SHAP attribution computation with StageNet.""" + explainer = ShapExplainer( + self.model, + use_embeddings=True, + n_background_samples=10, + max_coalitions=30, + ) + data_batch = next(iter(self.test_loader)) + + # Compute attributions + attributions = explainer.attribute(**data_batch) + + # Check output structure + self.assertIn("codes", attributions) + self.assertIn("procedures", attributions) + self.assertIn("lab_values", attributions) + + # Check that attributions are tensors + self.assertIsInstance(attributions["codes"], torch.Tensor) + self.assertIsInstance(attributions["procedures"], torch.Tensor) + self.assertIsInstance(attributions["lab_values"], torch.Tensor) + + @unittest.skip("StageNet with discrete codes requires special handling for SHAP") + def test_shap_attribution_shapes_stagenet(self): + """Test that attribution shapes match input shapes for StageNet.""" + explainer = ShapExplainer( + self.model, + use_embeddings=True, + n_background_samples=10, + max_coalitions=30, + ) + data_batch = next(iter(self.test_loader)) + + attributions = explainer.attribute(**data_batch) + + # For StageNet, inputs are tuples (time, values) + # Attributions should match the values part + _, codes_values = data_batch["codes"] + _, procedures_values = data_batch["procedures"] + _, lab_values = data_batch["lab_values"] + + self.assertEqual(attributions["codes"].shape, codes_values.shape) + self.assertEqual(attributions["procedures"].shape, procedures_values.shape) + self.assertEqual(attributions["lab_values"].shape, lab_values.shape) + + @unittest.skip("StageNet with discrete codes requires special handling for SHAP") + def test_shap_values_finite_stagenet(self): + """Test that StageNet SHAP values are finite.""" + explainer = ShapExplainer( + self.model, + use_embeddings=True, + n_background_samples=10, + max_coalitions=30, + ) + data_batch = next(iter(self.test_loader)) + + attributions = explainer.attribute(**data_batch) + + # Check no NaN or Inf + self.assertTrue(torch.isfinite(attributions["codes"]).all()) + self.assertTrue(torch.isfinite(attributions["procedures"]).all()) + self.assertTrue(torch.isfinite(attributions["lab_values"]).all()) + + +class TestShapExplainerEdgeCases(unittest.TestCase): + """Test edge cases and error handling for ShapExplainer.""" + + def setUp(self): + self.model = _SimpleShapModel() + self.model.eval() + self.labels = torch.zeros((1, 1)) + + def test_discrete_feature_background_generation(self): + """Test background generation for discrete (integer) features.""" + explainer = ShapExplainer( + self.model, + use_embeddings=False, + n_background_samples=30, + ) + + # Use integer inputs + inputs = torch.tensor([[1, 2, 3]], dtype=torch.long) + + # Generate background + background = explainer._generate_background_samples({"x": inputs}) + + self.assertIn("x", background) + self.assertEqual(background["x"].shape[0], 30) # n_background_samples + self.assertEqual(background["x"].dtype, torch.long) + + def test_continuous_feature_background_generation(self): + """Test background generation for continuous features.""" + explainer = ShapExplainer( + self.model, + use_embeddings=False, + n_background_samples=40, + ) + + # Use continuous inputs + inputs = torch.tensor([[1.5, -0.3, 0.8]]) + + # Generate background + background = explainer._generate_background_samples({"x": inputs}) + + self.assertIn("x", background) + self.assertEqual(background["x"].shape[0], 40) + self.assertTrue(background["x"].dtype in [torch.float32, torch.float64]) + + # Check values are within input range + self.assertTrue(torch.all(background["x"] >= inputs.min())) + self.assertTrue(torch.all(background["x"] <= inputs.max())) + + def test_empty_feature_dict(self): + """Test handling of empty feature dictionary.""" + explainer = ShapExplainer( + self.model, + use_embeddings=False, + ) + + # This should not crash + background = explainer._generate_background_samples({}) + self.assertEqual(len(background), 0) + + def test_kernel_weight_computation_edge_cases(self): + """Test kernel weight computation for edge cases.""" + # Empty coalition (size = 0) + weight_empty = ShapExplainer._compute_kernel_weight(0, 5) + self.assertEqual(weight_empty.item(), 1000.0) + + # Full coalition (size = n_features) + weight_full = ShapExplainer._compute_kernel_weight(5, 5) + self.assertEqual(weight_full.item(), 1000.0) + + # Partial coalition + weight_partial = ShapExplainer._compute_kernel_weight(2, 5) + self.assertTrue(weight_partial.item() > 0) + self.assertTrue(torch.isfinite(weight_partial)) + + def test_time_vector_adjustment(self): + """Test time vector length adjustment utilities.""" + # Test padding + time_vec_short = torch.tensor([0.0, 1.0, 2.0]) + adjusted_pad = ShapExplainer._adjust_time_length(time_vec_short, 5) + self.assertEqual(adjusted_pad.shape[0], 5) + self.assertEqual(adjusted_pad[-1].item(), 2.0) # Last value repeated + + # Test truncation + time_vec_long = torch.tensor([0.0, 1.0, 2.0, 3.0, 4.0]) + adjusted_trunc = ShapExplainer._adjust_time_length(time_vec_long, 3) + self.assertEqual(adjusted_trunc.shape[0], 3) + + # Test exact match + time_vec_exact = torch.tensor([0.0, 1.0, 2.0]) + adjusted_exact = ShapExplainer._adjust_time_length(time_vec_exact, 3) + self.assertEqual(adjusted_exact.shape[0], 3) + torch.testing.assert_close(adjusted_exact, time_vec_exact) + + # Test empty vector + time_vec_empty = torch.tensor([]) + adjusted_empty = ShapExplainer._adjust_time_length(time_vec_empty, 3) + self.assertEqual(adjusted_empty.shape[0], 3) + self.assertTrue(torch.all(adjusted_empty == 0)) + + def test_time_vector_normalization(self): + """Test time vector normalization to 1D.""" + # 2D time tensor + time_2d = torch.tensor([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]) + normalized = ShapExplainer._normalize_time_vector(time_2d) + self.assertEqual(normalized.dim(), 1) + self.assertEqual(normalized.shape[0], 3) + + # 1D time tensor + time_1d = torch.tensor([0.0, 1.0, 2.0]) + normalized = ShapExplainer._normalize_time_vector(time_1d) + self.assertEqual(normalized.dim(), 1) + torch.testing.assert_close(normalized, time_1d) + + # Single row 2D + time_single = torch.tensor([[0.0, 1.0, 2.0]]) + normalized = ShapExplainer._normalize_time_vector(time_single) + self.assertEqual(normalized.dim(), 1) + + def test_target_prediction_extraction_binary(self): + """Test target prediction extraction for binary classification.""" + # Single logit (binary classification) + logits_binary = torch.tensor([[0.5], [1.0], [-0.3]]) + + # Class 1 + pred_1 = ShapExplainer._extract_target_prediction(logits_binary, 1) + self.assertEqual(pred_1.shape, (3,)) + self.assertTrue(torch.all((pred_1 >= 0) & (pred_1 <= 1))) + + # Class 0 + pred_0 = ShapExplainer._extract_target_prediction(logits_binary, 0) + self.assertEqual(pred_0.shape, (3,)) + torch.testing.assert_close(pred_0, 1.0 - pred_1) + + # None (max) + pred_max = ShapExplainer._extract_target_prediction(logits_binary, None) + self.assertEqual(pred_max.shape, (3,)) + + def test_target_prediction_extraction_multiclass(self): + """Test target prediction extraction for multi-class classification.""" + logits_multi = torch.tensor([[0.5, 1.0, -0.3], [0.2, 0.8, 0.1]]) + + # Specific class + pred_class_1 = ShapExplainer._extract_target_prediction(logits_multi, 1) + self.assertEqual(pred_class_1.shape, (2,)) + torch.testing.assert_close(pred_class_1, logits_multi[:, 1]) + + # None (max) + pred_max = ShapExplainer._extract_target_prediction(logits_multi, None) + self.assertEqual(pred_max.shape, (2,)) + + def test_logit_extraction_from_dict(self): + """Test logit extraction from model output dictionary.""" + output_dict = {"logit": torch.tensor([[0.5]]), "y_prob": torch.tensor([[0.62]])} + logits = ShapExplainer._extract_logits(output_dict) + torch.testing.assert_close(logits, torch.tensor([[0.5]])) + + def test_logit_extraction_from_tensor(self): + """Test logit extraction from tensor output.""" + output_tensor = torch.tensor([[0.5]]) + logits = ShapExplainer._extract_logits(output_tensor) + torch.testing.assert_close(logits, output_tensor) + + def test_shape_mapping_simple(self): + """Test mapping SHAP values back to input shapes.""" + shap_values = {"x": torch.randn(2, 3)} + input_shapes = {"x": (2, 3)} + + mapped = ShapExplainer._map_to_input_shapes(shap_values, input_shapes) + self.assertEqual(mapped["x"].shape, (2, 3)) + + def test_shape_mapping_expansion(self): + """Test shape expansion when needed.""" + shap_values = {"x": torch.randn(2, 3)} + input_shapes = {"x": (2, 3, 1)} + + mapped = ShapExplainer._map_to_input_shapes(shap_values, input_shapes) + self.assertEqual(mapped["x"].shape, (2, 3, 1)) + + def test_n_features_determination_2d(self): + """Test feature count determination for 2D tensors.""" + inputs = {"x": torch.randn(4, 5)} + embeddings = {"x": torch.randn(4, 5, 8)} + + n_features = ShapExplainer._determine_n_features("x", inputs, embeddings) + self.assertEqual(n_features, 5) + + def test_n_features_determination_3d(self): + """Test feature count determination for 3D tensors.""" + inputs = {"x": torch.randn(2, 6, 4)} + embeddings = {"x": torch.randn(2, 6, 4, 16)} + + n_features = ShapExplainer._determine_n_features("x", inputs, embeddings) + self.assertEqual(n_features, 6) + + def test_regularization_parameter(self): + """Test different regularization parameters.""" + explainer_small_reg = ShapExplainer( + self.model, + use_embeddings=False, + regularization=1e-8, + ) + self.assertEqual(explainer_small_reg.regularization, 1e-8) + + explainer_large_reg = ShapExplainer( + self.model, + use_embeddings=False, + regularization=1e-4, + ) + self.assertEqual(explainer_large_reg.regularization, 1e-4) + + def test_max_coalitions_capping(self): + """Test that coalition count is properly capped.""" + explainer = ShapExplainer( + self.model, + use_embeddings=False, + max_coalitions=50, + ) + self.assertEqual(explainer.max_coalitions, 50) + + # For 3 features, 2^3 = 8 < 50, so it should use 8 + # For 10 features, 2^10 = 1024 > 50, so it should use 50 + + +class TestShapExplainerStateManagement(unittest.TestCase): + """Test state management and repeated calls.""" + + def setUp(self): + self.model = _SimpleShapModel() + self.model.eval() + self.labels = torch.zeros((1, 1)) + self.explainer = ShapExplainer( + self.model, + use_embeddings=False, + n_background_samples=20, + max_coalitions=30, + ) + + def test_repeated_calls_consistency(self): + """Test that repeated calls with same input produce similar results.""" + inputs = torch.tensor([[1.0, 0.5, -0.3]]) + + # Set random seed for reproducibility + torch.manual_seed(42) + attr_1 = self.explainer.attribute(x=inputs, y=self.labels) + + torch.manual_seed(42) + attr_2 = self.explainer.attribute(x=inputs, y=self.labels) + + # Results should be very similar (allowing for minor numerical differences) + torch.testing.assert_close(attr_1["x"], attr_2["x"], atol=1e-4, rtol=1e-3) + + def test_different_inputs_different_outputs(self): + """Test that different inputs produce different attributions.""" + input_1 = torch.tensor([[1.0, 0.5, -0.3]]) + input_2 = torch.tensor([[0.5, 1.0, 0.2]]) + + attr_1 = self.explainer.attribute(x=input_1, y=self.labels) + attr_2 = self.explainer.attribute(x=input_2, y=self.labels) + + # Attributions should be different + self.assertFalse(torch.allclose(attr_1["x"], attr_2["x"], atol=0.01)) + + def test_model_eval_mode_preserved(self): + """Test that model stays in eval mode after attribution.""" + self.model.eval() + inputs = torch.tensor([[1.0, 0.5, -0.3]]) + + self.explainer.attribute(x=inputs, y=self.labels) + + # Model should still be in eval mode + self.assertFalse(self.model.training) + + def test_gradient_cleanup(self): + """Test that gradients are properly cleaned up.""" + inputs = torch.tensor([[1.0, 0.5, -0.3]]) + + # Ensure inputs don't require gradients + self.assertFalse(inputs.requires_grad) + + attributions = self.explainer.attribute(x=inputs, y=self.labels) + + # Attributions should not require gradients + self.assertFalse(attributions["x"].requires_grad) + + +class TestShapExplainerDeviceHandling(unittest.TestCase): + """Test device handling (CPU/CUDA compatibility).""" + + def setUp(self): + self.model = _SimpleShapModel() + self.model.eval() + self.labels = torch.zeros((1, 1)) + + def test_cpu_device(self): + """Test SHAP computation on CPU.""" + self.model.to("cpu") + explainer = ShapExplainer( + self.model, + use_embeddings=False, + n_background_samples=10, + max_coalitions=20, + ) + + inputs = torch.tensor([[1.0, 0.5, -0.3]]) + attributions = explainer.attribute(x=inputs, y=self.labels) + + self.assertEqual(attributions["x"].device.type, "cpu") + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + def test_cuda_device(self): + """Test SHAP computation on CUDA.""" + self.model.to("cuda") + explainer = ShapExplainer( + self.model, + use_embeddings=False, + n_background_samples=10, + max_coalitions=20, + ) + + inputs = torch.tensor([[1.0, 0.5, -0.3]]) + attributions = explainer.attribute(x=inputs, y=self.labels) + + self.assertEqual(attributions["x"].device.type, "cuda") + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + def test_mixed_device_handling(self): + """Test that inputs are moved to model device.""" + self.model.to("cuda") + explainer = ShapExplainer( + self.model, + use_embeddings=False, + n_background_samples=10, + max_coalitions=20, + ) + + # Inputs on CPU + inputs = torch.tensor([[1.0, 0.5, -0.3]]) # CPU + self.assertEqual(inputs.device.type, "cpu") + + # Should still work (inputs moved to CUDA internally) + attributions = explainer.attribute(x=inputs, y=self.labels) + + # Output should be on CUDA + self.assertEqual(attributions["x"].device.type, "cuda") + + +class TestShapExplainerDocumentation(unittest.TestCase): + """Test that docstrings and examples are accurate.""" + + def test_docstring_exists(self): + """Test that main class has docstring.""" + self.assertIsNotNone(ShapExplainer.__doc__) + self.assertGreater(len(ShapExplainer.__doc__), 100) + + def test_init_docstring_exists(self): + """Test that __init__ has docstring.""" + self.assertIsNotNone(ShapExplainer.__init__.__doc__) + + def test_attribute_docstring_exists(self): + """Test that attribute method has docstring.""" + self.assertIsNotNone(ShapExplainer.attribute.__doc__) + + def test_public_methods_have_docstrings(self): + """Test that all public methods have docstrings.""" + public_methods = [ + method for method in dir(ShapExplainer) + if not method.startswith('_') and callable(getattr(ShapExplainer, method)) + ] + + for method_name in public_methods: + method = getattr(ShapExplainer, method_name) + if method_name not in ['train', 'eval', 'parameters']: # Inherited methods + self.assertIsNotNone( + method.__doc__, + f"Method {method_name} missing docstring" + ) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file