Skip to content
Open
Show file tree
Hide file tree
Changes from 50 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
b20e639
Add dask dependency for low memory data processing
LogicFan Nov 21, 2025
0921380
Add dataset cache_dir
LogicFan Nov 21, 2025
61d9c08
Fix typeing
LogicFan Nov 21, 2025
05eb5a1
Convert table csv file to parquet file
LogicFan Nov 21, 2025
4ce8613
Add TODO
LogicFan Nov 21, 2025
b69bcbe
Change load_data to dd.DataFrame
LogicFan Nov 21, 2025
e7c7964
Fix mimic4 for dask
LogicFan Nov 21, 2025
c7b8092
enable collected_global_event_df for Dask
LogicFan Nov 21, 2025
01a0048
Fix unique_patient_ids, stats for Dask
LogicFan Nov 21, 2025
eaa6ba5
Initial Attempt for Patient with Dask dataframe
LogicFan Nov 21, 2025
2593680
Fix patient
LogicFan Nov 21, 2025
006c3ae
Support get_patient for dask
LogicFan Nov 21, 2025
6d38147
Support iter_patients for Dask
LogicFan Nov 21, 2025
f41c4d5
Add overload type hint for Patient
LogicFan Nov 21, 2025
9195229
Update pre_filter signature to Dask
LogicFan Nov 21, 2025
cf23a92
Fix type hint
LogicFan Nov 21, 2025
5d440c7
Chage lab_df to be dask compitable.
LogicFan Nov 21, 2025
d834460
Fix schema inference on csv reader
LogicFan Nov 21, 2025
2a0d7d9
Fix incorrect Dask transform
LogicFan Nov 21, 2025
c769a18
Optimize code
LogicFan Nov 21, 2025
90dee13
revert data back to polars as it it faster
LogicFan Nov 21, 2025
d2faab9
Because patient has reverted
LogicFan Nov 21, 2025
c1d9117
Revert task to use polars, as it's faster
LogicFan Nov 21, 2025
9905ddb
use pl.DataFrame in patient.
LogicFan Nov 21, 2025
91143b6
Fix type conversion issues
LogicFan Nov 21, 2025
4091f7f
Add litdata
LogicFan Nov 21, 2025
5a9424f
Fix Mimic4
LogicFan Nov 21, 2025
1a0b6a0
Works for single worker
LogicFan Nov 21, 2025
f7ea645
Change SampleDataset to IterableDataset
LogicFan Nov 21, 2025
a71218e
Distributed Progress Bar, Bucekt partition
LogicFan Nov 22, 2025
4c6aec8
Better cache system
LogicFan Nov 22, 2025
741f9a6
Better apply task
LogicFan Nov 22, 2025
525c121
Fix bug
LogicFan Nov 22, 2025
791bbda
Fixup
LogicFan Nov 22, 2025
b51cc26
Fixup
LogicFan Nov 22, 2025
54e2018
Fixup
LogicFan Nov 22, 2025
4424359
Fixup
LogicFan Nov 22, 2025
c14af09
Move actual compute ctor
LogicFan Nov 23, 2025
86d04c6
Update Patient to use pandas, because polars will have nested mp issues.
LogicFan Nov 23, 2025
c82c6be
Fix typing
LogicFan Nov 23, 2025
23084c9
Add type safe method for Patient
LogicFan Nov 23, 2025
1f1e2cc
Fix MortalityPredictionStageNetMIMIC4 to pandas
LogicFan Nov 23, 2025
cc42dd3
Fix litdata.optimize hang
LogicFan Nov 23, 2025
cd9a79f
Fix sample dataset
LogicFan Nov 23, 2025
d7a1fcf
add SampleSubset
LogicFan Nov 23, 2025
2d47cc3
Add more test cases
LogicFan Nov 23, 2025
d4f3216
Refactor SampleSubset __init__
LogicFan Nov 23, 2025
a93d1a6
support set_shuffle, add testcase
LogicFan Nov 23, 2025
4dbc4f4
Fix incorrect typehint
LogicFan Nov 23, 2025
66e9e64
Fix splitter
LogicFan Nov 23, 2025
790cb42
Fix task.pre_filter
LogicFan Nov 26, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 101 additions & 28 deletions examples/memtest.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,105 @@
# %%
import psutil, os, time, threading
PEAK_MEM_USAGE = 0
SELF_PROC = psutil.Process(os.getpid())
import torch
from pyhealth.data.data import Patient
from pyhealth.datasets import (
MIMIC4Dataset,
get_dataloader,
split_by_patient,
)
from pyhealth.models import StageNet
from pyhealth.tasks import MortalityPredictionStageNetMIMIC4
from pyhealth.trainer import Trainer
import dask.config

def track_mem():
global PEAK_MEM_USAGE
while True:
m = SELF_PROC.memory_info().rss
if m > PEAK_MEM_USAGE:
PEAK_MEM_USAGE = m
time.sleep(0.1)
import warnings
warnings.filterwarnings("ignore", "pkg_resources is deprecated as an API")

threading.Thread(target=track_mem, daemon=True).start()
print(f"[MEM] start={PEAK_MEM_USAGE / (1024**3)} GB")
if __name__ == "__main__":
dask.config.set({"temporary-directory": "/mnt/tmpfs/"})

base_dataset = MIMIC4Dataset(
ehr_root="/home/logic/physionet.org/files/mimiciv/3.1",
ehr_tables=[
"patients",
"admissions",
"diagnoses_icd",
"procedures_icd",
"labevents",
],
num_workers=2,
mem_per_worker="8GB",
dev=True
)

# %%
from pyhealth.datasets import MIMIC4Dataset
DATASET_DIR = "/home/logic/physionet.org/files/mimiciv/3.1"
dataset = MIMIC4Dataset(
ehr_root=DATASET_DIR,
ehr_tables=[
"patients",
"admissions",
"diagnoses_icd",
"procedures_icd",
"prescriptions",
"labevents",
],
)
print(f"[MEM] __init__={PEAK_MEM_USAGE / (1024**3):.3f} GB")
# %%
print(f"Patients: {base_dataset.unique_patient_ids[:10]}, ...")

# STEP 2: Apply StageNet mortality prediction task
sample_dataset = base_dataset.set_task(
MortalityPredictionStageNetMIMIC4(),
num_workers=2,
)
print(f"Total samples: {len(sample_dataset)}")
print(f"Input schema: {sample_dataset.input_schema}")
print(f"Output schema: {sample_dataset.output_schema}")

# Inspect a sample
sample = next(iter(sample_dataset))
print("\nSample structure:")
print(f" Patient ID: {sample['patient_id']}")
print(f"ICD Codes: {sample['icd_codes']}")
print(f" Labs shape: {len(sample['labs'][0])} timesteps")
print(f" Mortality: {sample['mortality']}")

# STEP 3: Split dataset
train_dataset, val_dataset, test_dataset = split_by_patient(
sample_dataset, [0.8, 0.1, 0.1]
)

# Create dataloaders
train_loader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
val_loader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
test_loader = get_dataloader(test_dataset, batch_size=32, shuffle=False)

# STEP 4: Initialize StageNet model
model = StageNet(
dataset=sample_dataset,
embedding_dim=128,
chunk_size=128,
levels=3,
dropout=0.3,
)

num_params = sum(p.numel() for p in model.parameters())
print(f"\nModel initialized with {num_params} parameters")

# STEP 5: Train the model
trainer = Trainer(
model=model,
device="cpu", # or "cpu"
metrics=["pr_auc", "roc_auc", "accuracy", "f1"],
enable_logging=False,
)
print("\nStarting training...")

trainer.train(
train_dataloader=train_loader,
val_dataloader=val_loader,
epochs=5,
monitor="roc_auc",
optimizer_params={"lr": 1e-5},
)

# STEP 6: Evaluate on test set
results = trainer.evaluate(test_loader)
print("\nTest Results:")
for metric, value in results.items():
print(f" {metric}: {value:.4f}")

# STEP 7: Inspect model predictions
sample_batch = next(iter(test_loader))
with torch.no_grad():
output = model(**sample_batch)

print("\nSample predictions:")
print(f" Predicted probabilities: {output['y_prob'][:5]}")
print(f" True labels: {output['y_true'][:5]}")
75 changes: 43 additions & 32 deletions pyhealth/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from dataclasses import dataclass, field
from datetime import datetime
from functools import reduce
from typing import Dict, List, Mapping, Optional, Union
from typing import Dict, List, Mapping, Optional, Union, Any

import numpy as np
import polars as pl
import pandas as pd


@dataclass(frozen=True)
Expand All @@ -20,9 +20,9 @@ class Event:

event_type: str
timestamp: datetime
attr_dict: Mapping[str, any] = field(default_factory=dict)
attr_dict: Mapping[str, Any] = field(default_factory=dict)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small type hint fix here, any is a function for iterables, and Any is the correct type hint.


def __init__(self, event_type: str, timestamp: datetime = None, **kwargs):
def __init__(self, event_type: str, timestamp: datetime | None = None, **kwargs):
"""Initialize an Event instance.

Args:
Expand Down Expand Up @@ -50,7 +50,7 @@ def __init__(self, event_type: str, timestamp: datetime = None, **kwargs):
object.__setattr__(self, "attr_dict", attr_dict)

@classmethod
def from_dict(cls, d: Dict[str, any]) -> "Event":
def from_dict(cls, d: Dict[str, Any]) -> "Event":
"""Create an Event instance from a dictionary.

Args:
Expand All @@ -61,12 +61,12 @@ def from_dict(cls, d: Dict[str, any]) -> "Event":
"""
timestamp: datetime = d["timestamp"]
event_type: str = d["event_type"]
attr_dict: Dict[str, any] = {
attr_dict: Dict[str, Any] = {
k.split("/", 1)[1]: v for k, v in d.items() if k.split("/")[0] == event_type
}
return cls(event_type=event_type, timestamp=timestamp, attr_dict=attr_dict)

def __getitem__(self, key: str) -> any:
def __getitem__(self, key: str) -> Any:
"""Get an attribute by key.

Args:
Expand Down Expand Up @@ -95,7 +95,7 @@ def __contains__(self, key: str) -> bool:
return True
return key in self.attr_dict

def __getattr__(self, key: str) -> any:
def __getattr__(self, key: str) -> Any:
"""Get an attribute using dot notation.

Args:
Expand All @@ -119,65 +119,76 @@ class Patient:

Attributes:
patient_id (str): Unique patient identifier.
data_source (pl.DataFrame): DataFrame containing all events, sorted by timestamp.
event_type_partitions (Dict[str, pl.DataFrame]): Dictionary mapping event types to their respective DataFrame partitions.
data_source (pd.DataFrame): DataFrame containing all events, sorted by timestamp.
event_type_partitions (Dict[str, pd.DataFrame]): Dictionary mapping event types to their respective DataFrame partitions.
"""

def __init__(self, patient_id: str, data_source: pl.DataFrame) -> None:
def __init__(self, patient_id: str, data_source: pd.DataFrame) -> None:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The polars will spawn processes. Given this will be ran in litdata.optimize in a multi-processes environment, nested multi-process will cause a hang. Thus, we can this to pandas.

"""
Initialize a Patient instance.

Args:
patient_id (str): Unique patient identifier.
data_source (pl.DataFrame): DataFrame containing all events.
data_source (pd.DataFrame): DataFrame containing all events.
"""
self.patient_id = patient_id
self.data_source = data_source.sort("timestamp")
self.event_type_partitions = self.data_source.partition_by("event_type", maintain_order=True, as_dict=True)
self.data_source = data_source.sort_values("timestamp")
self.event_type_partitions = {
(event_type,): group.copy()
for event_type, group in self.data_source.groupby("event_type", sort=False)
}

def _filter_by_time_range_regular(self, df: pl.DataFrame, start: Optional[datetime], end: Optional[datetime]) -> pl.DataFrame:
def _filter_by_time_range_regular(self, df: pd.DataFrame, start: Optional[datetime], end: Optional[datetime]) -> pd.DataFrame:
"""Regular filtering by time. Time complexity: O(n)."""
if start is not None:
df = df.filter(pl.col("timestamp") >= start)
df = df[df["timestamp"] >= start]
if end is not None:
df = df.filter(pl.col("timestamp") <= end)
df = df[df["timestamp"] <= end]
return df

def _filter_by_time_range_fast(self, df: pl.DataFrame, start: Optional[datetime], end: Optional[datetime]) -> pl.DataFrame:
def _filter_by_time_range_fast(self, df: pd.DataFrame, start: Optional[datetime], end: Optional[datetime]) -> pd.DataFrame:
"""Fast filtering by time using binary search on sorted timestamps. Time complexity: O(log n)."""
if start is None and end is None:
return df
df = df.filter(pl.col("timestamp").is_not_null())
ts_col = df["timestamp"].to_numpy()
df = df[df["timestamp"].notna()].sort_values("timestamp")
ts_col = pd.to_datetime(df["timestamp"]).to_numpy(dtype="datetime64[ns]")
start_idx = 0
end_idx = len(ts_col)
if start is not None:
start_idx = np.searchsorted(ts_col, start, side="left")
start_idx = np.searchsorted(ts_col, np.datetime64(start, "ns"), side="left")
if end is not None:
end_idx = np.searchsorted(ts_col, end, side="right")
return df.slice(start_idx, end_idx - start_idx)
end_idx = np.searchsorted(ts_col, np.datetime64(end, "ns"), side="right")
return df.iloc[start_idx:end_idx]

def _filter_by_event_type_regular(self, df: pl.DataFrame, event_type: Optional[str]) -> pl.DataFrame:
def _filter_by_event_type_regular(self, df: pd.DataFrame, event_type: Optional[str]) -> pd.DataFrame:
"""Regular filtering by event type. Time complexity: O(n)."""
if event_type:
df = df.filter(pl.col("event_type") == event_type)
df = df[df["event_type"] == event_type]
return df

def _filter_by_event_type_fast(self, df: pl.DataFrame, event_type: Optional[str]) -> pl.DataFrame:
def _filter_by_event_type_fast(self, df: pd.DataFrame, event_type: Optional[str]) -> pd.DataFrame:
"""Fast filtering by event type using pre-built event type index. Time complexity: O(1)."""
if event_type:
return self.event_type_partitions.get((event_type,), df[:0])
return self.event_type_partitions.get((event_type,), df.iloc[0:0])
else:
return df

def get_events_py(self, **kawargs) -> List[Event]:
"""Type-safe wrapper for get_events."""
return self.get_events(**kawargs, return_df=False) # type: ignore

def get_events_df(self, **kawargs) -> pd.DataFrame:
"""DataFrame wrapper for get_events."""
return self.get_events(**kawargs, return_df=True) # type: ignore

Comment on lines 172 to +183
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added for type hint purpose.

def get_events(
self,
event_type: Optional[str] = None,
start: Optional[datetime] = None,
end: Optional[datetime] = None,
filters: Optional[List[tuple]] = None,
return_df: bool = False,
) -> Union[pl.DataFrame, List[Event]]:
) -> Union[pd.DataFrame, List[Event]]:
"""Get events with optional type and time filters.

Args:
Expand All @@ -191,7 +202,7 @@ def get_events(
and time filters. The logic is "AND" between different filters.

Returns:
Union[pl.DataFrame, List[Event]]: Filtered events as a DataFrame
Union[pd.DataFrame, List[Event]]: Filtered events as a DataFrame
or a list of Event objects.
"""
# faster filtering (by default)
Expand All @@ -213,7 +224,7 @@ def get_events(
f"Invalid filter format: {filt} (must be tuple of (attr, op, value))"
)
attr, op, val = filt
col_expr = pl.col(f"{event_type}/{attr}")
col_expr = df[f"{event_type}/{attr}"]
# Build operator expression
if op == "==":
exprs.append(col_expr == val)
Expand All @@ -230,7 +241,7 @@ def get_events(
else:
raise ValueError(f"Unsupported operator: {op} in filter {filt}")
if exprs:
df = df.filter(reduce(operator.and_, exprs))
df = df.loc[reduce(operator.and_, exprs)]
if return_df:
return df
return [Event.from_dict(d) for d in df.to_dicts()]
return [Event.from_dict(d) for d in df.to_dict(orient="records")] # type: ignore
Loading
Loading