-
Notifications
You must be signed in to change notification settings - Fork 472
[Memory] Better Memory Utilization for PyHealth #622
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 50 commits
b20e639
0921380
61d9c08
05eb5a1
4ce8613
b69bcbe
e7c7964
c7b8092
01a0048
eaa6ba5
2593680
006c3ae
6d38147
f41c4d5
9195229
cf23a92
5d440c7
d834460
2a0d7d9
c769a18
90dee13
d2faab9
c1d9117
9905ddb
91143b6
4091f7f
5a9424f
1a0b6a0
f7ea645
a71218e
4c6aec8
741f9a6
525c121
791bbda
b51cc26
54e2018
4424359
c14af09
86d04c6
c82c6be
23084c9
1f1e2cc
cc42dd3
cd9a79f
d7a1fcf
2d47cc3
d4f3216
a93d1a6
4dbc4f4
66e9e64
790cb42
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,32 +1,105 @@ | ||
| # %% | ||
| import psutil, os, time, threading | ||
| PEAK_MEM_USAGE = 0 | ||
| SELF_PROC = psutil.Process(os.getpid()) | ||
| import torch | ||
| from pyhealth.data.data import Patient | ||
| from pyhealth.datasets import ( | ||
| MIMIC4Dataset, | ||
| get_dataloader, | ||
| split_by_patient, | ||
| ) | ||
| from pyhealth.models import StageNet | ||
| from pyhealth.tasks import MortalityPredictionStageNetMIMIC4 | ||
| from pyhealth.trainer import Trainer | ||
| import dask.config | ||
|
|
||
| def track_mem(): | ||
| global PEAK_MEM_USAGE | ||
| while True: | ||
| m = SELF_PROC.memory_info().rss | ||
| if m > PEAK_MEM_USAGE: | ||
| PEAK_MEM_USAGE = m | ||
| time.sleep(0.1) | ||
| import warnings | ||
| warnings.filterwarnings("ignore", "pkg_resources is deprecated as an API") | ||
|
|
||
| threading.Thread(target=track_mem, daemon=True).start() | ||
| print(f"[MEM] start={PEAK_MEM_USAGE / (1024**3)} GB") | ||
| if __name__ == "__main__": | ||
| dask.config.set({"temporary-directory": "/mnt/tmpfs/"}) | ||
|
|
||
| base_dataset = MIMIC4Dataset( | ||
| ehr_root="/home/logic/physionet.org/files/mimiciv/3.1", | ||
| ehr_tables=[ | ||
| "patients", | ||
| "admissions", | ||
| "diagnoses_icd", | ||
| "procedures_icd", | ||
| "labevents", | ||
| ], | ||
| num_workers=2, | ||
| mem_per_worker="8GB", | ||
| dev=True | ||
| ) | ||
|
|
||
| # %% | ||
| from pyhealth.datasets import MIMIC4Dataset | ||
| DATASET_DIR = "/home/logic/physionet.org/files/mimiciv/3.1" | ||
| dataset = MIMIC4Dataset( | ||
| ehr_root=DATASET_DIR, | ||
| ehr_tables=[ | ||
| "patients", | ||
| "admissions", | ||
| "diagnoses_icd", | ||
| "procedures_icd", | ||
| "prescriptions", | ||
| "labevents", | ||
| ], | ||
| ) | ||
| print(f"[MEM] __init__={PEAK_MEM_USAGE / (1024**3):.3f} GB") | ||
| # %% | ||
| print(f"Patients: {base_dataset.unique_patient_ids[:10]}, ...") | ||
|
|
||
| # STEP 2: Apply StageNet mortality prediction task | ||
| sample_dataset = base_dataset.set_task( | ||
| MortalityPredictionStageNetMIMIC4(), | ||
| num_workers=2, | ||
| ) | ||
| print(f"Total samples: {len(sample_dataset)}") | ||
| print(f"Input schema: {sample_dataset.input_schema}") | ||
| print(f"Output schema: {sample_dataset.output_schema}") | ||
|
|
||
| # Inspect a sample | ||
| sample = next(iter(sample_dataset)) | ||
| print("\nSample structure:") | ||
| print(f" Patient ID: {sample['patient_id']}") | ||
| print(f"ICD Codes: {sample['icd_codes']}") | ||
| print(f" Labs shape: {len(sample['labs'][0])} timesteps") | ||
| print(f" Mortality: {sample['mortality']}") | ||
|
|
||
| # STEP 3: Split dataset | ||
| train_dataset, val_dataset, test_dataset = split_by_patient( | ||
| sample_dataset, [0.8, 0.1, 0.1] | ||
| ) | ||
|
|
||
| # Create dataloaders | ||
| train_loader = get_dataloader(train_dataset, batch_size=32, shuffle=True) | ||
| val_loader = get_dataloader(val_dataset, batch_size=32, shuffle=False) | ||
| test_loader = get_dataloader(test_dataset, batch_size=32, shuffle=False) | ||
|
|
||
| # STEP 4: Initialize StageNet model | ||
| model = StageNet( | ||
| dataset=sample_dataset, | ||
| embedding_dim=128, | ||
| chunk_size=128, | ||
| levels=3, | ||
| dropout=0.3, | ||
| ) | ||
|
|
||
| num_params = sum(p.numel() for p in model.parameters()) | ||
| print(f"\nModel initialized with {num_params} parameters") | ||
|
|
||
| # STEP 5: Train the model | ||
| trainer = Trainer( | ||
| model=model, | ||
| device="cpu", # or "cpu" | ||
| metrics=["pr_auc", "roc_auc", "accuracy", "f1"], | ||
| enable_logging=False, | ||
| ) | ||
| print("\nStarting training...") | ||
|
|
||
| trainer.train( | ||
| train_dataloader=train_loader, | ||
| val_dataloader=val_loader, | ||
| epochs=5, | ||
| monitor="roc_auc", | ||
| optimizer_params={"lr": 1e-5}, | ||
| ) | ||
|
|
||
| # STEP 6: Evaluate on test set | ||
| results = trainer.evaluate(test_loader) | ||
| print("\nTest Results:") | ||
| for metric, value in results.items(): | ||
| print(f" {metric}: {value:.4f}") | ||
|
|
||
| # STEP 7: Inspect model predictions | ||
| sample_batch = next(iter(test_loader)) | ||
| with torch.no_grad(): | ||
| output = model(**sample_batch) | ||
|
|
||
| print("\nSample predictions:") | ||
| print(f" Predicted probabilities: {output['y_prob'][:5]}") | ||
| print(f" True labels: {output['y_true'][:5]}") |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,10 +2,10 @@ | |
| from dataclasses import dataclass, field | ||
| from datetime import datetime | ||
| from functools import reduce | ||
| from typing import Dict, List, Mapping, Optional, Union | ||
| from typing import Dict, List, Mapping, Optional, Union, Any | ||
|
|
||
| import numpy as np | ||
| import polars as pl | ||
| import pandas as pd | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
|
|
@@ -20,9 +20,9 @@ class Event: | |
|
|
||
| event_type: str | ||
| timestamp: datetime | ||
| attr_dict: Mapping[str, any] = field(default_factory=dict) | ||
| attr_dict: Mapping[str, Any] = field(default_factory=dict) | ||
|
|
||
| def __init__(self, event_type: str, timestamp: datetime = None, **kwargs): | ||
| def __init__(self, event_type: str, timestamp: datetime | None = None, **kwargs): | ||
| """Initialize an Event instance. | ||
|
|
||
| Args: | ||
|
|
@@ -50,7 +50,7 @@ def __init__(self, event_type: str, timestamp: datetime = None, **kwargs): | |
| object.__setattr__(self, "attr_dict", attr_dict) | ||
|
|
||
| @classmethod | ||
| def from_dict(cls, d: Dict[str, any]) -> "Event": | ||
| def from_dict(cls, d: Dict[str, Any]) -> "Event": | ||
| """Create an Event instance from a dictionary. | ||
|
|
||
| Args: | ||
|
|
@@ -61,12 +61,12 @@ def from_dict(cls, d: Dict[str, any]) -> "Event": | |
| """ | ||
| timestamp: datetime = d["timestamp"] | ||
| event_type: str = d["event_type"] | ||
| attr_dict: Dict[str, any] = { | ||
| attr_dict: Dict[str, Any] = { | ||
| k.split("/", 1)[1]: v for k, v in d.items() if k.split("/")[0] == event_type | ||
| } | ||
| return cls(event_type=event_type, timestamp=timestamp, attr_dict=attr_dict) | ||
|
|
||
| def __getitem__(self, key: str) -> any: | ||
| def __getitem__(self, key: str) -> Any: | ||
| """Get an attribute by key. | ||
|
|
||
| Args: | ||
|
|
@@ -95,7 +95,7 @@ def __contains__(self, key: str) -> bool: | |
| return True | ||
| return key in self.attr_dict | ||
|
|
||
| def __getattr__(self, key: str) -> any: | ||
| def __getattr__(self, key: str) -> Any: | ||
| """Get an attribute using dot notation. | ||
|
|
||
| Args: | ||
|
|
@@ -119,65 +119,76 @@ class Patient: | |
|
|
||
| Attributes: | ||
| patient_id (str): Unique patient identifier. | ||
| data_source (pl.DataFrame): DataFrame containing all events, sorted by timestamp. | ||
| event_type_partitions (Dict[str, pl.DataFrame]): Dictionary mapping event types to their respective DataFrame partitions. | ||
| data_source (pd.DataFrame): DataFrame containing all events, sorted by timestamp. | ||
| event_type_partitions (Dict[str, pd.DataFrame]): Dictionary mapping event types to their respective DataFrame partitions. | ||
| """ | ||
|
|
||
| def __init__(self, patient_id: str, data_source: pl.DataFrame) -> None: | ||
| def __init__(self, patient_id: str, data_source: pd.DataFrame) -> None: | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The polars will spawn processes. Given this will be ran in litdata.optimize in a multi-processes environment, nested multi-process will cause a hang. Thus, we can this to pandas. |
||
| """ | ||
| Initialize a Patient instance. | ||
|
|
||
| Args: | ||
| patient_id (str): Unique patient identifier. | ||
| data_source (pl.DataFrame): DataFrame containing all events. | ||
| data_source (pd.DataFrame): DataFrame containing all events. | ||
| """ | ||
| self.patient_id = patient_id | ||
| self.data_source = data_source.sort("timestamp") | ||
| self.event_type_partitions = self.data_source.partition_by("event_type", maintain_order=True, as_dict=True) | ||
| self.data_source = data_source.sort_values("timestamp") | ||
| self.event_type_partitions = { | ||
| (event_type,): group.copy() | ||
| for event_type, group in self.data_source.groupby("event_type", sort=False) | ||
| } | ||
|
|
||
| def _filter_by_time_range_regular(self, df: pl.DataFrame, start: Optional[datetime], end: Optional[datetime]) -> pl.DataFrame: | ||
| def _filter_by_time_range_regular(self, df: pd.DataFrame, start: Optional[datetime], end: Optional[datetime]) -> pd.DataFrame: | ||
| """Regular filtering by time. Time complexity: O(n).""" | ||
| if start is not None: | ||
| df = df.filter(pl.col("timestamp") >= start) | ||
| df = df[df["timestamp"] >= start] | ||
| if end is not None: | ||
| df = df.filter(pl.col("timestamp") <= end) | ||
| df = df[df["timestamp"] <= end] | ||
| return df | ||
|
|
||
| def _filter_by_time_range_fast(self, df: pl.DataFrame, start: Optional[datetime], end: Optional[datetime]) -> pl.DataFrame: | ||
| def _filter_by_time_range_fast(self, df: pd.DataFrame, start: Optional[datetime], end: Optional[datetime]) -> pd.DataFrame: | ||
| """Fast filtering by time using binary search on sorted timestamps. Time complexity: O(log n).""" | ||
| if start is None and end is None: | ||
| return df | ||
| df = df.filter(pl.col("timestamp").is_not_null()) | ||
| ts_col = df["timestamp"].to_numpy() | ||
| df = df[df["timestamp"].notna()].sort_values("timestamp") | ||
| ts_col = pd.to_datetime(df["timestamp"]).to_numpy(dtype="datetime64[ns]") | ||
| start_idx = 0 | ||
| end_idx = len(ts_col) | ||
| if start is not None: | ||
| start_idx = np.searchsorted(ts_col, start, side="left") | ||
| start_idx = np.searchsorted(ts_col, np.datetime64(start, "ns"), side="left") | ||
| if end is not None: | ||
| end_idx = np.searchsorted(ts_col, end, side="right") | ||
| return df.slice(start_idx, end_idx - start_idx) | ||
| end_idx = np.searchsorted(ts_col, np.datetime64(end, "ns"), side="right") | ||
| return df.iloc[start_idx:end_idx] | ||
|
|
||
| def _filter_by_event_type_regular(self, df: pl.DataFrame, event_type: Optional[str]) -> pl.DataFrame: | ||
| def _filter_by_event_type_regular(self, df: pd.DataFrame, event_type: Optional[str]) -> pd.DataFrame: | ||
| """Regular filtering by event type. Time complexity: O(n).""" | ||
| if event_type: | ||
| df = df.filter(pl.col("event_type") == event_type) | ||
| df = df[df["event_type"] == event_type] | ||
| return df | ||
|
|
||
| def _filter_by_event_type_fast(self, df: pl.DataFrame, event_type: Optional[str]) -> pl.DataFrame: | ||
| def _filter_by_event_type_fast(self, df: pd.DataFrame, event_type: Optional[str]) -> pd.DataFrame: | ||
| """Fast filtering by event type using pre-built event type index. Time complexity: O(1).""" | ||
| if event_type: | ||
| return self.event_type_partitions.get((event_type,), df[:0]) | ||
| return self.event_type_partitions.get((event_type,), df.iloc[0:0]) | ||
| else: | ||
| return df | ||
|
|
||
| def get_events_py(self, **kawargs) -> List[Event]: | ||
| """Type-safe wrapper for get_events.""" | ||
| return self.get_events(**kawargs, return_df=False) # type: ignore | ||
|
|
||
| def get_events_df(self, **kawargs) -> pd.DataFrame: | ||
| """DataFrame wrapper for get_events.""" | ||
| return self.get_events(**kawargs, return_df=True) # type: ignore | ||
|
|
||
|
Comment on lines
172
to
+183
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added for type hint purpose. |
||
| def get_events( | ||
| self, | ||
| event_type: Optional[str] = None, | ||
| start: Optional[datetime] = None, | ||
| end: Optional[datetime] = None, | ||
| filters: Optional[List[tuple]] = None, | ||
| return_df: bool = False, | ||
| ) -> Union[pl.DataFrame, List[Event]]: | ||
| ) -> Union[pd.DataFrame, List[Event]]: | ||
| """Get events with optional type and time filters. | ||
|
|
||
| Args: | ||
|
|
@@ -191,7 +202,7 @@ def get_events( | |
| and time filters. The logic is "AND" between different filters. | ||
|
|
||
| Returns: | ||
| Union[pl.DataFrame, List[Event]]: Filtered events as a DataFrame | ||
| Union[pd.DataFrame, List[Event]]: Filtered events as a DataFrame | ||
| or a list of Event objects. | ||
| """ | ||
| # faster filtering (by default) | ||
|
|
@@ -213,7 +224,7 @@ def get_events( | |
| f"Invalid filter format: {filt} (must be tuple of (attr, op, value))" | ||
| ) | ||
| attr, op, val = filt | ||
| col_expr = pl.col(f"{event_type}/{attr}") | ||
| col_expr = df[f"{event_type}/{attr}"] | ||
| # Build operator expression | ||
| if op == "==": | ||
| exprs.append(col_expr == val) | ||
|
|
@@ -230,7 +241,7 @@ def get_events( | |
| else: | ||
| raise ValueError(f"Unsupported operator: {op} in filter {filt}") | ||
| if exprs: | ||
| df = df.filter(reduce(operator.and_, exprs)) | ||
| df = df.loc[reduce(operator.and_, exprs)] | ||
| if return_df: | ||
| return df | ||
| return [Event.from_dict(d) for d in df.to_dicts()] | ||
| return [Event.from_dict(d) for d in df.to_dict(orient="records")] # type: ignore | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small type hint fix here,
anyis a function for iterables, andAnyis the correct type hint.