Skip to content

Commit 7dfe6e4

Browse files
authored
Add multimodal RNN support (#797)
* Add multimodal RNN support * add @todos for later
1 parent 6053b55 commit 7dfe6e4

5 files changed

Lines changed: 692 additions & 3 deletions

File tree

docs/api/models/pyhealth.models.RNN.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@ The separate callable RNNLayer and the complete RNN model.
1010
:show-inheritance:
1111

1212
.. autoclass:: pyhealth.models.RNN
13+
:members:
14+
:undoc-members:
15+
:show-inheritance:
16+
17+
.. autoclass:: pyhealth.models.MultimodalRNN
1318
:members:
1419
:undoc-members:
1520
:show-inheritance:
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
"""
2+
Mortality Prediction on MIMIC-IV with MultimodalRNN
3+
4+
This example demonstrates how to use the MultimodalRNN model with mixed
5+
input modalities for in-hospital mortality prediction on MIMIC-IV.
6+
7+
The MultimodalRNN model can handle:
8+
- Sequential features (diagnoses, procedures, lab timeseries) → RNN processing
9+
- Non-sequential features (demographics, static measurements) → Direct embedding
10+
11+
This example shows:
12+
1. Loading MIMIC-IV data with mixed feature types
13+
2. Applying a mortality prediction task
14+
3. Training a MultimodalRNN model with both sequential and non-sequential inputs
15+
4. Evaluating the model performance
16+
"""
17+
18+
from pyhealth.datasets import MIMIC4Dataset
19+
from pyhealth.datasets import split_by_patient, get_dataloader
20+
from pyhealth.models import MultimodalRNN
21+
from pyhealth.tasks import InHospitalMortalityMIMIC4
22+
from pyhealth.trainer import Trainer
23+
24+
25+
if __name__ == "__main__":
26+
# STEP 1: Load MIMIC-IV base dataset
27+
print("=" * 60)
28+
print("STEP 1: Loading MIMIC-IV Dataset")
29+
print("=" * 60)
30+
31+
base_dataset = MIMIC4Dataset(
32+
ehr_root="/srv/local/data/physionet.org/files/mimiciv/2.2/",
33+
ehr_tables=["diagnoses_icd", "procedures_icd", "labevents"],
34+
dev=True, # Use development mode for faster testing
35+
num_workers=4,
36+
)
37+
base_dataset.stats()
38+
39+
# STEP 2: Apply mortality prediction task with multimodal features
40+
print("\n" + "=" * 60)
41+
print("STEP 2: Setting Mortality Prediction Task")
42+
print("=" * 60)
43+
44+
# Use the InHospitalMortalityMIMIC4 task
45+
# This task will create sequential features from diagnoses, procedures, and labs
46+
task = InHospitalMortalityMIMIC4()
47+
sample_dataset = base_dataset.set_task(
48+
task,
49+
num_workers=4,
50+
)
51+
52+
print(f"\nTotal samples: {len(sample_dataset)}")
53+
print(f"Input schema: {sample_dataset.input_schema}")
54+
print(f"Output schema: {sample_dataset.output_schema}")
55+
56+
# Inspect a sample
57+
if len(sample_dataset) > 0:
58+
sample = sample_dataset[0]
59+
print("\nSample structure:")
60+
print(f" Patient ID: {sample['patient_id']}")
61+
for key in sample_dataset.input_schema.keys():
62+
if key in sample:
63+
if isinstance(sample[key], (list, tuple)):
64+
print(f" {key}: length {len(sample[key])}")
65+
else:
66+
print(f" {key}: {type(sample[key])}")
67+
print(f" Mortality: {sample.get('mortality', 'N/A')}")
68+
69+
# STEP 3: Split dataset
70+
print("\n" + "=" * 60)
71+
print("STEP 3: Splitting Dataset")
72+
print("=" * 60)
73+
74+
train_dataset, val_dataset, test_dataset = split_by_patient(
75+
sample_dataset, [0.8, 0.1, 0.1]
76+
)
77+
78+
print(f"Train samples: {len(train_dataset)}")
79+
print(f"Val samples: {len(val_dataset)}")
80+
print(f"Test samples: {len(test_dataset)}")
81+
82+
# Create dataloaders
83+
train_loader = get_dataloader(train_dataset, batch_size=64, shuffle=True)
84+
val_loader = get_dataloader(val_dataset, batch_size=64, shuffle=False)
85+
test_loader = get_dataloader(test_dataset, batch_size=64, shuffle=False)
86+
87+
# STEP 4: Initialize MultimodalRNN model
88+
print("\n" + "=" * 60)
89+
print("STEP 4: Initializing MultimodalRNN Model")
90+
print("=" * 60)
91+
92+
model = MultimodalRNN(
93+
dataset=sample_dataset,
94+
embedding_dim=128,
95+
hidden_dim=128,
96+
rnn_type="GRU",
97+
num_layers=2,
98+
dropout=0.3,
99+
bidirectional=False,
100+
)
101+
102+
num_params = sum(p.numel() for p in model.parameters())
103+
print(f"Model initialized with {num_params:,} parameters")
104+
105+
# Print feature classification
106+
print(f"\nSequential features (RNN processing): {model.sequential_features}")
107+
print(f"Non-sequential features (direct embedding): {model.non_sequential_features}")
108+
109+
# Calculate expected embedding dimensions
110+
seq_dim = len(model.sequential_features) * model.hidden_dim
111+
non_seq_dim = len(model.non_sequential_features) * model.embedding_dim
112+
total_dim = seq_dim + non_seq_dim
113+
print(f"\nPatient representation dimension:")
114+
print(f" Sequential contribution: {seq_dim}")
115+
print(f" Non-sequential contribution: {non_seq_dim}")
116+
print(f" Total: {total_dim}")
117+
118+
# STEP 5: Train the model
119+
print("\n" + "=" * 60)
120+
print("STEP 5: Training Model")
121+
print("=" * 60)
122+
123+
trainer = Trainer(
124+
model=model,
125+
device="cuda:0", # Change to "cpu" if no GPU available
126+
metrics=["pr_auc", "roc_auc", "accuracy", "f1"],
127+
)
128+
129+
trainer.train(
130+
train_dataloader=train_loader,
131+
val_dataloader=val_loader,
132+
epochs=10,
133+
monitor="roc_auc",
134+
optimizer_params={"lr": 1e-3},
135+
)
136+
137+
# STEP 6: Evaluate on test set
138+
print("\n" + "=" * 60)
139+
print("STEP 6: Evaluating on Test Set")
140+
print("=" * 60)
141+
142+
results = trainer.evaluate(test_loader)
143+
print("\nTest Results:")
144+
for metric, value in results.items():
145+
print(f" {metric}: {value:.4f}")
146+
147+
# STEP 7: Demonstrate model predictions
148+
print("\n" + "=" * 60)
149+
print("STEP 7: Sample Predictions")
150+
print("=" * 60)
151+
152+
import torch
153+
154+
sample_batch = next(iter(test_loader))
155+
with torch.no_grad():
156+
output = model(**sample_batch)
157+
158+
print(f"\nBatch size: {output['y_prob'].shape[0]}")
159+
print(f"First 10 predicted probabilities:")
160+
for i, (prob, true_label) in enumerate(
161+
zip(output['y_prob'][:10], output['y_true'][:10])
162+
):
163+
print(f" Sample {i+1}: prob={prob.item():.4f}, true={int(true_label.item())}")
164+
165+
# Summary
166+
print("\n" + "=" * 60)
167+
print("SUMMARY: MultimodalRNN Training Complete")
168+
print("=" * 60)
169+
print(f"Model: MultimodalRNN")
170+
print(f"Dataset: MIMIC-IV")
171+
print(f"Task: In-Hospital Mortality Prediction")
172+
print(f"Sequential features: {len(model.sequential_features)}")
173+
print(f"Non-sequential features: {len(model.non_sequential_features)}")
174+
print(f"Best validation ROC-AUC: {max(results.get('roc_auc', 0), 0):.4f}")
175+
print("=" * 60)
176+

pyhealth/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from .mlp import MLP
1818
from .molerec import MoleRec, MoleRecLayer
1919
from .retain import RETAIN, RETAINLayer
20-
from .rnn import RNN, RNNLayer
20+
from .rnn import MultimodalRNN, RNN, RNNLayer
2121
from .safedrug import SafeDrug, SafeDrugLayer
2222
from .sparcnet import DenseBlock, DenseLayer, SparcNet, TransitionLayer
2323
from .stagenet import StageNet, StageNetLayer

0 commit comments

Comments
 (0)