From b20e639140625a1893869012bd22b00fdbc7c6be Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 21 Nov 2025 02:12:29 -0500 Subject: [PATCH 01/51] Add dask dependency for low memory data processing Co-authored-by: John Wu <54558896+jhnwu3@users.noreply.github.com> --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index ceedcdd0b..b0501bc4c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ dependencies = [ "pandas~=2.3.1", "pandarallel~=1.6.5", "pydantic~=2.11.7", + "dask[dataframe,distributed]~=2025.11.0", ] license = "BSD-3-Clause" license-files = ["LICENSE.md"] From 0921380fc0a1ba5d28c4308d17e6555af3e200a0 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 21 Nov 2025 03:11:56 -0500 Subject: [PATCH 02/51] Add dataset cache_dir Co-authored-by: John Wu <54558896+jhnwu3@users.noreply.github.com> --- pyhealth/datasets/base_dataset.py | 45 ++++++++++++++++++++++++++----- 1 file changed, 39 insertions(+), 6 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 3390453ff..927372ffd 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -6,10 +6,13 @@ from pathlib import Path from typing import Dict, Iterator, List, Optional from urllib.parse import urlparse, urlunparse +import uuid +import json import polars as pl import requests from tqdm import tqdm +import platformdirs from ..data import Patient from ..tasks import BaseTask @@ -108,30 +111,45 @@ class BaseDataset(ABC): def __init__( self, - root: str, + root: str | Path, tables: List[str], - dataset_name: Optional[str] = None, - config_path: Optional[str] = None, + dataset_name: str | None = None, + config_path: str | None = None, + cache_dir: str | Path | None = None, dev: bool = False, ): """Initializes the BaseDataset. Args: - root (str): The root directory where dataset files are stored. + root (str | Path): The root directory where dataset files are stored. tables (List[str]): List of table names to load. - dataset_name (Optional[str]): Name of the dataset. Defaults to class name. - config_path (Optional[str]): Path to the configuration YAML file. + dataset_name (str | None): Name of the dataset. Defaults to class name. + config_path (str | None): Path to the configuration YAML file. + cache_dir (str | Path | None): Directory to cache processed data. If None, a default + cache directory will be created under the platform's cache directory. dev (bool): Whether to run in dev mode (limits to 1000 patients). """ + if config_path is None: + raise ValueError("config_path must be provided") + if len(set(tables)) != len(tables): logger.warning("Duplicate table names in tables list. Removing duplicates.") tables = list(set(tables)) + self.root = root self.tables = tables self.dataset_name = dataset_name or self.__class__.__name__ self.config = load_yaml_config(config_path) self.dev = dev + if cache_dir is None: + cache_dir = platformdirs.user_cache_dir(appname='pyhealth') + logger.info(f"No cache_dir provided. Using default cache for PyHealth: {cache_dir}") + cache_dir = Path(cache_dir) + self.cache_dir = cache_dir / self.uuid() + self.cache_dir.mkdir(parents=True, exist_ok=True) + logger.info(f"Initializing {self.dataset_name} dataset cache directory to {self.cache_dir}") + logger.info( f"Initializing {self.dataset_name} dataset from {self.root} (dev mode: {self.dev})" ) @@ -142,6 +160,21 @@ def __init__( self._collected_global_event_df = None self._unique_patient_ids = None + def uuid(self) -> str: + """Generates a unique identifier for the dataset instance. This is used for creating + cache directories. The UUID is based on the root path, tables, dataset name, and dev mode. + + Returns: + str: A unique identifier string. + """ + id_str = json.dumps({ + "root": str(self.root), + "tables": sorted(self.tables), + "dataset_name": self.dataset_name, + "dev": self.dev, + }, sort_keys=True) + return str(uuid.uuid5(uuid.NAMESPACE_DNS, id_str)) + @property def collected_global_event_df(self) -> pl.DataFrame: """Collects and returns the global event data frame. From 61d9c084045e6909ec83dabc232d9ab19b2b3fc7 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 21 Nov 2025 03:15:45 -0500 Subject: [PATCH 03/51] Fix typeing Co-authored-by: John Wu <54558896+jhnwu3@users.noreply.github.com> --- pyhealth/datasets/base_dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 927372ffd..8a041f394 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -111,7 +111,7 @@ class BaseDataset(ABC): def __init__( self, - root: str | Path, + root: str, tables: List[str], dataset_name: str | None = None, config_path: str | None = None, @@ -121,7 +121,7 @@ def __init__( """Initializes the BaseDataset. Args: - root (str | Path): The root directory where dataset files are stored. + root (str): The root directory where dataset files are stored. tables (List[str]): List of table names to load. dataset_name (str | None): Name of the dataset. Defaults to class name. config_path (str | None): Path to the configuration YAML file. @@ -168,7 +168,7 @@ def uuid(self) -> str: str: A unique identifier string. """ id_str = json.dumps({ - "root": str(self.root), + "root": self.root, "tables": sorted(self.tables), "dataset_name": self.dataset_name, "dev": self.dev, From 05eb5a1339917b3712d553330b09424e7799a8d4 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 21 Nov 2025 03:40:05 -0500 Subject: [PATCH 04/51] Convert table csv file to parquet file Co-authored-by: John Wu <54558896+jhnwu3@users.noreply.github.com> --- pyhealth/datasets/base_dataset.py | 70 ++++++++++++++++++++++++++++++- 1 file changed, 69 insertions(+), 1 deletion(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 8a041f394..d145838a6 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -10,6 +10,9 @@ import json import polars as pl +import dask.dataframe as dd +import pyarrow.csv as pv +import pyarrow.parquet as pq import requests from tqdm import tqdm import platformdirs @@ -58,6 +61,26 @@ def path_exists(path: str) -> bool: else: return Path(path).exists() +def alt_path(path: str) -> str: + """ + Get the alternative path by switching between .csv.gz, .csv, .tsv.gz, and .tsv extensions. + + Args: + path (str): Original file path. + + Returns: + str: Alternative file path. + """ + if path.endswith(".csv.gz"): + return path[:-3] # Remove .gz -> try .csv + elif path.endswith(".csv"): + return f"{path}.gz" # Add .gz -> try .csv.gz + elif path.endswith(".tsv.gz"): + return path[:-3] # Remove .gz -> try .tsv + elif path.endswith(".tsv"): + return f"{path}.gz" # Add .gz -> try .tsv.gz + else: + raise ValueError(f"Path does not have expected extension: {path}") def scan_csv_gz_or_csv_tsv(path: str) -> pl.LazyFrame: """ @@ -147,7 +170,7 @@ def __init__( logger.info(f"No cache_dir provided. Using default cache for PyHealth: {cache_dir}") cache_dir = Path(cache_dir) self.cache_dir = cache_dir / self.uuid() - self.cache_dir.mkdir(parents=True, exist_ok=True) + self.setup_cache_dir() logger.info(f"Initializing {self.dataset_name} dataset cache directory to {self.cache_dir}") logger.info( @@ -175,6 +198,14 @@ def uuid(self) -> str: }, sort_keys=True) return str(uuid.uuid5(uuid.NAMESPACE_DNS, id_str)) + def setup_cache_dir(self) -> None: + """Creates the cache directory structure. + """ + self.cache_dir.mkdir(parents=True, exist_ok=True) + + # Create tables subdirectory to store cached table files + (self.cache_dir / "tables").mkdir(parents=True, exist_ok=True) + @property def collected_global_event_df(self) -> pl.DataFrame: """Collects and returns the global event data frame. @@ -252,6 +283,7 @@ def _to_lower(col_name: str) -> str: csv_path = clean_path(csv_path) logger.info(f"Scanning table: {table_name} from {csv_path}") + _ = self.load_csv_or_tsv(table_name, csv_path) df = scan_csv_gz_or_csv_tsv(csv_path) # Convert column names to lowercase before calling preprocess_func @@ -270,6 +302,7 @@ def _to_lower(col_name: str) -> str: other_csv_path = f"{self.root}/{join_cfg.file_path}" other_csv_path = clean_path(other_csv_path) logger.info(f"Joining with table: {other_csv_path}") + _ = self.load_csv_or_tsv(table_name, csv_path) join_df = scan_csv_gz_or_csv_tsv(other_csv_path) join_df = join_df.rename(_to_lower) join_key = join_cfg.on @@ -321,6 +354,41 @@ def _to_lower(col_name: str) -> str: return event_frame + def load_csv_or_tsv(self, table_name: str, path: str) -> dd.DataFrame: + """Loads a CSV.gz, CSV, TSV.gz, or TSV file into a Dask DataFrame. + + Args: + table_name (str): The name of the table. + path (str): The URL or local path to the .csv, .csv.gz, .tsv, or .tsv.gz file. + Returns: + dd.DataFrame: The loaded Dask DataFrame. + """ + parquet_path = self.cache_dir / "tables" / f"{table_name}.parquet" + + if not path_exists(str(parquet_path)): + # convert .gz file to .parquet file since Dask cannot split on gz files directly + if not path_exists(path): + if not path_exists(alt_path(path)): + raise FileNotFoundError(f"Neither path exists: {path} or {alt_path(path)}") + path = alt_path(path) + + delimiter = '\t' if path.endswith(".tsv") or path.endswith(".tsv.gz") else ',' + csv_reader = pv.open_csv( + path, + read_options=pv.ReadOptions(block_size=1 << 26), # 64 MB + parse_options=pv.ParseOptions(delimiter=delimiter) + ) + with pq.ParquetWriter(parquet_path, csv_reader.schema) as writer: + for batch in csv_reader: + writer.write_batch(batch) + + pass + return dd.read_parquet( + self.cache_dir / "tables" / f"{table_name}.parquet", + split_row_groups=True, # type: ignore + blocksize="64MB", + ) + @property def unique_patient_ids(self) -> List[str]: """Returns a list of unique patient IDs. From 4ce86138534f1031f2f88152184c42c2e7e46463 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 21 Nov 2025 04:08:02 -0500 Subject: [PATCH 05/51] Add TODO Co-authored-by: John Wu <54558896+jhnwu3@users.noreply.github.com> --- pyhealth/datasets/base_dataset.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index d145838a6..610fd2564 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -373,6 +373,8 @@ def load_csv_or_tsv(self, table_name: str, path: str) -> dd.DataFrame: path = alt_path(path) delimiter = '\t' if path.endswith(".tsv") or path.endswith(".tsv.gz") else ',' + # TODO: this may give incorrect type inference for some columns + # if the first block is not representative csv_reader = pv.open_csv( path, read_options=pv.ReadOptions(block_size=1 << 26), # 64 MB From b69bcbe51a60a12aa7c3bbd36e720912fbb36d7c Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 21 Nov 2025 04:18:01 -0500 Subject: [PATCH 06/51] Change load_data to dd.DataFrame --- pyhealth/datasets/base_dataset.py | 86 ++++++++++++++----------------- 1 file changed, 38 insertions(+), 48 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 610fd2564..5017baf98 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -8,8 +8,11 @@ from urllib.parse import urlparse, urlunparse import uuid import json +import functools +import operator import polars as pl +import pandas as pd import dask.dataframe as dd import pyarrow.csv as pv import pyarrow.parquet as pq @@ -247,23 +250,23 @@ def collected_global_event_df(self) -> pl.DataFrame: return self._collected_global_event_df - def load_data(self) -> pl.LazyFrame: + def load_data(self) -> dd.DataFrame: """Loads data from the specified tables. Returns: - pl.LazyFrame: A concatenated lazy frame of all tables. + dd.DataFrame: A concatenated lazy frame of all tables. """ frames = [self.load_table(table.lower()) for table in self.tables] - return pl.concat(frames, how="diagonal") + return dd.concat(frames, axis=0, join="outer") - def load_table(self, table_name: str) -> pl.LazyFrame: + def load_table(self, table_name: str) -> dd.DataFrame: """Loads a table and processes joins if specified. Args: table_name (str): The name of the table to load. Returns: - pl.LazyFrame: The processed lazy frame for the table. + dd.DataFrame: The processed lazy frame for the table. Raises: ValueError: If the table is not found in the config. @@ -272,44 +275,34 @@ def load_table(self, table_name: str) -> pl.LazyFrame: if table_name not in self.config.tables: raise ValueError(f"Table {table_name} not found in config") - def _to_lower(col_name: str) -> str: - lower_name = col_name.lower() - if lower_name != col_name: - logger.warning("Renaming column %s to lowercase %s", col_name, lower_name) - return lower_name - table_cfg = self.config.tables[table_name] csv_path = f"{self.root}/{table_cfg.file_path}" csv_path = clean_path(csv_path) logger.info(f"Scanning table: {table_name} from {csv_path}") - _ = self.load_csv_or_tsv(table_name, csv_path) - df = scan_csv_gz_or_csv_tsv(csv_path) - - # Convert column names to lowercase before calling preprocess_func - df = df.rename(_to_lower) + df: dd.DataFrame = self.load_csv_or_tsv(table_name, csv_path) # Check if there is a preprocessing function for this table + # TODO: we need to update the preprocess function to work with Dask DataFrame + # for all datasets. Only care about MIMIC4 for now. preprocess_func = getattr(self, f"preprocess_{table_name}", None) if preprocess_func is not None: logger.info( f"Preprocessing table: {table_name} with {preprocess_func.__name__}" ) - df = preprocess_func(df) + df: dd.DataFrame = preprocess_func(df) # Handle joins - for join_cfg in table_cfg.join: + for i, join_cfg in enumerate(table_cfg.join): other_csv_path = f"{self.root}/{join_cfg.file_path}" other_csv_path = clean_path(other_csv_path) logger.info(f"Joining with table: {other_csv_path}") - _ = self.load_csv_or_tsv(table_name, csv_path) - join_df = scan_csv_gz_or_csv_tsv(other_csv_path) - join_df = join_df.rename(_to_lower) + join_df: dd.DataFrame = self.load_csv_or_tsv(f"{table_name}_join_{i}", other_csv_path) join_key = join_cfg.on columns = join_cfg.columns how = join_cfg.how - df = df.join(join_df.select([join_key] + columns), on=join_key, how=how) + df: dd.DataFrame = df.merge(join_df[[join_key] + columns], on=join_key, how=how) patient_id_col = table_cfg.patient_id timestamp_col = table_cfg.timestamp @@ -320,37 +313,34 @@ def _to_lower(col_name: str) -> str: if timestamp_col: if isinstance(timestamp_col, list): # Concatenate all timestamp parts in order with no separator - combined_timestamp = pl.concat_str( - [pl.col(col) for col in timestamp_col] - ).str.strptime(pl.Datetime, format=timestamp_format, strict=True) - timestamp_expr = combined_timestamp + timestamp_series: dd.Series = functools.reduce(operator.add, (df[col].astype(str) for col in timestamp_col)) else: # Single timestamp column - timestamp_expr = pl.col(timestamp_col).str.strptime( - pl.Datetime, format=timestamp_format, strict=True - ) + timestamp_series: dd.Series = df[timestamp_col].astype(str) + timestamp_series: dd.Series = dd.to_datetime( + timestamp_series, + format=timestamp_format, + errors="raise", + ) + df: dd.DataFrame = df.assign(timestamp=timestamp_series.astype("datetime64[ms]")) else: - timestamp_expr = pl.lit(None, dtype=pl.Datetime) + df: dd.DataFrame = df.assign(timestamp=pd.NaT) # If patient_id_col is None, use row index as patient_id - patient_id_expr = ( - pl.col(patient_id_col).cast(pl.Utf8) - if patient_id_col - else pl.int_range(0, pl.count()).cast(pl.Utf8) - ) - base_columns = [ - patient_id_expr.alias("patient_id"), - pl.lit(table_name).cast(pl.Utf8).alias("event_type"), - # ms should be sufficient for most cases - timestamp_expr.cast(pl.Datetime(time_unit="ms")).alias("timestamp"), - ] - - # Flatten attribute columns with event_type prefix - attribute_columns = [ - pl.col(attr.lower()).alias(f"{table_name}/{attr}") for attr in attribute_cols - ] - - event_frame = df.select(base_columns + attribute_columns) + if patient_id_col: + df: dd.DataFrame = df.assign(patient_id=df[patient_id_col].astype(str)) + else: + df: dd.DataFrame = df.reset_index(drop=True) + df: dd.DataFrame = df.assign(patient_id=df.index.astype(str)) + + df: dd.DataFrame = df.assign(event_type=table_name) + + rename_attr = {attr: f"{table_name}/{attr}" for attr in attribute_cols} + df: dd.DataFrame = df.rename(columns=rename_attr) + + attr_cols = [rename_attr[attr] for attr in attribute_cols] + final_cols = ["patient_id", "event_type", "timestamp"] + attr_cols + event_frame = df[final_cols] return event_frame From e7c7964201eee121b5a2ff6cf53abf5f2ac74226 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 21 Nov 2025 04:24:33 -0500 Subject: [PATCH 07/51] Fix mimic4 for dask Co-authored-by: John Wu <54558896+jhnwu3@users.noreply.github.com> --- pyhealth/datasets/mimic4.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pyhealth/datasets/mimic4.py b/pyhealth/datasets/mimic4.py index 05321dedb..00b27de98 100644 --- a/pyhealth/datasets/mimic4.py +++ b/pyhealth/datasets/mimic4.py @@ -1,10 +1,10 @@ import logging import os import warnings -from typing import Dict, List, Optional +from typing import List, Optional import pandas as pd -import polars as pl +import dask.dataframe as dd try: import psutil @@ -274,7 +274,7 @@ def __init__( log_memory_usage("Completed MIMIC4Dataset init") - def _combine_data(self) -> pl.LazyFrame: + def _combine_data(self) -> dd.DataFrame: """ Combines data from all initialized sub-datasets into a unified global event dataframe. @@ -293,4 +293,4 @@ def _combine_data(self) -> pl.LazyFrame: if len(frames) == 1: return frames[0] else: - return pl.concat(frames, how="diagonal") + return dd.concat(frames, axis=0, join="outer") From c7b8092a620963c209e1be414958c2810f1482c7 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 21 Nov 2025 05:06:45 -0500 Subject: [PATCH 08/51] enable collected_global_event_df for Dask Co-authored-by: John Wu <54558896+jhnwu3@users.noreply.github.com> --- pyhealth/datasets/base_dataset.py | 82 +++++++++++++------------------ pyhealth/datasets/mimic4.py | 13 +++++ 2 files changed, 48 insertions(+), 47 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 5017baf98..4caa27aaa 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -168,13 +168,8 @@ def __init__( self.config = load_yaml_config(config_path) self.dev = dev - if cache_dir is None: - cache_dir = platformdirs.user_cache_dir(appname='pyhealth') - logger.info(f"No cache_dir provided. Using default cache for PyHealth: {cache_dir}") - cache_dir = Path(cache_dir) - self.cache_dir = cache_dir / self.uuid() - self.setup_cache_dir() - logger.info(f"Initializing {self.dataset_name} dataset cache directory to {self.cache_dir}") + subfolder = self.cache_subfolder(self.root, self.tables, self.dataset_name, self.dev) + self.setup_cache_dir(cache_dir=cache_dir, subfolder=subfolder) logger.info( f"Initializing {self.dataset_name} dataset from {self.root} (dev mode: {self.dev})" @@ -186,7 +181,8 @@ def __init__( self._collected_global_event_df = None self._unique_patient_ids = None - def uuid(self) -> str: + @staticmethod + def cache_subfolder(root: str, tables: List[str], dataset_name: str, dev: bool) -> str: """Generates a unique identifier for the dataset instance. This is used for creating cache directories. The UUID is based on the root path, tables, dataset name, and dev mode. @@ -194,61 +190,53 @@ def uuid(self) -> str: str: A unique identifier string. """ id_str = json.dumps({ - "root": self.root, - "tables": sorted(self.tables), - "dataset_name": self.dataset_name, - "dev": self.dev, + "root": root, + "tables": sorted(tables), + "dataset_name": dataset_name, + "dev": dev, }, sort_keys=True) return str(uuid.uuid5(uuid.NAMESPACE_DNS, id_str)) - def setup_cache_dir(self) -> None: + def setup_cache_dir(self, cache_dir: str | Path | None = None, subfolder: str = str(uuid.uuid4())) -> None: """Creates the cache directory structure. + + Args: + cache_dir (str | Path | None): The base cache directory. If None, a default cache + directory will be created under the platform's cache directory. + subfolder (str): Subfolder name for this dataset instance's cache. """ - self.cache_dir.mkdir(parents=True, exist_ok=True) - + if cache_dir is None: + cache_dir = platformdirs.user_cache_dir(appname='pyhealth') + logger.info(f"No cache_dir provided. Using default cache for PyHealth: {cache_dir}") + cache_dir = Path(cache_dir) + self.cache_dir = cache_dir / subfolder + + self.cache_dir.mkdir(parents=True, exist_ok=True) # Create tables subdirectory to store cached table files (self.cache_dir / "tables").mkdir(parents=True, exist_ok=True) + # Create global_event_df subdirectory to store cached global event dataframe + (self.cache_dir / "global_event_df").mkdir(parents=True, exist_ok=True) + logger.info(f"Initializing {self.dataset_name} dataset cache directory to {self.cache_dir}") + @property - def collected_global_event_df(self) -> pl.DataFrame: + def collected_global_event_df(self) -> dd.DataFrame: """Collects and returns the global event data frame. Returns: - pl.DataFrame: The collected global event data frame. + dd.DataFrame: The collected global event data frame. """ - if self._collected_global_event_df is None: - logger.info("Collecting global event dataframe...") + path = self.cache_dir / "global_event_df" / "cached.parquet" - # Collect the dataframe - with dev mode limiting if applicable - df = self.global_event_df - # TODO: dev doesn't seem to improve the speed / memory usage + if not path_exists(str(path)): if self.dev: - # Limit the number of patients in dev mode - logger.info("Dev mode enabled: limiting to 1000 patients") - limited_patients = df.select(pl.col("patient_id")).unique().limit(1000) - df = df.join(limited_patients, on="patient_id", how="inner") - - self._collected_global_event_df = df.collect() - - # Profile the Polars collect() operation (commented out by default) - # self._collected_global_event_df, profile = df.profile() - # profile = profile.with_columns([ - # (pl.col("end") - pl.col("start")).alias("duration"), - # ]) - # profile = profile.with_columns([ - # (pl.col("duration") / profile["duration"].sum() * 100).alias("percentage") - # ]) - # profile = profile.sort("duration", descending=True) - # with pl.Config() as cfg: - # cfg.set_tbl_rows(-1) - # cfg.set_fmt_str_lengths(200) - # print(profile) - - logger.info( - f"Collected dataframe with shape: {self._collected_global_event_df.shape}" - ) + patients = self.global_event_df["patient_id"].unique().head(1000).tolist() + filter = self.global_event_df["patient_id"].isin(patients) + self.global_event_df[filter].to_parquet(path) + else: + self.global_event_df.to_parquet(path) - return self._collected_global_event_df + return dd.read_parquet(str(path)) def load_data(self) -> dd.DataFrame: """Loads data from the specified tables. diff --git a/pyhealth/datasets/mimic4.py b/pyhealth/datasets/mimic4.py index 00b27de98..aca96f16c 100644 --- a/pyhealth/datasets/mimic4.py +++ b/pyhealth/datasets/mimic4.py @@ -2,6 +2,7 @@ import os import warnings from typing import List, Optional +from pathlib import Path import pandas as pd import dask.dataframe as dd @@ -209,6 +210,7 @@ def __init__( ehr_config_path: Optional[str] = None, note_config_path: Optional[str] = None, cxr_config_path: Optional[str] = None, + cache_dir: str | Path | None = None, dataset_name: str = "mimic4", dev: bool = False, ): @@ -240,6 +242,7 @@ def __init__( root=ehr_root, tables=ehr_tables, config_path=ehr_config_path, + cache_dir=cache_dir, ) log_memory_usage("After EHR dataset initialization") @@ -250,6 +253,7 @@ def __init__( root=note_root, tables=note_tables, config_path=note_config_path, + cache_dir=cache_dir, ) log_memory_usage("After Note dataset initialization") @@ -260,9 +264,18 @@ def __init__( root=cxr_root, tables=cxr_tables, config_path=cxr_config_path, + cache_dir=cache_dir, ) log_memory_usage("After CXR dataset initialization") + subfolder = BaseDataset.cache_subfolder( + str(ehr_root) + str(note_root) + str(cxr_root), + ehr_tables + note_tables + cxr_tables, + self.dataset_name, + self.dev + ) + self.setup_cache_dir(cache_dir=cache_dir, subfolder=subfolder) + # Combine data from all sub-datasets log_memory_usage("Before combining data") self.global_event_df = self._combine_data() From 01a0048b281c9070f7977d85df6dec92096a030b Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 21 Nov 2025 05:15:50 -0500 Subject: [PATCH 09/51] Fix unique_patient_ids, stats for Dask Co-authored-by: John Wu <54558896+jhnwu3@users.noreply.github.com> --- pyhealth/datasets/base_dataset.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 4caa27aaa..5176e8d28 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -378,10 +378,10 @@ def unique_patient_ids(self) -> List[str]: """ if self._unique_patient_ids is None: self._unique_patient_ids = ( - self.collected_global_event_df.select("patient_id") + self.collected_global_event_df["patient_id"] .unique() - .to_series() - .to_list() + .compute() + .tolist() ) logger.info(f"Found {len(self._unique_patient_ids)} unique patient IDs") return self._unique_patient_ids @@ -421,10 +421,12 @@ def iter_patients(self, df: Optional[pl.LazyFrame] = None) -> Iterator[Patient]: def stats(self) -> None: """Prints statistics about the dataset.""" df = self.collected_global_event_df + n_patients = df["patient_id"].nunique().compute() + n_events = df.shape[0].compute() print(f"Dataset: {self.dataset_name}") print(f"Dev mode: {self.dev}") - print(f"Number of patients: {df['patient_id'].n_unique()}") - print(f"Number of events: {df.height}") + print(f"Number of patients: {n_patients}") + print(f"Number of events: {n_events}") @property def default_task(self) -> Optional[BaseTask]: From eaa6ba576780fe55fbf043110cf50ef0c0197ef4 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 21 Nov 2025 05:56:31 -0500 Subject: [PATCH 10/51] Initial Attempt for Patient with Dask dataframe Co-authored-by: John Wu <54558896+jhnwu3@users.noreply.github.com> --- pyhealth/data/data.py | 149 ++++++++++++++++++------------------------ 1 file changed, 62 insertions(+), 87 deletions(-) diff --git a/pyhealth/data/data.py b/pyhealth/data/data.py index 2a6d3a45c..bad240bc6 100644 --- a/pyhealth/data/data.py +++ b/pyhealth/data/data.py @@ -1,12 +1,10 @@ import operator from dataclasses import dataclass, field from datetime import datetime -from functools import reduce -from typing import Dict, List, Mapping, Optional, Union - -import numpy as np -import polars as pl +from typing import Dict, List, Mapping, Optional, Union, Any +import dask.dataframe as dd +import pandas as pd @dataclass(frozen=True) class Event: @@ -20,9 +18,9 @@ class Event: event_type: str timestamp: datetime - attr_dict: Mapping[str, any] = field(default_factory=dict) + attr_dict: Mapping[str, Any] = field(default_factory=dict) - def __init__(self, event_type: str, timestamp: datetime = None, **kwargs): + def __init__(self, event_type: str, timestamp: datetime | None = None, **kwargs): """Initialize an Event instance. Args: @@ -50,23 +48,22 @@ def __init__(self, event_type: str, timestamp: datetime = None, **kwargs): object.__setattr__(self, "attr_dict", attr_dict) @classmethod - def from_dict(cls, d: Dict[str, any]) -> "Event": + def from_dict(cls, d: Dict[str, Any]) -> "Event": """Create an Event instance from a dictionary. Args: - d (Dict[str, any]): Dictionary containing event data. - + d (Dict[str, Any]): Dictionary containing event data. Returns: Event: An instance of the Event class. """ timestamp: datetime = d["timestamp"] event_type: str = d["event_type"] - attr_dict: Dict[str, any] = { + attr_dict: Dict[str, Any] = { k.split("/", 1)[1]: v for k, v in d.items() if k.split("/")[0] == event_type } return cls(event_type=event_type, timestamp=timestamp, attr_dict=attr_dict) - def __getitem__(self, key: str) -> any: + def __getitem__(self, key: str) -> Any: """Get an attribute by key. Args: @@ -95,7 +92,7 @@ def __contains__(self, key: str) -> bool: return True return key in self.attr_dict - def __getattr__(self, key: str) -> any: + def __getattr__(self, key: str) -> Any: """Get an attribute using dot notation. Args: @@ -119,56 +116,63 @@ class Patient: Attributes: patient_id (str): Unique patient identifier. - data_source (pl.DataFrame): DataFrame containing all events, sorted by timestamp. - event_type_partitions (Dict[str, pl.DataFrame]): Dictionary mapping event types to their respective DataFrame partitions. + data_source (dd.DataFrame): Dask DataFrame containing all events. """ - def __init__(self, patient_id: str, data_source: pl.DataFrame) -> None: - """ - Initialize a Patient instance. + def __init__(self, patient_id: str, data_source: dd.DataFrame) -> None: + """Initialize a Patient instance. Args: patient_id (str): Unique patient identifier. - data_source (pl.DataFrame): DataFrame containing all events. + data_source (dd.DataFrame): DataFrame containing all events. """ self.patient_id = patient_id - self.data_source = data_source.sort("timestamp") - self.event_type_partitions = self.data_source.partition_by("event_type", maintain_order=True, as_dict=True) + self.data_source = data_source - def _filter_by_time_range_regular(self, df: pl.DataFrame, start: Optional[datetime], end: Optional[datetime]) -> pl.DataFrame: - """Regular filtering by time. Time complexity: O(n).""" + def _filter_by_time_range(self, df: dd.DataFrame, start: Optional[datetime], end: Optional[datetime]) -> dd.DataFrame: + """Filter events by time range using lazy Dask operations.""" if start is not None: - df = df.filter(pl.col("timestamp") >= start) + df = df[df["timestamp"] >= start] if end is not None: - df = df.filter(pl.col("timestamp") <= end) + df = df[df["timestamp"] <= end] return df - def _filter_by_time_range_fast(self, df: pl.DataFrame, start: Optional[datetime], end: Optional[datetime]) -> pl.DataFrame: - """Fast filtering by time using binary search on sorted timestamps. Time complexity: O(log n).""" - if start is None and end is None: - return df - df = df.filter(pl.col("timestamp").is_not_null()) - ts_col = df["timestamp"].to_numpy() - start_idx = 0 - end_idx = len(ts_col) - if start is not None: - start_idx = np.searchsorted(ts_col, start, side="left") - if end is not None: - end_idx = np.searchsorted(ts_col, end, side="right") - return df.slice(start_idx, end_idx - start_idx) - - def _filter_by_event_type_regular(self, df: pl.DataFrame, event_type: Optional[str]) -> pl.DataFrame: - """Regular filtering by event type. Time complexity: O(n).""" + def _filter_by_event_type(self, df: dd.DataFrame, event_type: Optional[str]) -> dd.DataFrame: + """Filter by event type if provided.""" if event_type: - df = df.filter(pl.col("event_type") == event_type) + df = df[df["event_type"] == event_type] return df - def _filter_by_event_type_fast(self, df: pl.DataFrame, event_type: Optional[str]) -> pl.DataFrame: - """Fast filtering by event type using pre-built event type index. Time complexity: O(1).""" - if event_type: - return self.event_type_partitions.get((event_type,), df[:0]) - else: - return df + def _apply_attribute_filters( + self, df: dd.DataFrame, event_type: str, filters: List[tuple] + ) -> dd.DataFrame: + """Apply attribute-level filters to the DataFrame.""" + op_map = { + "==": operator.eq, + "!=": operator.ne, + "<": operator.lt, + "<=": operator.le, + ">": operator.gt, + ">=": operator.ge, + } + mask = None + for filt in filters: + if not (isinstance(filt, tuple) and len(filt) == 3): + raise ValueError( + f"Invalid filter format: {filt} (must be tuple of (attr, op, value))" + ) + attr, op, val = filt + if op not in op_map: + raise ValueError(f"Unsupported operator: {op} in filter {filt}") + col_name = f"{event_type}/{attr}" + if col_name not in df.columns: + raise KeyError(f"Column '{col_name}' not found in dataset") + col = df[col_name] + condition = op_map[op](col, val) + mask = condition if mask is None else mask & condition + if mask is not None: + df = df[mask] + return df def get_events( self, @@ -177,60 +181,31 @@ def get_events( end: Optional[datetime] = None, filters: Optional[List[tuple]] = None, return_df: bool = False, - ) -> Union[pl.DataFrame, List[Event]]: + ) -> Union[dd.DataFrame, List[Event]]: """Get events with optional type and time filters. Args: event_type (Optional[str]): Type of events to filter. start (Optional[datetime]): Start time for filtering events. end (Optional[datetime]): End time for filtering events. - return_df (bool): Whether to return a DataFrame or a list of + return_df (bool): Whether to return a pandas DataFrame or a list of Event objects. filters (Optional[List[tuple]]): Additional filters as [(attr, op, value), ...], e.g.: [("attr1", "!=", "abnormal"), ("attr2", "!=", 1)]. Filters are applied after type and time filters. The logic is "AND" between different filters. Returns: - Union[pl.DataFrame, List[Event]]: Filtered events as a DataFrame + Union[dd.DataFrame, List[Event]]: Filtered events as a Dask DataFrame or a list of Event objects. """ - # faster filtering (by default) - df = self._filter_by_event_type_fast(self.data_source, event_type) - df = self._filter_by_time_range_fast(df, start, end) + df = self._filter_by_event_type(self.data_source, event_type) + df = self._filter_by_time_range(df, start, end) - # regular filtering (commented out by default) - # df = self._filter_by_event_type_regular(self.data_source, event_type) - # df = self._filter_by_time_range_regular(df, start, end) - - if filters: + active_filters = filters or [] + if active_filters: assert event_type is not None, "event_type must be provided if filters are provided" - else: - filters = [] - exprs = [] - for filt in filters: - if not (isinstance(filt, tuple) and len(filt) == 3): - raise ValueError( - f"Invalid filter format: {filt} (must be tuple of (attr, op, value))" - ) - attr, op, val = filt - col_expr = pl.col(f"{event_type}/{attr}") - # Build operator expression - if op == "==": - exprs.append(col_expr == val) - elif op == "!=": - exprs.append(col_expr != val) - elif op == "<": - exprs.append(col_expr < val) - elif op == "<=": - exprs.append(col_expr <= val) - elif op == ">": - exprs.append(col_expr > val) - elif op == ">=": - exprs.append(col_expr >= val) - else: - raise ValueError(f"Unsupported operator: {op} in filter {filt}") - if exprs: - df = df.filter(reduce(operator.and_, exprs)) + df = self._apply_attribute_filters(df, event_type, active_filters) + if return_df: return df - return [Event.from_dict(d) for d in df.to_dicts()] + return [Event.from_dict(d) for d in df.to_dict("records")] \ No newline at end of file From 2593680fd6d66f9ea5c99f242383fc69c4ac802c Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 21 Nov 2025 06:01:51 -0500 Subject: [PATCH 11/51] Fix patient Co-authored-by: John Wu <54558896+jhnwu3@users.noreply.github.com> --- pyhealth/data/data.py | 4 ++- tests/core/test_patient.py | 69 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 1 deletion(-) create mode 100644 tests/core/test_patient.py diff --git a/pyhealth/data/data.py b/pyhealth/data/data.py index bad240bc6..6369c3720 100644 --- a/pyhealth/data/data.py +++ b/pyhealth/data/data.py @@ -208,4 +208,6 @@ def get_events( if return_df: return df - return [Event.from_dict(d) for d in df.to_dict("records")] \ No newline at end of file + # Dask DataFrames do not expose .to_dict on lazy expressions; compute to pandas first. + records = df.compute().to_dict("records") + return [Event.from_dict(d) for d in records] diff --git a/tests/core/test_patient.py b/tests/core/test_patient.py new file mode 100644 index 000000000..a6d0cdc8a --- /dev/null +++ b/tests/core/test_patient.py @@ -0,0 +1,69 @@ +import unittest +from datetime import datetime + +import dask.dataframe as dd +import pandas as pd + +from pyhealth.data import Patient + + +class TestPatientGetEvents(unittest.TestCase): + def setUp(self): + timestamps = [ + datetime(2021, 1, 1), + datetime(2021, 1, 5), + datetime(2021, 2, 1), + ] + pdf = pd.DataFrame( + { + "patient_id": ["p1", "p1", "p1"], + "event_type": ["labs", "labs", "visit"], + "timestamp": timestamps, + "labs/result": [1.0, 2.0, None], + "labs/unit": ["mg/dL", "mg/dL", None], + "visit/location": [None, None, "icu"], + } + ) + self.ddf = dd.from_pandas(pdf, npartitions=1) + self.patient = Patient(patient_id="p1", data_source=self.ddf) + + def test_returns_event_objects_by_default(self): + events = self.patient.get_events() + self.assertEqual(len(events), 3) + self.assertEqual( + sorted([e.event_type for e in events]), ["labs", "labs", "visit"] + ) + self.assertEqual(events[0].attr_dict["result"], 1.0) + + def test_return_df_flag(self): + labs_df = self.patient.get_events(event_type="labs", return_df=True) + labs_pdf = labs_df.compute() + self.assertEqual(len(labs_pdf), 2) + self.assertTrue((labs_pdf["event_type"] == "labs").all()) + + def test_time_range_filter(self): + start = datetime(2021, 1, 2) + end = datetime(2021, 1, 31) + events = self.patient.get_events(start=start, end=end) + self.assertEqual(len(events), 1) + self.assertEqual(events[0].timestamp, datetime(2021, 1, 5)) + + def test_event_type_and_attribute_filters(self): + filters = [("result", ">=", 2)] + events = self.patient.get_events(event_type="labs", filters=filters) + self.assertEqual(len(events), 1) + self.assertEqual(events[0].attr_dict["result"], 2.0) + + def test_filters_require_event_type(self): + with self.assertRaises(AssertionError): + self.patient.get_events(filters=[("result", "==", 1)]) + + def test_missing_column_in_filters_raises(self): + with self.assertRaises(KeyError): + self.patient.get_events( + event_type="labs", filters=[("does_not_exist", "==", 1)] + ) + + +if __name__ == "__main__": + unittest.main() From 006c3ae12bc55bcd4f4b46870579bf694bdfa278 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 21 Nov 2025 06:09:56 -0500 Subject: [PATCH 12/51] Support get_patient for dask Co-authored-by: John Wu <54558896+jhnwu3@users.noreply.github.com> --- pyhealth/datasets/base_dataset.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 5176e8d28..a22ada840 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -401,8 +401,12 @@ def get_patient(self, patient_id: str) -> Patient: assert ( patient_id in self.unique_patient_ids ), f"Patient {patient_id} not found in dataset" - df = self.collected_global_event_df.filter(pl.col("patient_id") == patient_id) - return Patient(patient_id=patient_id, data_source=df) + df = self.collected_global_event_df + if not isinstance(df, dd.DataFrame): + raise TypeError("collected_global_event_df must be a Dask DataFrame") + + patient_df = df[df["patient_id"] == patient_id] + return Patient(patient_id=patient_id, data_source=patient_df) def iter_patients(self, df: Optional[pl.LazyFrame] = None) -> Iterator[Patient]: """Yields Patient objects for each unique patient in the dataset. From 6d38147084734f454946b9f86431bad1797087b1 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 21 Nov 2025 06:21:49 -0500 Subject: [PATCH 13/51] Support iter_patients for Dask Co-authored-by: John Wu <54558896+jhnwu3@users.noreply.github.com> --- pyhealth/datasets/base_dataset.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index a22ada840..ff926176e 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -408,19 +408,20 @@ def get_patient(self, patient_id: str) -> Patient: patient_df = df[df["patient_id"] == patient_id] return Patient(patient_id=patient_id, data_source=patient_df) - def iter_patients(self, df: Optional[pl.LazyFrame] = None) -> Iterator[Patient]: - """Yields Patient objects for each unique patient in the dataset. + def iter_patients(self, df: Optional[dd.DataFrame] = None) -> Iterator[Patient]: + """Yields Patient objects for each unique patient in the dataset. + This method is inefficient, you should prefer to use + `self.colllected_global_event_df.groupby(("patient_id", )).apply(...)` directly + if possible. Yields: Iterator[Patient]: An iterator over Patient objects. """ if df is None: df = self.collected_global_event_df - grouped = df.group_by("patient_id") - for patient_id, patient_df in grouped: - patient_id = patient_id[0] - yield Patient(patient_id=patient_id, data_source=patient_df) + for patitent_id in self.unique_patient_ids: + yield self.get_patient(patitent_id) def stats(self) -> None: """Prints statistics about the dataset.""" @@ -482,6 +483,11 @@ def set_task( f"Setting task {task.task_name} for {self.dataset_name} base dataset..." ) + if cache_dir is not None: + logger.warning(f"This argument cache_dir is deprecated. Use dataset cache_dir instead.") + if cache_format != "parquet": + logger.warning(f"Only 'parquet' cache_format is officially supported now.") + # Check for cached data if cache_dir is provided samples = None if cache_dir is not None: From f41c4d5763761900a9fab03f377bdd93a489e446 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 21 Nov 2025 07:01:54 -0500 Subject: [PATCH 14/51] Add overload type hint for Patient Co-authored-by: John Wu <54558896+jhnwu3@users.noreply.github.com> --- pyhealth/data/data.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/pyhealth/data/data.py b/pyhealth/data/data.py index 6369c3720..ea4225be0 100644 --- a/pyhealth/data/data.py +++ b/pyhealth/data/data.py @@ -1,7 +1,7 @@ import operator from dataclasses import dataclass, field from datetime import datetime -from typing import Dict, List, Mapping, Optional, Union, Any +from typing import Dict, List, Mapping, Optional, Union, Any, overload, Literal import dask.dataframe as dd import pandas as pd @@ -174,6 +174,28 @@ def _apply_attribute_filters( df = df[mask] return df + @overload + def get_events( + self, + *, + event_type: Optional[str] = None, + start: Optional[datetime] = None, + end: Optional[datetime] = None, + filters: Optional[List[tuple]] = None, + return_df: Literal[True] + ) -> dd.DataFrame: ... + + @overload + def get_events( + self, + *, + event_type: Optional[str] = None, + start: Optional[datetime] = None, + end: Optional[datetime] = None, + filters: Optional[List[tuple]] = None, + return_df: Literal[False] + ) -> List[Event]: ... + def get_events( self, event_type: Optional[str] = None, From 9195229a37f9e3015b242d5e23534887002f4c8f Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 21 Nov 2025 07:02:19 -0500 Subject: [PATCH 15/51] Update pre_filter signature to Dask Co-authored-by: John Wu <54558896+jhnwu3@users.noreply.github.com> --- pyhealth/tasks/base_task.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyhealth/tasks/base_task.py b/pyhealth/tasks/base_task.py index 888c7e2e1..9026e1e24 100644 --- a/pyhealth/tasks/base_task.py +++ b/pyhealth/tasks/base_task.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Dict, List, Union, Type -import polars as pl +import dask.dataframe as dd class BaseTask(ABC): @@ -9,7 +9,7 @@ class BaseTask(ABC): input_schema: Dict[str, Union[str, Type]] output_schema: Dict[str, Union[str, Type]] - def pre_filter(self, df: pl.LazyFrame) -> pl.LazyFrame: + def pre_filter(self, df: dd.DataFrame) -> dd.DataFrame: return df @abstractmethod From cf23a92b3ce2f1b42a5588bd9b24e9d9b3fd9453 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 21 Nov 2025 07:16:10 -0500 Subject: [PATCH 16/51] Fix type hint Co-authored-by: John Wu <54558896+jhnwu3@users.noreply.github.com> --- .../mortality_prediction_stagenet_mimic4.py | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/pyhealth/tasks/mortality_prediction_stagenet_mimic4.py b/pyhealth/tasks/mortality_prediction_stagenet_mimic4.py index fc9c58f7f..a4dcf60c1 100644 --- a/pyhealth/tasks/mortality_prediction_stagenet_mimic4.py +++ b/pyhealth/tasks/mortality_prediction_stagenet_mimic4.py @@ -1,9 +1,12 @@ from datetime import datetime -from typing import Any, ClassVar, Dict, List +from typing import Any, ClassVar, Dict, List, Type +import dask.dataframe as dd +import pandas as pd import polars as pl from .base_task import BaseTask +from ..data.data import Patient, Event class MortalityPredictionStageNetMIMIC4(BaseTask): @@ -35,11 +38,11 @@ class MortalityPredictionStageNetMIMIC4(BaseTask): """ task_name: str = "MortalityPredictionStageNetMIMIC4" - input_schema: Dict[str, str] = { + input_schema: Dict[str, str | Type] = { "icd_codes": "stagenet", "labs": "stagenet_tensor", } - output_schema: Dict[str, str] = {"mortality": "binary"} + output_schema: Dict[str, str | Type] = {"mortality": "binary"} # Organize lab items by category # Each category will map to ONE dimension in the output vector @@ -75,7 +78,7 @@ class MortalityPredictionStageNetMIMIC4(BaseTask): item for itemids in LAB_CATEGORIES.values() for item in itemids ] - def __call__(self, patient: Any) -> List[Dict[str, Any]]: + def __call__(self, patient: Patient) -> List[Dict[str, Any]]: """Process a patient to create mortality prediction samples. Creates ONE sample per patient with all admissions aggregated. @@ -89,13 +92,12 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: procedures, labs across visits, and final mortality label """ # Filter patients by age (>= 18) - demographics = patient.get_events(event_type="patients") + demographics: List[Event] = patient.get_events(event_type="patients", return_df=False) if not demographics: return [] - demographics = demographics[0] try: - anchor_age = int(demographics.anchor_age) + anchor_age = int(demographics[0].anchor_age) if anchor_age < 18: return [] except (ValueError, TypeError, AttributeError): @@ -103,7 +105,7 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: return [] # Get all admissions - admissions = patient.get_events(event_type="admissions") + admissions = patient.get_events(event_type="admissions", return_df=False) if len(admissions) < 1: return [] @@ -159,6 +161,7 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: diagnoses_icd = patient.get_events( event_type="diagnoses_icd", filters=[("hadm_id", "==", admission.hadm_id)], + return_df=False, ) visit_diagnoses = [ event.icd_code @@ -170,6 +173,7 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: procedures_icd = patient.get_events( event_type="procedures_icd", filters=[("hadm_id", "==", admission.hadm_id)], + return_df=False, ) visit_procedures = [ event.icd_code From 5d440c764e773515ce0a60ccbb56a66b7d68be88 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 21 Nov 2025 07:29:53 -0500 Subject: [PATCH 17/51] Chage lab_df to be dask compitable. Co-authored-by: John Wu <54558896+jhnwu3@users.noreply.github.com> --- .../mortality_prediction_stagenet_mimic4.py | 94 +++++++------------ 1 file changed, 36 insertions(+), 58 deletions(-) diff --git a/pyhealth/tasks/mortality_prediction_stagenet_mimic4.py b/pyhealth/tasks/mortality_prediction_stagenet_mimic4.py index a4dcf60c1..9c9840b0b 100644 --- a/pyhealth/tasks/mortality_prediction_stagenet_mimic4.py +++ b/pyhealth/tasks/mortality_prediction_stagenet_mimic4.py @@ -3,6 +3,7 @@ import dask.dataframe as dd import pandas as pd +import numpy as np import polars as pl from .base_task import BaseTask @@ -189,7 +190,7 @@ def __call__(self, patient: Patient) -> List[Dict[str, Any]]: all_icd_times.append(time_from_previous) # Get lab events for this admission - labevents_df = patient.get_events( + labevents_dd: dd.DataFrame = patient.get_events( event_type="labevents", start=admission_time, end=admission_dischtime, @@ -197,63 +198,40 @@ def __call__(self, patient: Patient) -> List[Dict[str, Any]]: ) # Filter to relevant lab items - labevents_df = labevents_df.filter( - pl.col("labevents/itemid").is_in(self.LABITEMS) - ) - - # Parse storetime and filter - if labevents_df.height > 0: - labevents_df = labevents_df.with_columns( - pl.col("labevents/storetime").str.strptime( - pl.Datetime, "%Y-%m-%d %H:%M:%S" - ) - ) - labevents_df = labevents_df.filter( - pl.col("labevents/storetime") <= admission_dischtime - ) - - if labevents_df.height > 0: - # Select relevant columns - labevents_df = labevents_df.select( - pl.col("timestamp"), - pl.col("labevents/itemid"), - pl.col("labevents/valuenum").cast(pl.Float64), - ) - - # Group by timestamp and aggregate into 10D vectors - # For each timestamp, create vector of lab categories - unique_timestamps = sorted( - labevents_df["timestamp"].unique().to_list() - ) - - for lab_ts in unique_timestamps: - # Get all lab events at this timestamp - ts_labs = labevents_df.filter(pl.col("timestamp") == lab_ts) - - # Create 10-dimensional vector (one per category) - lab_vector = [] - for category_name in self.LAB_CATEGORY_NAMES: - category_itemids = self.LAB_CATEGORIES[category_name] - - # Find first matching value for this category - category_value = None - for itemid in category_itemids: - matching = ts_labs.filter( - pl.col("labevents/itemid") == itemid - ) - if matching.height > 0: - category_value = matching["labevents/valuenum"][0] - break - - lab_vector.append(category_value) - - # Calculate time from admission start (hours) - time_from_admission = ( - lab_ts - admission_time - ).total_seconds() / 3600.0 - - all_lab_values.append(lab_vector) - all_lab_times.append(time_from_admission) + labevents_dd: dd.DataFrame = labevents_dd[labevents_dd["labevents/itemid"].isin(self.LABITEMS)] + storetime_series: dd.Series = dd.to_datetime(labevents_dd["labevents/storetime"], format="%Y-%m-%d %H:%M:%S", errors="coerce") + labevents_dd: dd.DataFrame = labevents_dd[storetime_series <= np.datetime64(admission_dischtime)] + labevents_dd: dd.DataFrame = labevents_dd[["timestamp", "labevents/itemid", "labevents/valuenum"]].astype({"labevents/valuenum": "float64"}) + labevents_df: pd.DataFrame = labevents_dd.compute() + + if not labevents_df.empty: + unique_timestamps = sorted(labevents_df["timestamp"].unique().tolist()) + for lab_ts in unique_timestamps: + # Get all lab events at this timestamp + ts_labs: pd.DataFrame = labevents_df[labevents_df["timestamp"] == lab_ts] + + # Create 10-dimensional vector (one per category) + lab_vector = [] + for category_name in self.LAB_CATEGORY_NAMES: + category_itemids = self.LAB_CATEGORIES[category_name] + + # Find first matching value for this category + category_value = None + for itemid in category_itemids: + matching = ts_labs[ts_labs["labevents/itemid"] == itemid] + if not matching.empty: + category_value = matching["labevents/valuenum"].iloc[0] + break + + lab_vector.append(category_value) + + # Calculate time from admission start (hours) + time_from_admission = ( + lab_ts - admission_time + ).total_seconds() / 3600.0 + + all_lab_values.append(lab_vector) + all_lab_times.append(time_from_admission) # Skip if no lab events (required for this task) if len(all_lab_values) == 0: From d834460ab3bc23e48e66f8b98f63e9fe379bbf8a Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 21 Nov 2025 08:21:39 -0500 Subject: [PATCH 18/51] Fix schema inference on csv reader Co-authored-by: John Wu <54558896+jhnwu3@users.noreply.github.com> --- pyhealth/datasets/base_dataset.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index ff926176e..6a2061198 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -14,6 +14,7 @@ import polars as pl import pandas as pd import dask.dataframe as dd +import pyarrow as pa import pyarrow.csv as pv import pyarrow.parquet as pq import requests @@ -351,13 +352,20 @@ def load_csv_or_tsv(self, table_name: str, path: str) -> dd.DataFrame: path = alt_path(path) delimiter = '\t' if path.endswith(".tsv") or path.endswith(".tsv.gz") else ',' - # TODO: this may give incorrect type inference for some columns - # if the first block is not representative - csv_reader = pv.open_csv( + + # Always infer schema as string to avoid incorrect type inference + schema_reader = pv.open_csv( path, read_options=pv.ReadOptions(block_size=1 << 26), # 64 MB parse_options=pv.ParseOptions(delimiter=delimiter) ) + schema = pa.schema([pa.field(name, pa.string()) for name in schema_reader.schema.names]) + csv_reader = pv.open_csv( + path, + read_options=pv.ReadOptions(block_size=1 << 26), # 64 MB + parse_options=pv.ParseOptions(delimiter=delimiter), + convert_options=pv.ConvertOptions(column_types=schema) + ) with pq.ParquetWriter(parquet_path, csv_reader.schema) as writer: for batch in csv_reader: writer.write_batch(batch) From 2a0d7d9c23b6e7c37f25a74c594f56fba5603015 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 21 Nov 2025 08:47:50 -0500 Subject: [PATCH 19/51] Fix incorrect Dask transform Co-authored-by: John Wu <54558896+jhnwu3@users.noreply.github.com> --- pyhealth/data/data.py | 16 +++++++--------- .../mortality_prediction_stagenet_mimic4.py | 8 ++++++-- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/pyhealth/data/data.py b/pyhealth/data/data.py index ea4225be0..7021eafaa 100644 --- a/pyhealth/data/data.py +++ b/pyhealth/data/data.py @@ -4,7 +4,7 @@ from typing import Dict, List, Mapping, Optional, Union, Any, overload, Literal import dask.dataframe as dd -import pandas as pd +import numpy as np @dataclass(frozen=True) class Event: @@ -132,9 +132,9 @@ def __init__(self, patient_id: str, data_source: dd.DataFrame) -> None: def _filter_by_time_range(self, df: dd.DataFrame, start: Optional[datetime], end: Optional[datetime]) -> dd.DataFrame: """Filter events by time range using lazy Dask operations.""" if start is not None: - df = df[df["timestamp"] >= start] + df = df[df["timestamp"] >= np.datetime64(start)] if end is not None: - df = df[df["timestamp"] <= end] + df = df[df["timestamp"] <= np.datetime64(end)] return df def _filter_by_event_type(self, df: dd.DataFrame, event_type: Optional[str]) -> dd.DataFrame: @@ -155,7 +155,7 @@ def _apply_attribute_filters( ">": operator.gt, ">=": operator.ge, } - mask = None + for filt in filters: if not (isinstance(filt, tuple) and len(filt) == 3): raise ValueError( @@ -167,11 +167,9 @@ def _apply_attribute_filters( col_name = f"{event_type}/{attr}" if col_name not in df.columns: raise KeyError(f"Column '{col_name}' not found in dataset") - col = df[col_name] - condition = op_map[op](col, val) - mask = condition if mask is None else mask & condition - if mask is not None: - df = df[mask] + + df = df[op_map[op](df[col_name], val)] + return df @overload diff --git a/pyhealth/tasks/mortality_prediction_stagenet_mimic4.py b/pyhealth/tasks/mortality_prediction_stagenet_mimic4.py index 9c9840b0b..fb36a06a7 100644 --- a/pyhealth/tasks/mortality_prediction_stagenet_mimic4.py +++ b/pyhealth/tasks/mortality_prediction_stagenet_mimic4.py @@ -199,8 +199,12 @@ def __call__(self, patient: Patient) -> List[Dict[str, Any]]: # Filter to relevant lab items labevents_dd: dd.DataFrame = labevents_dd[labevents_dd["labevents/itemid"].isin(self.LABITEMS)] - storetime_series: dd.Series = dd.to_datetime(labevents_dd["labevents/storetime"], format="%Y-%m-%d %H:%M:%S", errors="coerce") - labevents_dd: dd.DataFrame = labevents_dd[storetime_series <= np.datetime64(admission_dischtime)] + labevents_dd: dd.DataFrame = labevents_dd.assign(storetime_filter=dd.to_datetime( + labevents_dd["labevents/storetime"], + format="%Y-%m-%d %H:%M:%S", + errors="coerce" + )) + labevents_dd: dd.DataFrame = labevents_dd[labevents_dd["storetime_filter"] <= np.datetime64(admission_dischtime)] labevents_dd: dd.DataFrame = labevents_dd[["timestamp", "labevents/itemid", "labevents/valuenum"]].astype({"labevents/valuenum": "float64"}) labevents_df: pd.DataFrame = labevents_dd.compute() From c769a1867cec50a40b8ff8861ec0ad0c1a76e871 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 21 Nov 2025 09:32:35 -0500 Subject: [PATCH 20/51] Optimize code Co-authored-by: John Wu <54558896+jhnwu3@users.noreply.github.com> --- pyhealth/datasets/base_dataset.py | 13 +++++++++---- .../tasks/mortality_prediction_stagenet_mimic4.py | 5 ++++- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 6a2061198..d25738ed1 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -237,7 +237,12 @@ def collected_global_event_df(self) -> dd.DataFrame: else: self.global_event_df.to_parquet(path) - return dd.read_parquet(str(path)) + # This is imporant for fast fetch by patient_id + df = dd.read_parquet(str(path)) + df["index"] = df["patient_id"] + df = df.set_index("index") + + return df def load_data(self) -> dd.DataFrame: """Loads data from the specified tables. @@ -386,7 +391,7 @@ def unique_patient_ids(self) -> List[str]: """ if self._unique_patient_ids is None: self._unique_patient_ids = ( - self.collected_global_event_df["patient_id"] + self.collected_global_event_df.index .unique() .compute() .tolist() @@ -413,7 +418,7 @@ def get_patient(self, patient_id: str) -> Patient: if not isinstance(df, dd.DataFrame): raise TypeError("collected_global_event_df must be a Dask DataFrame") - patient_df = df[df["patient_id"] == patient_id] + patient_df = df.loc[patient_id] return Patient(patient_id=patient_id, data_source=patient_df) def iter_patients(self, df: Optional[dd.DataFrame] = None) -> Iterator[Patient]: @@ -434,7 +439,7 @@ def iter_patients(self, df: Optional[dd.DataFrame] = None) -> Iterator[Patient]: def stats(self) -> None: """Prints statistics about the dataset.""" df = self.collected_global_event_df - n_patients = df["patient_id"].nunique().compute() + n_patients = len(self.unique_patient_ids) n_events = df.shape[0].compute() print(f"Dataset: {self.dataset_name}") print(f"Dev mode: {self.dev}") diff --git a/pyhealth/tasks/mortality_prediction_stagenet_mimic4.py b/pyhealth/tasks/mortality_prediction_stagenet_mimic4.py index fb36a06a7..7b5c49269 100644 --- a/pyhealth/tasks/mortality_prediction_stagenet_mimic4.py +++ b/pyhealth/tasks/mortality_prediction_stagenet_mimic4.py @@ -205,7 +205,10 @@ def __call__(self, patient: Patient) -> List[Dict[str, Any]]: errors="coerce" )) labevents_dd: dd.DataFrame = labevents_dd[labevents_dd["storetime_filter"] <= np.datetime64(admission_dischtime)] - labevents_dd: dd.DataFrame = labevents_dd[["timestamp", "labevents/itemid", "labevents/valuenum"]].astype({"labevents/valuenum": "float64"}) + labevents_dd: dd.DataFrame = labevents_dd[["timestamp", "labevents/itemid", "labevents/valuenum"]] + labevents_dd["labevents/valuenum"] = dd.to_numeric( + labevents_dd["labevents/valuenum"], errors="coerce" + ) labevents_df: pd.DataFrame = labevents_dd.compute() if not labevents_df.empty: From 90dee139755a76ba67b4a08c491f0d9dec0e776a Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 21 Nov 2025 09:45:00 -0500 Subject: [PATCH 21/51] revert data back to polars as it it faster Co-authored-by: John Wu <54558896+jhnwu3@users.noreply.github.com> --- pyhealth/data/data.py | 157 +++++++++++++++++++++--------------------- 1 file changed, 80 insertions(+), 77 deletions(-) diff --git a/pyhealth/data/data.py b/pyhealth/data/data.py index 7021eafaa..2d1113257 100644 --- a/pyhealth/data/data.py +++ b/pyhealth/data/data.py @@ -1,10 +1,12 @@ import operator from dataclasses import dataclass, field from datetime import datetime -from typing import Dict, List, Mapping, Optional, Union, Any, overload, Literal +from functools import reduce +from typing import Dict, List, Mapping, Optional, Union, Any -import dask.dataframe as dd import numpy as np +import polars as pl + @dataclass(frozen=True) class Event: @@ -52,7 +54,8 @@ def from_dict(cls, d: Dict[str, Any]) -> "Event": """Create an Event instance from a dictionary. Args: - d (Dict[str, Any]): Dictionary containing event data. + d (Dict[str, any]): Dictionary containing event data. + Returns: Event: An instance of the Event class. """ @@ -116,83 +119,56 @@ class Patient: Attributes: patient_id (str): Unique patient identifier. - data_source (dd.DataFrame): Dask DataFrame containing all events. + data_source (pl.DataFrame): DataFrame containing all events, sorted by timestamp. + event_type_partitions (Dict[str, pl.DataFrame]): Dictionary mapping event types to their respective DataFrame partitions. """ - def __init__(self, patient_id: str, data_source: dd.DataFrame) -> None: - """Initialize a Patient instance. + def __init__(self, patient_id: str, data_source: pl.DataFrame) -> None: + """ + Initialize a Patient instance. Args: patient_id (str): Unique patient identifier. - data_source (dd.DataFrame): DataFrame containing all events. + data_source (pl.DataFrame): DataFrame containing all events. """ self.patient_id = patient_id - self.data_source = data_source + self.data_source = data_source.sort("timestamp") + self.event_type_partitions = self.data_source.partition_by("event_type", maintain_order=True, as_dict=True) - def _filter_by_time_range(self, df: dd.DataFrame, start: Optional[datetime], end: Optional[datetime]) -> dd.DataFrame: - """Filter events by time range using lazy Dask operations.""" + def _filter_by_time_range_regular(self, df: pl.DataFrame, start: Optional[datetime], end: Optional[datetime]) -> pl.DataFrame: + """Regular filtering by time. Time complexity: O(n).""" if start is not None: - df = df[df["timestamp"] >= np.datetime64(start)] + df = df.filter(pl.col("timestamp") >= start) if end is not None: - df = df[df["timestamp"] <= np.datetime64(end)] - return df - - def _filter_by_event_type(self, df: dd.DataFrame, event_type: Optional[str]) -> dd.DataFrame: - """Filter by event type if provided.""" - if event_type: - df = df[df["event_type"] == event_type] + df = df.filter(pl.col("timestamp") <= end) return df - def _apply_attribute_filters( - self, df: dd.DataFrame, event_type: str, filters: List[tuple] - ) -> dd.DataFrame: - """Apply attribute-level filters to the DataFrame.""" - op_map = { - "==": operator.eq, - "!=": operator.ne, - "<": operator.lt, - "<=": operator.le, - ">": operator.gt, - ">=": operator.ge, - } - - for filt in filters: - if not (isinstance(filt, tuple) and len(filt) == 3): - raise ValueError( - f"Invalid filter format: {filt} (must be tuple of (attr, op, value))" - ) - attr, op, val = filt - if op not in op_map: - raise ValueError(f"Unsupported operator: {op} in filter {filt}") - col_name = f"{event_type}/{attr}" - if col_name not in df.columns: - raise KeyError(f"Column '{col_name}' not found in dataset") - - df = df[op_map[op](df[col_name], val)] + def _filter_by_time_range_fast(self, df: pl.DataFrame, start: Optional[datetime], end: Optional[datetime]) -> pl.DataFrame: + """Fast filtering by time using binary search on sorted timestamps. Time complexity: O(log n).""" + if start is None and end is None: + return df + df = df.filter(pl.col("timestamp").is_not_null()) + ts_col = df["timestamp"].to_numpy() + start_idx = 0 + end_idx = len(ts_col) + if start is not None: + start_idx = np.searchsorted(ts_col, start, side="left") + if end is not None: + end_idx = np.searchsorted(ts_col, end, side="right") + return df.slice(start_idx, end_idx - start_idx) + def _filter_by_event_type_regular(self, df: pl.DataFrame, event_type: Optional[str]) -> pl.DataFrame: + """Regular filtering by event type. Time complexity: O(n).""" + if event_type: + df = df.filter(pl.col("event_type") == event_type) return df - @overload - def get_events( - self, - *, - event_type: Optional[str] = None, - start: Optional[datetime] = None, - end: Optional[datetime] = None, - filters: Optional[List[tuple]] = None, - return_df: Literal[True] - ) -> dd.DataFrame: ... - - @overload - def get_events( - self, - *, - event_type: Optional[str] = None, - start: Optional[datetime] = None, - end: Optional[datetime] = None, - filters: Optional[List[tuple]] = None, - return_df: Literal[False] - ) -> List[Event]: ... + def _filter_by_event_type_fast(self, df: pl.DataFrame, event_type: Optional[str]) -> pl.DataFrame: + """Fast filtering by event type using pre-built event type index. Time complexity: O(1).""" + if event_type: + return self.event_type_partitions.get((event_type,), df[:0]) + else: + return df def get_events( self, @@ -201,33 +177,60 @@ def get_events( end: Optional[datetime] = None, filters: Optional[List[tuple]] = None, return_df: bool = False, - ) -> Union[dd.DataFrame, List[Event]]: + ) -> Union[pl.DataFrame, List[Event]]: """Get events with optional type and time filters. Args: event_type (Optional[str]): Type of events to filter. start (Optional[datetime]): Start time for filtering events. end (Optional[datetime]): End time for filtering events. - return_df (bool): Whether to return a pandas DataFrame or a list of + return_df (bool): Whether to return a DataFrame or a list of Event objects. filters (Optional[List[tuple]]): Additional filters as [(attr, op, value), ...], e.g.: [("attr1", "!=", "abnormal"), ("attr2", "!=", 1)]. Filters are applied after type and time filters. The logic is "AND" between different filters. Returns: - Union[dd.DataFrame, List[Event]]: Filtered events as a Dask DataFrame + Union[pl.DataFrame, List[Event]]: Filtered events as a DataFrame or a list of Event objects. """ - df = self._filter_by_event_type(self.data_source, event_type) - df = self._filter_by_time_range(df, start, end) + # faster filtering (by default) + df = self._filter_by_event_type_fast(self.data_source, event_type) + df = self._filter_by_time_range_fast(df, start, end) - active_filters = filters or [] - if active_filters: - assert event_type is not None, "event_type must be provided if filters are provided" - df = self._apply_attribute_filters(df, event_type, active_filters) + # regular filtering (commented out by default) + # df = self._filter_by_event_type_regular(self.data_source, event_type) + # df = self._filter_by_time_range_regular(df, start, end) + if filters: + assert event_type is not None, "event_type must be provided if filters are provided" + else: + filters = [] + exprs = [] + for filt in filters: + if not (isinstance(filt, tuple) and len(filt) == 3): + raise ValueError( + f"Invalid filter format: {filt} (must be tuple of (attr, op, value))" + ) + attr, op, val = filt + col_expr = pl.col(f"{event_type}/{attr}") + # Build operator expression + if op == "==": + exprs.append(col_expr == val) + elif op == "!=": + exprs.append(col_expr != val) + elif op == "<": + exprs.append(col_expr < val) + elif op == "<=": + exprs.append(col_expr <= val) + elif op == ">": + exprs.append(col_expr > val) + elif op == ">=": + exprs.append(col_expr >= val) + else: + raise ValueError(f"Unsupported operator: {op} in filter {filt}") + if exprs: + df = df.filter(reduce(operator.and_, exprs)) if return_df: return df - # Dask DataFrames do not expose .to_dict on lazy expressions; compute to pandas first. - records = df.compute().to_dict("records") - return [Event.from_dict(d) for d in records] + return [Event.from_dict(d) for d in df.to_dicts()] \ No newline at end of file From d2faab989c10d5057cf51857f8c50e4b167040e1 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 21 Nov 2025 09:46:25 -0500 Subject: [PATCH 22/51] Because patient has reverted Co-authored-by: John Wu <54558896+jhnwu3@users.noreply.github.com> --- tests/core/test_patient.py | 69 -------------------------------------- 1 file changed, 69 deletions(-) delete mode 100644 tests/core/test_patient.py diff --git a/tests/core/test_patient.py b/tests/core/test_patient.py deleted file mode 100644 index a6d0cdc8a..000000000 --- a/tests/core/test_patient.py +++ /dev/null @@ -1,69 +0,0 @@ -import unittest -from datetime import datetime - -import dask.dataframe as dd -import pandas as pd - -from pyhealth.data import Patient - - -class TestPatientGetEvents(unittest.TestCase): - def setUp(self): - timestamps = [ - datetime(2021, 1, 1), - datetime(2021, 1, 5), - datetime(2021, 2, 1), - ] - pdf = pd.DataFrame( - { - "patient_id": ["p1", "p1", "p1"], - "event_type": ["labs", "labs", "visit"], - "timestamp": timestamps, - "labs/result": [1.0, 2.0, None], - "labs/unit": ["mg/dL", "mg/dL", None], - "visit/location": [None, None, "icu"], - } - ) - self.ddf = dd.from_pandas(pdf, npartitions=1) - self.patient = Patient(patient_id="p1", data_source=self.ddf) - - def test_returns_event_objects_by_default(self): - events = self.patient.get_events() - self.assertEqual(len(events), 3) - self.assertEqual( - sorted([e.event_type for e in events]), ["labs", "labs", "visit"] - ) - self.assertEqual(events[0].attr_dict["result"], 1.0) - - def test_return_df_flag(self): - labs_df = self.patient.get_events(event_type="labs", return_df=True) - labs_pdf = labs_df.compute() - self.assertEqual(len(labs_pdf), 2) - self.assertTrue((labs_pdf["event_type"] == "labs").all()) - - def test_time_range_filter(self): - start = datetime(2021, 1, 2) - end = datetime(2021, 1, 31) - events = self.patient.get_events(start=start, end=end) - self.assertEqual(len(events), 1) - self.assertEqual(events[0].timestamp, datetime(2021, 1, 5)) - - def test_event_type_and_attribute_filters(self): - filters = [("result", ">=", 2)] - events = self.patient.get_events(event_type="labs", filters=filters) - self.assertEqual(len(events), 1) - self.assertEqual(events[0].attr_dict["result"], 2.0) - - def test_filters_require_event_type(self): - with self.assertRaises(AssertionError): - self.patient.get_events(filters=[("result", "==", 1)]) - - def test_missing_column_in_filters_raises(self): - with self.assertRaises(KeyError): - self.patient.get_events( - event_type="labs", filters=[("does_not_exist", "==", 1)] - ) - - -if __name__ == "__main__": - unittest.main() From c1d9117d785d511cd5b1528f1fafc96ef65c8fbc Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 21 Nov 2025 09:48:52 -0500 Subject: [PATCH 23/51] Revert task to use polars, as it's faster Co-authored-by: John Wu <54558896+jhnwu3@users.noreply.github.com> --- .../mortality_prediction_stagenet_mimic4.py | 111 ++++++++++-------- 1 file changed, 61 insertions(+), 50 deletions(-) diff --git a/pyhealth/tasks/mortality_prediction_stagenet_mimic4.py b/pyhealth/tasks/mortality_prediction_stagenet_mimic4.py index 7b5c49269..c23ee3225 100644 --- a/pyhealth/tasks/mortality_prediction_stagenet_mimic4.py +++ b/pyhealth/tasks/mortality_prediction_stagenet_mimic4.py @@ -1,13 +1,9 @@ from datetime import datetime from typing import Any, ClassVar, Dict, List, Type -import dask.dataframe as dd -import pandas as pd -import numpy as np import polars as pl from .base_task import BaseTask -from ..data.data import Patient, Event class MortalityPredictionStageNetMIMIC4(BaseTask): @@ -79,7 +75,7 @@ class MortalityPredictionStageNetMIMIC4(BaseTask): item for itemids in LAB_CATEGORIES.values() for item in itemids ] - def __call__(self, patient: Patient) -> List[Dict[str, Any]]: + def __call__(self, patient: Any) -> List[Dict[str, Any]]: """Process a patient to create mortality prediction samples. Creates ONE sample per patient with all admissions aggregated. @@ -93,12 +89,13 @@ def __call__(self, patient: Patient) -> List[Dict[str, Any]]: procedures, labs across visits, and final mortality label """ # Filter patients by age (>= 18) - demographics: List[Event] = patient.get_events(event_type="patients", return_df=False) + demographics = patient.get_events(event_type="patients") if not demographics: return [] + demographics = demographics[0] try: - anchor_age = int(demographics[0].anchor_age) + anchor_age = int(demographics.anchor_age) if anchor_age < 18: return [] except (ValueError, TypeError, AttributeError): @@ -106,7 +103,7 @@ def __call__(self, patient: Patient) -> List[Dict[str, Any]]: return [] # Get all admissions - admissions = patient.get_events(event_type="admissions", return_df=False) + admissions = patient.get_events(event_type="admissions") if len(admissions) < 1: return [] @@ -162,7 +159,6 @@ def __call__(self, patient: Patient) -> List[Dict[str, Any]]: diagnoses_icd = patient.get_events( event_type="diagnoses_icd", filters=[("hadm_id", "==", admission.hadm_id)], - return_df=False, ) visit_diagnoses = [ event.icd_code @@ -174,7 +170,6 @@ def __call__(self, patient: Patient) -> List[Dict[str, Any]]: procedures_icd = patient.get_events( event_type="procedures_icd", filters=[("hadm_id", "==", admission.hadm_id)], - return_df=False, ) visit_procedures = [ event.icd_code @@ -190,7 +185,7 @@ def __call__(self, patient: Patient) -> List[Dict[str, Any]]: all_icd_times.append(time_from_previous) # Get lab events for this admission - labevents_dd: dd.DataFrame = patient.get_events( + labevents_df = patient.get_events( event_type="labevents", start=admission_time, end=admission_dischtime, @@ -198,47 +193,63 @@ def __call__(self, patient: Patient) -> List[Dict[str, Any]]: ) # Filter to relevant lab items - labevents_dd: dd.DataFrame = labevents_dd[labevents_dd["labevents/itemid"].isin(self.LABITEMS)] - labevents_dd: dd.DataFrame = labevents_dd.assign(storetime_filter=dd.to_datetime( - labevents_dd["labevents/storetime"], - format="%Y-%m-%d %H:%M:%S", - errors="coerce" - )) - labevents_dd: dd.DataFrame = labevents_dd[labevents_dd["storetime_filter"] <= np.datetime64(admission_dischtime)] - labevents_dd: dd.DataFrame = labevents_dd[["timestamp", "labevents/itemid", "labevents/valuenum"]] - labevents_dd["labevents/valuenum"] = dd.to_numeric( - labevents_dd["labevents/valuenum"], errors="coerce" + labevents_df = labevents_df.filter( + pl.col("labevents/itemid").is_in(self.LABITEMS) ) - labevents_df: pd.DataFrame = labevents_dd.compute() - if not labevents_df.empty: - unique_timestamps = sorted(labevents_df["timestamp"].unique().tolist()) - for lab_ts in unique_timestamps: - # Get all lab events at this timestamp - ts_labs: pd.DataFrame = labevents_df[labevents_df["timestamp"] == lab_ts] - - # Create 10-dimensional vector (one per category) - lab_vector = [] - for category_name in self.LAB_CATEGORY_NAMES: - category_itemids = self.LAB_CATEGORIES[category_name] - - # Find first matching value for this category - category_value = None - for itemid in category_itemids: - matching = ts_labs[ts_labs["labevents/itemid"] == itemid] - if not matching.empty: - category_value = matching["labevents/valuenum"].iloc[0] - break - - lab_vector.append(category_value) - - # Calculate time from admission start (hours) - time_from_admission = ( - lab_ts - admission_time - ).total_seconds() / 3600.0 + # Parse storetime and filter + if labevents_df.height > 0: + labevents_df = labevents_df.with_columns( + pl.col("labevents/storetime").str.strptime( + pl.Datetime, "%Y-%m-%d %H:%M:%S" + ) + ) + labevents_df = labevents_df.filter( + pl.col("labevents/storetime") <= admission_dischtime + ) - all_lab_values.append(lab_vector) - all_lab_times.append(time_from_admission) + if labevents_df.height > 0: + # Select relevant columns + labevents_df = labevents_df.select( + pl.col("timestamp"), + pl.col("labevents/itemid"), + pl.col("labevents/valuenum").cast(pl.Float64), + ) + + # Group by timestamp and aggregate into 10D vectors + # For each timestamp, create vector of lab categories + unique_timestamps = sorted( + labevents_df["timestamp"].unique().to_list() + ) + + for lab_ts in unique_timestamps: + # Get all lab events at this timestamp + ts_labs = labevents_df.filter(pl.col("timestamp") == lab_ts) + + # Create 10-dimensional vector (one per category) + lab_vector = [] + for category_name in self.LAB_CATEGORY_NAMES: + category_itemids = self.LAB_CATEGORIES[category_name] + + # Find first matching value for this category + category_value = None + for itemid in category_itemids: + matching = ts_labs.filter( + pl.col("labevents/itemid") == itemid + ) + if matching.height > 0: + category_value = matching["labevents/valuenum"][0] + break + + lab_vector.append(category_value) + + # Calculate time from admission start (hours) + time_from_admission = ( + lab_ts - admission_time + ).total_seconds() / 3600.0 + + all_lab_values.append(lab_vector) + all_lab_times.append(time_from_admission) # Skip if no lab events (required for this task) if len(all_lab_values) == 0: @@ -262,4 +273,4 @@ def __call__(self, patient: Patient) -> List[Dict[str, Any]]: "labs": labs_data, "mortality": final_mortality, } - return [sample] + return [sample] \ No newline at end of file From 9905ddba55c69c2b67feb6853b05dd08209d8bb8 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 21 Nov 2025 09:50:37 -0500 Subject: [PATCH 24/51] use pl.DataFrame in patient. Co-authored-by: John Wu <54558896+jhnwu3@users.noreply.github.com> --- pyhealth/datasets/base_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index d25738ed1..f8e2e7cde 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -418,7 +418,7 @@ def get_patient(self, patient_id: str) -> Patient: if not isinstance(df, dd.DataFrame): raise TypeError("collected_global_event_df must be a Dask DataFrame") - patient_df = df.loc[patient_id] + patient_df: pl.DataFrame = pl.from_pandas(df.loc[patient_id].compute()) return Patient(patient_id=patient_id, data_source=patient_df) def iter_patients(self, df: Optional[dd.DataFrame] = None) -> Iterator[Patient]: From 91143b6993f14536d032fc34d6fba40b9a07e8cc Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 21 Nov 2025 10:04:23 -0500 Subject: [PATCH 25/51] Fix type conversion issues Co-authored-by: John Wu <54558896+jhnwu3@users.noreply.github.com> --- pyhealth/data/data.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyhealth/data/data.py b/pyhealth/data/data.py index 2d1113257..c9b88b1a6 100644 --- a/pyhealth/data/data.py +++ b/pyhealth/data/data.py @@ -148,13 +148,13 @@ def _filter_by_time_range_fast(self, df: pl.DataFrame, start: Optional[datetime] if start is None and end is None: return df df = df.filter(pl.col("timestamp").is_not_null()) - ts_col = df["timestamp"].to_numpy() + ts_col = df["timestamp"].dt.epoch("s").to_numpy() start_idx = 0 end_idx = len(ts_col) if start is not None: - start_idx = np.searchsorted(ts_col, start, side="left") + start_idx = np.searchsorted(ts_col, start.timestamp(), side="left") if end is not None: - end_idx = np.searchsorted(ts_col, end, side="right") + end_idx = np.searchsorted(ts_col, end.timestamp(), side="right") return df.slice(start_idx, end_idx - start_idx) def _filter_by_event_type_regular(self, df: pl.DataFrame, event_type: Optional[str]) -> pl.DataFrame: From 4091f7f27bd901b24aab676cfea97509947d7541 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 21 Nov 2025 10:05:19 -0500 Subject: [PATCH 26/51] Add litdata Co-authored-by: John Wu <54558896+jhnwu3@users.noreply.github.com> --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index b0501bc4c..356c12d96 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ dependencies = [ "pandarallel~=1.6.5", "pydantic~=2.11.7", "dask[dataframe,distributed]~=2025.11.0", + "litdata~=0.2.58", ] license = "BSD-3-Clause" license-files = ["LICENSE.md"] From 5a9424f87dd2de774b30a926a3545e0b2d333737 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 21 Nov 2025 16:28:24 -0500 Subject: [PATCH 27/51] Fix Mimic4 Co-authored-by: John Wu <54558896+jhnwu3@users.noreply.github.com> --- pyhealth/datasets/mimic4.py | 12 ++++-------- .../tasks/mortality_prediction_stagenet_mimic4.py | 2 +- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/pyhealth/datasets/mimic4.py b/pyhealth/datasets/mimic4.py index aca96f16c..e95c9934e 100644 --- a/pyhealth/datasets/mimic4.py +++ b/pyhealth/datasets/mimic4.py @@ -276,18 +276,12 @@ def __init__( ) self.setup_cache_dir(cache_dir=cache_dir, subfolder=subfolder) - # Combine data from all sub-datasets - log_memory_usage("Before combining data") - self.global_event_df = self._combine_data() - log_memory_usage("After combining data") - # Cache attributes - self._collected_global_event_df = None self._unique_patient_ids = None log_memory_usage("Completed MIMIC4Dataset init") - def _combine_data(self) -> dd.DataFrame: + def load_data(self) -> dd.DataFrame: """ Combines data from all initialized sub-datasets into a unified global event dataframe. @@ -298,8 +292,10 @@ def _combine_data(self) -> dd.DataFrame: # Collect global event dataframes from all sub-datasets for dataset_type, dataset in self.sub_datasets.items(): + dataset_type: str + dataset: BaseDataset logger.info(f"Combining data from {dataset_type} dataset") - frames.append(dataset.global_event_df) + frames.append(dataset.load_data()) # Concatenate all frames logger.info("Creating combined dataframe") diff --git a/pyhealth/tasks/mortality_prediction_stagenet_mimic4.py b/pyhealth/tasks/mortality_prediction_stagenet_mimic4.py index c23ee3225..25da630be 100644 --- a/pyhealth/tasks/mortality_prediction_stagenet_mimic4.py +++ b/pyhealth/tasks/mortality_prediction_stagenet_mimic4.py @@ -213,7 +213,7 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: labevents_df = labevents_df.select( pl.col("timestamp"), pl.col("labevents/itemid"), - pl.col("labevents/valuenum").cast(pl.Float64), + pl.col("labevents/valuenum").str.strip_chars().replace("", None).cast(pl.Float64), ) # Group by timestamp and aggregate into 10D vectors From 1a0b6a0e39c1ddd6255d53edeaa5e243aab3ff0c Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 21 Nov 2025 16:41:01 -0500 Subject: [PATCH 28/51] Works for single worker --- pyhealth/datasets/base_dataset.py | 356 +++++++++++++++--------------- 1 file changed, 177 insertions(+), 179 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index f8e2e7cde..afc523c48 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -1,15 +1,15 @@ import logging import os -import pickle from abc import ABC -from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path -from typing import Dict, Iterator, List, Optional +from typing import Dict, Iterator, List, Optional, Any, Callable from urllib.parse import urlparse, urlunparse import uuid import json import functools import operator +from collections import namedtuple +import pickle import polars as pl import pandas as pd @@ -18,7 +18,6 @@ import pyarrow.csv as pv import pyarrow.parquet as pq import requests -from tqdm import tqdm import platformdirs from ..data import Patient @@ -26,7 +25,6 @@ from ..processors.base_processor import FeatureProcessor from .configs import load_yaml_config from .sample_dataset import SampleDataset -from .utils import _convert_for_cache, _restore_from_cache logger = logging.getLogger(__name__) @@ -65,6 +63,7 @@ def path_exists(path: str) -> bool: else: return Path(path).exists() + def alt_path(path: str) -> str: """ Get the alternative path by switching between .csv.gz, .csv, .tsv.gz, and .tsv extensions. @@ -86,6 +85,7 @@ def alt_path(path: str) -> str: else: raise ValueError(f"Path does not have expected extension: {path}") + def scan_csv_gz_or_csv_tsv(path: str) -> pl.LazyFrame: """ Scan a CSV.gz, CSV, TSV.gz, or TSV file and returns a LazyFrame. @@ -124,6 +124,17 @@ def scan_file(file_path: str) -> pl.LazyFrame: raise FileNotFoundError(f"Neither path exists: {path} or {alt_path}") +def _transform_fn(input: tuple[str, str, BaseTask]) -> Iterator[Dict[str, Any]]: + (patient_id, path, task) = input + patient = Patient( + patient_id=patient_id, + data_source=pl.read_parquet(path).filter(pl.col("patient_id") == patient_id), + ) + for sample in task(patient): + sample = {k: pickle.dumps(v) for k,v in sample.items()} + yield sample + + class BaseDataset(ABC): """Abstract base class for all PyHealth datasets. @@ -169,36 +180,42 @@ def __init__( self.config = load_yaml_config(config_path) self.dev = dev - subfolder = self.cache_subfolder(self.root, self.tables, self.dataset_name, self.dev) + subfolder = self.cache_subfolder( + self.root, self.tables, self.dataset_name, self.dev + ) self.setup_cache_dir(cache_dir=cache_dir, subfolder=subfolder) logger.info( f"Initializing {self.dataset_name} dataset from {self.root} (dev mode: {self.dev})" ) - self.global_event_df = self.load_data() - # Cached attributes - self._collected_global_event_df = None self._unique_patient_ids = None @staticmethod - def cache_subfolder(root: str, tables: List[str], dataset_name: str, dev: bool) -> str: - """Generates a unique identifier for the dataset instance. This is used for creating + def cache_subfolder( + root: str, tables: List[str], dataset_name: str, dev: bool + ) -> str: + """Generates a unique identifier for the dataset instance. This is used for creating cache directories. The UUID is based on the root path, tables, dataset name, and dev mode. Returns: str: A unique identifier string. """ - id_str = json.dumps({ - "root": root, - "tables": sorted(tables), - "dataset_name": dataset_name, - "dev": dev, - }, sort_keys=True) + id_str = json.dumps( + { + "root": root, + "tables": sorted(tables), + "dataset_name": dataset_name, + "dev": dev, + }, + sort_keys=True, + ) return str(uuid.uuid5(uuid.NAMESPACE_DNS, id_str)) - def setup_cache_dir(self, cache_dir: str | Path | None = None, subfolder: str = str(uuid.uuid4())) -> None: + def setup_cache_dir( + self, cache_dir: str | Path | None = None, subfolder: str = str(uuid.uuid4()) + ) -> None: """Creates the cache directory structure. Args: @@ -207,42 +224,63 @@ def setup_cache_dir(self, cache_dir: str | Path | None = None, subfolder: str = subfolder (str): Subfolder name for this dataset instance's cache. """ if cache_dir is None: - cache_dir = platformdirs.user_cache_dir(appname='pyhealth') - logger.info(f"No cache_dir provided. Using default cache for PyHealth: {cache_dir}") + cache_dir = platformdirs.user_cache_dir(appname="pyhealth") + logger.info( + f"No cache_dir provided. Using default cache for PyHealth: {cache_dir}" + ) cache_dir = Path(cache_dir) self.cache_dir = cache_dir / subfolder + logger.info( + f"Initializing {self.dataset_name} dataset cache directory to {self.cache_dir}" + ) + + def _table_cache(self, table_name: str) -> str: + """Generates the cache path for a specific table. - self.cache_dir.mkdir(parents=True, exist_ok=True) - # Create tables subdirectory to store cached table files + Args: + table_name (str): The name of the table. + + Returns: + str: The cache path for the table. + """ (self.cache_dir / "tables").mkdir(parents=True, exist_ok=True) - # Create global_event_df subdirectory to store cached global event dataframe - (self.cache_dir / "global_event_df").mkdir(parents=True, exist_ok=True) - - logger.info(f"Initializing {self.dataset_name} dataset cache directory to {self.cache_dir}") + return str(self.cache_dir / "tables" / f"{table_name}.parquet") + + def _dataset_cache(self) -> str: + """Generates the cache path for the global event dataframe. + + Returns: + str: The cache path for the global event dataframe. + """ + return str(self.cache_dir / "global_event_df.parquet") + + def _task_cache(self, task_name: str) -> str: + """Generates the cache path for a specific task. + + Args: + task_name (str): The name of the task. + Returns: + str: The cache path for the task. + """ + return str(self.cache_dir / "tasks" / task_name) @property - def collected_global_event_df(self) -> dd.DataFrame: + def collected_global_event_df(self) -> pl.LazyFrame: """Collects and returns the global event data frame. Returns: dd.DataFrame: The collected global event data frame. """ - path = self.cache_dir / "global_event_df" / "cached.parquet" - if not path_exists(str(path)): + if not path_exists(self._dataset_cache()): + global_event_df = self.load_data() if self.dev: - patients = self.global_event_df["patient_id"].unique().head(1000).tolist() - filter = self.global_event_df["patient_id"].isin(patients) - self.global_event_df[filter].to_parquet(path) - else: - self.global_event_df.to_parquet(path) - - # This is imporant for fast fetch by patient_id - df = dd.read_parquet(str(path)) - df["index"] = df["patient_id"] - df = df.set_index("index") - - return df + patients = global_event_df["patient_id"].unique().head(1000).tolist() + filter = global_event_df["patient_id"].isin(patients) + global_event_df: dd.DataFrame = global_event_df[filter] + global_event_df = global_event_df.sort_values(by=["patient_id"]) + global_event_df.to_parquet(self._dataset_cache()) + return pl.scan_parquet(self._dataset_cache()) def load_data(self) -> dd.DataFrame: """Loads data from the specified tables. @@ -291,12 +329,16 @@ def load_table(self, table_name: str) -> dd.DataFrame: other_csv_path = f"{self.root}/{join_cfg.file_path}" other_csv_path = clean_path(other_csv_path) logger.info(f"Joining with table: {other_csv_path}") - join_df: dd.DataFrame = self.load_csv_or_tsv(f"{table_name}_join_{i}", other_csv_path) + join_df: dd.DataFrame = self.load_csv_or_tsv( + f"{table_name}_join_{i}", other_csv_path + ) join_key = join_cfg.on columns = join_cfg.columns how = join_cfg.how - df: dd.DataFrame = df.merge(join_df[[join_key] + columns], on=join_key, how=how) + df: dd.DataFrame = df.merge( + join_df[[join_key] + columns], on=join_key, how=how + ) patient_id_col = table_cfg.patient_id timestamp_col = table_cfg.timestamp @@ -307,7 +349,9 @@ def load_table(self, table_name: str) -> dd.DataFrame: if timestamp_col: if isinstance(timestamp_col, list): # Concatenate all timestamp parts in order with no separator - timestamp_series: dd.Series = functools.reduce(operator.add, (df[col].astype(str) for col in timestamp_col)) + timestamp_series: dd.Series = functools.reduce( + operator.add, (df[col].astype(str) for col in timestamp_col) + ) else: # Single timestamp column timestamp_series: dd.Series = df[timestamp_col].astype(str) @@ -316,7 +360,9 @@ def load_table(self, table_name: str) -> dd.DataFrame: format=timestamp_format, errors="raise", ) - df: dd.DataFrame = df.assign(timestamp=timestamp_series.astype("datetime64[ms]")) + df: dd.DataFrame = df.assign( + timestamp=timestamp_series.astype("datetime64[ms]") + ) else: df: dd.DataFrame = df.assign(timestamp=pd.NaT) @@ -326,7 +372,7 @@ def load_table(self, table_name: str) -> dd.DataFrame: else: df: dd.DataFrame = df.reset_index(drop=True) df: dd.DataFrame = df.assign(patient_id=df.index.astype(str)) - + df: dd.DataFrame = df.assign(event_type=table_name) rename_attr = {attr: f"{table_name}/{attr}" for attr in attribute_cols} @@ -347,38 +393,44 @@ def load_csv_or_tsv(self, table_name: str, path: str) -> dd.DataFrame: Returns: dd.DataFrame: The loaded Dask DataFrame. """ - parquet_path = self.cache_dir / "tables" / f"{table_name}.parquet" - - if not path_exists(str(parquet_path)): + if not path_exists(self._table_cache(table_name)): # convert .gz file to .parquet file since Dask cannot split on gz files directly if not path_exists(path): if not path_exists(alt_path(path)): - raise FileNotFoundError(f"Neither path exists: {path} or {alt_path(path)}") + raise FileNotFoundError( + f"Neither path exists: {path} or {alt_path(path)}" + ) path = alt_path(path) - - delimiter = '\t' if path.endswith(".tsv") or path.endswith(".tsv.gz") else ',' + + delimiter = ( + "\t" if path.endswith(".tsv") or path.endswith(".tsv.gz") else "," + ) # Always infer schema as string to avoid incorrect type inference schema_reader = pv.open_csv( - path, - read_options=pv.ReadOptions(block_size=1 << 26), # 64 MB - parse_options=pv.ParseOptions(delimiter=delimiter) + path, + read_options=pv.ReadOptions(block_size=1 << 26), # 64 MB + parse_options=pv.ParseOptions(delimiter=delimiter), + ) + schema = pa.schema( + [pa.field(name, pa.string()) for name in schema_reader.schema.names] ) - schema = pa.schema([pa.field(name, pa.string()) for name in schema_reader.schema.names]) csv_reader = pv.open_csv( - path, - read_options=pv.ReadOptions(block_size=1 << 26), # 64 MB - parse_options=pv.ParseOptions(delimiter=delimiter), - convert_options=pv.ConvertOptions(column_types=schema) + path, + read_options=pv.ReadOptions(block_size=1 << 26), # 64 MB + parse_options=pv.ParseOptions(delimiter=delimiter), + convert_options=pv.ConvertOptions(column_types=schema), ) - with pq.ParquetWriter(parquet_path, csv_reader.schema) as writer: + with pq.ParquetWriter( + self._table_cache(table_name), csv_reader.schema + ) as writer: for batch in csv_reader: writer.write_batch(batch) pass return dd.read_parquet( - self.cache_dir / "tables" / f"{table_name}.parquet", - split_row_groups=True, # type: ignore + self._table_cache(table_name), + split_row_groups=True, # type: ignore blocksize="64MB", ) @@ -391,10 +443,11 @@ def unique_patient_ids(self) -> List[str]: """ if self._unique_patient_ids is None: self._unique_patient_ids = ( - self.collected_global_event_df.index + self.collected_global_event_df.select(pl.col("patient_id")) .unique() - .compute() - .tolist() + .collect() + .to_series() + .to_list() ) logger.info(f"Found {len(self._unique_patient_ids)} unique patient IDs") return self._unique_patient_ids @@ -414,18 +467,13 @@ def get_patient(self, patient_id: str) -> Patient: assert ( patient_id in self.unique_patient_ids ), f"Patient {patient_id} not found in dataset" - df = self.collected_global_event_df - if not isinstance(df, dd.DataFrame): - raise TypeError("collected_global_event_df must be a Dask DataFrame") - - patient_df: pl.DataFrame = pl.from_pandas(df.loc[patient_id].compute()) + patient_df: pl.DataFrame = self.collected_global_event_df.filter( + pl.col("patient_id") == patient_id + ).collect() return Patient(patient_id=patient_id, data_source=patient_df) - def iter_patients(self, df: Optional[dd.DataFrame] = None) -> Iterator[Patient]: - """Yields Patient objects for each unique patient in the dataset. - This method is inefficient, you should prefer to use - `self.colllected_global_event_df.groupby(("patient_id", )).apply(...)` directly - if possible. + def iter_patients(self, df: Optional[pl.LazyFrame] = None) -> Iterator[Patient]: + """Yields Patient objects for each unique patient in the dataset. Yields: Iterator[Patient]: An iterator over Patient objects. @@ -440,12 +488,28 @@ def stats(self) -> None: """Prints statistics about the dataset.""" df = self.collected_global_event_df n_patients = len(self.unique_patient_ids) - n_events = df.shape[0].compute() + n_events = df.select(pl.count()).collect().item() print(f"Dataset: {self.dataset_name}") print(f"Dev mode: {self.dev}") print(f"Number of patients: {n_patients}") print(f"Number of events: {n_events}") + def transform_fn(self, task: BaseTask) -> Callable[[str], Iterator[Dict[str, Any]]]: + ctx = namedtuple("ctx", ["data", "task"])( + data=self._dataset_cache(), + task=task, + ) + + def f(patient_id: str) -> Iterator[Dict[str, Any]]: + patient_df: pl.DataFrame = pl.read_parquet(ctx.data).filter( + pl.col("patient_id") == patient_id + ) + patient = Patient(patient_id=patient_id, data_source=patient_df) + for sample in ctx.task(patient): + yield sample + + return f + @property def default_task(self) -> Optional[BaseTask]: """Returns the default task for the dataset. @@ -459,7 +523,7 @@ def set_task( self, task: Optional[BaseTask] = None, num_workers: int = 1, - cache_dir: Optional[str] = None, + cache_dir: str | Path | None = None, cache_format: str = "parquet", input_processors: Optional[Dict[str, FeatureProcessor]] = None, output_processors: Optional[Dict[str, FeatureProcessor]] = None, @@ -496,107 +560,41 @@ def set_task( f"Setting task {task.task_name} for {self.dataset_name} base dataset..." ) - if cache_dir is not None: - logger.warning(f"This argument cache_dir is deprecated. Use dataset cache_dir instead.") if cache_format != "parquet": - logger.warning(f"Only 'parquet' cache_format is officially supported now.") - - # Check for cached data if cache_dir is provided - samples = None - if cache_dir is not None: - cache_filename = f"{task.task_name}.{cache_format}" - cache_path = Path(cache_dir) / cache_filename - if cache_path.exists(): - logger.info(f"Loading cached samples from {cache_path}") - try: - if cache_format == "parquet": - # Load samples from parquet file - cached_df = pl.read_parquet(cache_path) - samples = [ - _restore_from_cache(row) for row in cached_df.to_dicts() - ] - elif cache_format == "pickle": - # Load samples from pickle file - with open(cache_path, "rb") as f: - samples = pickle.load(f) - else: - msg = f"Unsupported cache format: {cache_format}" - raise ValueError(msg) - logger.info(f"Loaded {len(samples)} cached samples") - except Exception as e: - logger.warning( - "Failed to load cached data: %s. Regenerating...", - e, - ) - samples = None - - # Generate samples if not loaded from cache - if samples is None: - logger.info(f"Generating samples with {num_workers} worker(s)...") - filtered_global_event_df = task.pre_filter(self.collected_global_event_df) - samples = [] - - if num_workers == 1: - # single-threading (by default) - for patient in tqdm( - self.iter_patients(filtered_global_event_df), - total=filtered_global_event_df["patient_id"].n_unique(), - desc=(f"Generating samples for {task.task_name} " "with 1 worker"), - smoothing=0, - ): - samples.extend(task(patient)) - else: - # multi-threading (not recommended) - logger.info( - f"Generating samples for {task.task_name} with " - f"{num_workers} workers" - ) - patients = list(self.iter_patients(filtered_global_event_df)) - with ThreadPoolExecutor(max_workers=num_workers) as executor: - futures = [executor.submit(task, patient) for patient in patients] - for future in tqdm( - as_completed(futures), - total=len(futures), - desc=( - f"Collecting samples for {task.task_name} " - f"from {num_workers} workers" - ), - ): - samples.extend(future.result()) - - # Cache the samples if cache_dir is provided - if cache_dir is not None: - cache_path = Path(cache_dir) / cache_filename - cache_path.parent.mkdir(parents=True, exist_ok=True) - logger.info(f"Caching samples to {cache_path}") - try: - if cache_format == "parquet": - # Save samples as parquet file - samples_for_cache = [ - _convert_for_cache(sample) for sample in samples - ] - samples_df = pl.DataFrame(samples_for_cache) - samples_df.write_parquet(cache_path) - elif cache_format == "pickle": - # Save samples as pickle file - with open(cache_path, "wb") as f: - pickle.dump(samples, f) - else: - msg = f"Unsupported cache format: {cache_format}" - raise ValueError(msg) - logger.info(f"Successfully cached {len(samples)} samples") - except Exception as e: - logger.warning(f"Failed to cache samples: {e}") - - sample_dataset = SampleDataset( - samples, - input_schema=task.input_schema, - output_schema=task.output_schema, - dataset_name=self.dataset_name, - task_name=task, - input_processors=input_processors, - output_processors=output_processors, - ) + logger.warning( + f"This argument is no longer supported: cache_format={cache_format}" + ) + if cache_dir is None: + cache_dir = self._task_cache(task.task_name) + logger.info( + "No cache_dir provided. Using default task cache dir: %s", cache_dir + ) + + if not path_exists(str(cache_dir)): + import litdata as ld + + ld.optimize( + fn=_transform_fn, + inputs=[ + (patient_id, self._dataset_cache(), task) + for patient_id in self.unique_patient_ids + ], + output_dir=str(cache_dir), + num_workers=num_workers, + chunk_bytes="64MB", + ) + + sample_dataset = None + + # SampleDataset( + # samples, + # input_schema=task.input_schema, + # output_schema=task.output_schema, + # dataset_name=self.dataset_name, + # task_name=task, + # input_processors=input_processors, + # output_processors=output_processors, + # ) - logger.info(f"Generated {len(samples)} samples for task {task.task_name}") + # logger.info(f"Generated {len(samples)} samples for task {task.task_name}") return sample_dataset From f7ea6454c04d557f179b0f24f51988f60359f5aa Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 21 Nov 2025 17:08:21 -0500 Subject: [PATCH 29/51] Change SampleDataset to IterableDataset --- examples/memtest.py | 99 ++++++++++++++++------- pyhealth/datasets/base_dataset.py | 33 ++++---- pyhealth/datasets/sample_dataset.py | 29 ++++--- pyhealth/processors/base_processor.py | 4 +- pyhealth/processors/label_processor.py | 4 +- pyhealth/processors/stagenet_processor.py | 18 ++--- 6 files changed, 120 insertions(+), 67 deletions(-) diff --git a/examples/memtest.py b/examples/memtest.py index 8a63090e8..56d3f748c 100644 --- a/examples/memtest.py +++ b/examples/memtest.py @@ -1,32 +1,75 @@ # %% -import psutil, os, time, threading -PEAK_MEM_USAGE = 0 -SELF_PROC = psutil.Process(os.getpid()) +import multiprocessing as mp +mp.set_start_method("spawn", force=True) -def track_mem(): - global PEAK_MEM_USAGE - while True: - m = SELF_PROC.memory_info().rss - if m > PEAK_MEM_USAGE: - PEAK_MEM_USAGE = m - time.sleep(0.1) +from pyhealth.datasets import ( + MIMIC4Dataset, + get_dataloader, + split_by_patient, +) +from pyhealth.models import StageNet +from pyhealth.tasks import MortalityPredictionStageNetMIMIC4 +from pyhealth.trainer import Trainer +import torch -threading.Thread(target=track_mem, daemon=True).start() -print(f"[MEM] start={PEAK_MEM_USAGE / (1024**3)} GB") +if __name__ == "__main__": + # STEP 1: Load MIMIC-IV base dataset + base_dataset = MIMIC4Dataset( + ehr_root="/home/logic/physionet.org/files/mimiciv/3.1", + ehr_tables=[ + "patients", + "admissions", + "diagnoses_icd", + "procedures_icd", + "labevents", + ], + ) -# %% -from pyhealth.datasets import MIMIC4Dataset -DATASET_DIR = "/home/logic/physionet.org/files/mimiciv/3.1" -dataset = MIMIC4Dataset( - ehr_root=DATASET_DIR, - ehr_tables=[ - "patients", - "admissions", - "diagnoses_icd", - "procedures_icd", - "prescriptions", - "labevents", - ], -) -print(f"[MEM] __init__={PEAK_MEM_USAGE / (1024**3):.3f} GB") -# %% + # STEP 2: Apply StageNet mortality prediction task + sample_dataset = base_dataset.set_task( + MortalityPredictionStageNetMIMIC4(), + num_workers=4, + cache_dir="../../mimic4_stagenet_cache", + ) + + print(f"Total samples: {len(sample_dataset)}") + print(f"Input schema: {sample_dataset.input_schema}") + print(f"Output schema: {sample_dataset.output_schema}") + + # Inspect a sample + sample = next(iter(sample_dataset)) + print("\nSample structure:") + print(f" Patient ID: {sample['patient_id']}") + print(f"ICD Codes: {sample['icd_codes']}") + print(f" Labs shape: {len(sample['labs'][0])} timesteps") + print(f" Mortality: {sample['mortality']}") + + # Create dataloaders + train_loader = get_dataloader(sample_dataset, batch_size=256, shuffle=True) + + # STEP 4: Initialize StageNet model + model = StageNet( + dataset=sample_dataset, + embedding_dim=128, + chunk_size=128, + levels=3, + dropout=0.3, + ) + + num_params = sum(p.numel() for p in model.parameters()) + print(f"\nModel initialized with {num_params} parameters") + + # STEP 5: Train the model + trainer = Trainer( + model=model, + device="cuda:5", # or "cpu" + metrics=["pr_auc", "roc_auc", "accuracy", "f1"], + ) + + trainer.train( + train_dataloader=train_loader, + val_dataloader=train_loader, + epochs=50, + monitor="roc_auc", + optimizer_params={"lr": 1e-5}, + ) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index afc523c48..5e722e9da 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -19,6 +19,7 @@ import pyarrow.parquet as pq import requests import platformdirs +from litdata.streaming import StreamingDataset from ..data import Patient from ..tasks import BaseTask @@ -123,6 +124,11 @@ def scan_file(file_path: str) -> pl.LazyFrame: raise FileNotFoundError(f"Neither path exists: {path} or {alt_path}") +def _pickle(datum: dict[str, Any]) -> dict[str, bytes]: + return {k: pickle.dumps(v) for k,v in datum.items()} + +def _unpickle(datum: dict[str, bytes]) -> dict[str, Any]: + return {k: pickle.loads(v) for k,v in datum.items()} def _transform_fn(input: tuple[str, str, BaseTask]) -> Iterator[Dict[str, Any]]: (patient_id, path, task) = input @@ -131,9 +137,8 @@ def _transform_fn(input: tuple[str, str, BaseTask]) -> Iterator[Dict[str, Any]]: data_source=pl.read_parquet(path).filter(pl.col("patient_id") == patient_id), ) for sample in task(patient): - sample = {k: pickle.dumps(v) for k,v in sample.items()} - yield sample - + # Schema is too complex to be handled by LitData, so we pickle the sample here + yield _pickle(sample) class BaseDataset(ABC): """Abstract base class for all PyHealth datasets. @@ -584,17 +589,17 @@ def set_task( chunk_bytes="64MB", ) - sample_dataset = None + streaming_dataset = StreamingDataset(str(cache_dir), transform=_unpickle) - # SampleDataset( - # samples, - # input_schema=task.input_schema, - # output_schema=task.output_schema, - # dataset_name=self.dataset_name, - # task_name=task, - # input_processors=input_processors, - # output_processors=output_processors, - # ) + sample_dataset = SampleDataset( + streaming_dataset, + input_schema=task.input_schema, # type: ignore + output_schema=task.output_schema, # type: ignore + dataset_name=self.dataset_name, + task_name=task.task_name, + input_processors=input_processors, + output_processors=output_processors, + ) - # logger.info(f"Generated {len(samples)} samples for task {task.task_name}") + logger.info(f"Generated {len(sample_dataset)} samples for task {task.task_name}") return sample_dataset diff --git a/pyhealth/datasets/sample_dataset.py b/pyhealth/datasets/sample_dataset.py index 54c40420c..1ee4811b7 100644 --- a/pyhealth/datasets/sample_dataset.py +++ b/pyhealth/datasets/sample_dataset.py @@ -1,14 +1,15 @@ -from typing import Any, Dict, List, Optional, Tuple, Union, Type +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Type import inspect -from torch.utils.data import Dataset +from torch.utils.data import IterableDataset +from litdata.streaming import StreamingDataset from tqdm import tqdm from ..processors import get_processor from ..processors.base_processor import FeatureProcessor -class SampleDataset(Dataset): +class SampleDataset(IterableDataset): """Sample dataset class for handling and processing data samples. Attributes: @@ -23,7 +24,7 @@ class SampleDataset(Dataset): def __init__( self, - samples: List[Dict], + dataset: StreamingDataset, input_schema: Dict[str, Union[str, Type[FeatureProcessor], FeatureProcessor, Tuple[Union[str, Type[FeatureProcessor]], Dict[str, Any]]]], output_schema: Dict[str, Union[str, Type[FeatureProcessor], FeatureProcessor, Tuple[Union[str, Type[FeatureProcessor]], Dict[str, Any]]]], dataset_name: Optional[str] = None, @@ -56,7 +57,7 @@ def __init__( dataset_name = "" if task_name is None: task_name = "" - self.samples = samples + self.dataset = dataset self.input_schema = input_schema self.output_schema = output_schema self.input_processors = input_processors if input_processors is not None else {} @@ -69,7 +70,7 @@ def __init__( self.patient_to_index = {} self.record_to_index = {} - for i, sample in enumerate(samples): + for i, sample in enumerate(iter(self.dataset)): # Create patient_to_index mapping patient_id = sample.get("patient_id") if patient_id is not None: @@ -128,7 +129,7 @@ def validate(self) -> None: """Validates that the samples match the input and output schemas.""" input_keys = set(self.input_schema.keys()) output_keys = set(self.output_schema.keys()) - for s in self.samples: + for s in iter(self.dataset): assert input_keys.issubset(s.keys()), "Input schema does not match samples." assert output_keys.issubset(s.keys()), ( "Output schema does not match samples." @@ -141,13 +142,13 @@ def build(self) -> None: if not self.input_processors: for k, v in self.input_schema.items(): self.input_processors[k] = self._get_processor_instance(v) - self.input_processors[k].fit(self.samples, k) + self.input_processors[k].fit(iter(self.dataset), k) if not self.output_processors: for k, v in self.output_schema.items(): self.output_processors[k] = self._get_processor_instance(v) - self.output_processors[k].fit(self.samples, k) + self.output_processors[k].fit(iter(self.dataset), k) # Always process samples with the (fitted) processors - for sample in tqdm(self.samples, desc="Processing samples"): + for sample in tqdm(iter(self.dataset), desc="Processing samples"): for k, v in sample.items(): if k in self.input_processors: sample[k] = self.input_processors[k].process(v) @@ -155,6 +156,10 @@ def build(self) -> None: sample[k] = self.output_processors[k].process(v) return + def __iter__(self) -> Iterator: + # TODO: transform samples on the fly + return self.dataset.__iter__() + def __getitem__(self, index: int) -> Dict: """Returns a sample by index. @@ -166,7 +171,7 @@ def __getitem__(self, index: int) -> Dict: task-specific attributes as key. Conversion to index/tensor will be done in the model. """ - return self.samples[index] + return self.dataset.__getitem__(index) def __str__(self) -> str: """Returns a string representation of the dataset. @@ -182,4 +187,4 @@ def __len__(self) -> int: Returns: int: The number of samples. """ - return len(self.samples) + return self.dataset.__len__() diff --git a/pyhealth/processors/base_processor.py b/pyhealth/processors/base_processor.py index 050cb5357..d207f0220 100644 --- a/pyhealth/processors/base_processor.py +++ b/pyhealth/processors/base_processor.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Iterator class Processor(ABC): @@ -33,7 +33,7 @@ class FeatureProcessor(Processor): Example: Tokenization, image loading, normalization. """ - def fit(self, samples: List[Dict[str, Any]], field: str) -> None: + def fit(self, samples: Iterator[Dict[str, Any]], field: str) -> None: """Fit the processor to the samples. Args: diff --git a/pyhealth/processors/label_processor.py b/pyhealth/processors/label_processor.py index ad2df1897..a9f9937d9 100644 --- a/pyhealth/processors/label_processor.py +++ b/pyhealth/processors/label_processor.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict, List +from typing import Any, Dict, Iterator, List import torch @@ -19,7 +19,7 @@ def __init__(self): super().__init__() self.label_vocab: Dict[Any, int] = {} - def fit(self, samples: List[Dict[str, Any]], field: str) -> None: + def fit(self, samples: Iterator[Dict[str, Any]], field: str) -> None: all_labels = set([sample[field] for sample in samples]) if len(all_labels) != 2: raise ValueError(f"Expected 2 unique labels, got {len(all_labels)}") diff --git a/pyhealth/processors/stagenet_processor.py b/pyhealth/processors/stagenet_processor.py index cbbafac94..420163f51 100644 --- a/pyhealth/processors/stagenet_processor.py +++ b/pyhealth/processors/stagenet_processor.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple import torch @@ -61,7 +61,7 @@ def __init__(self, padding: int = 0): self._max_nested_len = None # Max inner sequence length for nested codes self._padding = padding # Additional padding beyond observed max - def fit(self, samples: List[Dict], key: str) -> None: + def fit(self, samples: Iterator[dict[str, Any]], field: str) -> None: """Build vocabulary and determine input structure. Args: @@ -70,9 +70,9 @@ def fit(self, samples: List[Dict], key: str) -> None: """ # Examine first non-None sample to determine structure for sample in samples: - if key in sample and sample[key] is not None: + if field in sample and sample[field] is not None: # Unpack tuple: (time, values) - time_data, value_data = sample[key] + time_data, value_data = sample[field] # Determine nesting level for codes if isinstance(value_data, list) and len(value_data) > 0: @@ -90,9 +90,9 @@ def fit(self, samples: List[Dict], key: str) -> None: # Build vocabulary for codes and find max nested length max_inner_len = 0 for sample in samples: - if key in sample and sample[key] is not None: + if field in sample and sample[field] is not None: # Unpack tuple: (time, values) - time_data, value_data = sample[key] + time_data, value_data = sample[field] if self._is_nested: # Nested codes @@ -256,7 +256,7 @@ def __init__(self): self._size = None # Feature dimension (set during fit) self._is_nested = None - def fit(self, samples: List[Dict], key: str) -> None: + def fit(self, samples: Iterator[dict[str, Any]], field: str) -> None: """Determine input structure. Args: @@ -265,9 +265,9 @@ def fit(self, samples: List[Dict], key: str) -> None: """ # Examine first non-None sample to determine structure for sample in samples: - if key in sample and sample[key] is not None: + if field in sample and sample[field] is not None: # Unpack tuple: (time, values) - time_data, value_data = sample[key] + time_data, value_data = sample[field] # Determine nesting level for numerics if isinstance(value_data, list) and len(value_data) > 0: From a71218e4673745ece8b078e428f27d82be5c4d92 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 21 Nov 2025 19:41:36 -0500 Subject: [PATCH 30/51] Distributed Progress Bar, Bucekt partition Carshing Co-authored-by: John Wu <54558896+jhnwu3@users.noreply.github.com> --- pyhealth/datasets/base_dataset.py | 47 +++++++++++++++++++------------ pyproject.toml | 3 +- 2 files changed, 31 insertions(+), 19 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 5e722e9da..bab3cda6d 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -11,6 +11,7 @@ from collections import namedtuple import pickle +from distributed import get_client import polars as pl import pandas as pd import dask.dataframe as dd @@ -20,6 +21,8 @@ import requests import platformdirs from litdata.streaming import StreamingDataset +from dask.distributed import progress +import xxhash from ..data import Patient from ..tasks import BaseTask @@ -132,6 +135,10 @@ def _unpickle(datum: dict[str, bytes]) -> dict[str, Any]: def _transform_fn(input: tuple[str, str, BaseTask]) -> Iterator[Dict[str, Any]]: (patient_id, path, task) = input + with open(f"{path}/index.json", "rb") as f: + n_partitions = json.load(f)["n_partitions"] + bucket = xxhash.xxh64_intdigest(patient_id) % n_partitions + path = f"{path}/bucket={bucket}" patient = Patient( patient_id=patient_id, data_source=pl.read_parquet(path).filter(pl.col("patient_id") == patient_id), @@ -283,8 +290,28 @@ def collected_global_event_df(self) -> pl.LazyFrame: patients = global_event_df["patient_id"].unique().head(1000).tolist() filter = global_event_df["patient_id"].isin(patients) global_event_df: dd.DataFrame = global_event_df[filter] - global_event_df = global_event_df.sort_values(by=["patient_id"]) - global_event_df.to_parquet(self._dataset_cache()) + + mem_usage = global_event_df.memory_usage(deep=False).compute().sum() + n_partitions = mem_usage // (256 * 1024 * 1024) + 1 + bucket = global_event_df["patient_id"].apply( + xxhash.xxh64_intdigest, meta=("patient_id", "int") + ) % n_partitions + global_event_df = global_event_df.assign(bucket=bucket) + + logger.info(f"Estimated full global event dataframe size {mem_usage / (1024**3):.2f} GB") + logger.info(f"Repartitioning global event dataframe into {n_partitions} partitions for caching.") + + client = get_client() + handle = global_event_df.to_parquet( + self._dataset_cache(), + partition_on=["bucket"], + write_index=False, + compute=False, + ) + future = client.compute(handle) + progress(future) + with open(self._dataset_cache() + "/index.json", "w") as future: + json.dump({"n_partitions": int(n_partitions)}, future) return pl.scan_parquet(self._dataset_cache()) def load_data(self) -> dd.DataFrame: @@ -499,22 +526,6 @@ def stats(self) -> None: print(f"Number of patients: {n_patients}") print(f"Number of events: {n_events}") - def transform_fn(self, task: BaseTask) -> Callable[[str], Iterator[Dict[str, Any]]]: - ctx = namedtuple("ctx", ["data", "task"])( - data=self._dataset_cache(), - task=task, - ) - - def f(patient_id: str) -> Iterator[Dict[str, Any]]: - patient_df: pl.DataFrame = pl.read_parquet(ctx.data).filter( - pl.col("patient_id") == patient_id - ) - patient = Patient(patient_id=patient_id, data_source=patient_df) - for sample in ctx.task(patient): - yield sample - - return f - @property def default_task(self) -> Optional[BaseTask]: """Returns the default task for the dataset. diff --git a/pyproject.toml b/pyproject.toml index 356c12d96..9c3ac0f7b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,8 +39,9 @@ dependencies = [ "pandas~=2.3.1", "pandarallel~=1.6.5", "pydantic~=2.11.7", - "dask[dataframe,distributed]~=2025.11.0", + "dask[complete]~=2025.11.0", "litdata~=0.2.58", + "xxhash~=3.6.0", ] license = "BSD-3-Clause" license-files = ["LICENSE.md"] From 4c6aec8c094506194e1daa59bceace2605eb87fd Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 21 Nov 2025 20:09:02 -0500 Subject: [PATCH 31/51] Better cache system --- pyhealth/datasets/base_dataset.py | 228 +++++++++++++----------------- 1 file changed, 99 insertions(+), 129 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index bab3cda6d..b8e27978c 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -90,59 +90,27 @@ def alt_path(path: str) -> str: raise ValueError(f"Path does not have expected extension: {path}") -def scan_csv_gz_or_csv_tsv(path: str) -> pl.LazyFrame: - """ - Scan a CSV.gz, CSV, TSV.gz, or TSV file and returns a LazyFrame. - It will fall back to the other extension if not found. - - Args: - path (str): URL or local path to a .csv, .csv.gz, .tsv, or .tsv.gz file - - Returns: - pl.LazyFrame: The LazyFrame for the CSV.gz, CSV, TSV.gz, or TSV file. - """ - - def scan_file(file_path: str) -> pl.LazyFrame: - separator = "\t" if ".tsv" in file_path else "," - return pl.scan_csv(file_path, separator=separator, infer_schema=False) - - if path_exists(path): - return scan_file(path) - - # Try the alternative extension - if path.endswith(".csv.gz"): - alt_path = path[:-3] # Remove .gz -> try .csv - elif path.endswith(".csv"): - alt_path = f"{path}.gz" # Add .gz -> try .csv.gz - elif path.endswith(".tsv.gz"): - alt_path = path[:-3] # Remove .gz -> try .tsv - elif path.endswith(".tsv"): - alt_path = f"{path}.gz" # Add .gz -> try .tsv.gz - else: - raise FileNotFoundError(f"Path does not have expected extension: {path}") - - if path_exists(alt_path): - logger.info(f"Original path does not exist. Using alternative: {alt_path}") - return scan_file(alt_path) - - raise FileNotFoundError(f"Neither path exists: {path} or {alt_path}") - def _pickle(datum: dict[str, Any]) -> dict[str, bytes]: - return {k: pickle.dumps(v) for k,v in datum.items()} + return {k: pickle.dumps(v) for k, v in datum.items()} + def _unpickle(datum: dict[str, bytes]) -> dict[str, Any]: - return {k: pickle.loads(v) for k,v in datum.items()} + return {k: pickle.loads(v) for k, v in datum.items()} -def _transform_fn(input: tuple[str, str, BaseTask]) -> Iterator[Dict[str, Any]]: - (patient_id, path, task) = input - with open(f"{path}/index.json", "rb") as f: +def _get_patient(merged_cache: str, patient_id: str) -> Patient: + with open(merged_cache + "/index.json", "rb") as f: n_partitions = json.load(f)["n_partitions"] bucket = xxhash.xxh64_intdigest(patient_id) % n_partitions - path = f"{path}/bucket={bucket}" + path = f"{merged_cache}/bucket={bucket}" patient = Patient( patient_id=patient_id, data_source=pl.read_parquet(path).filter(pl.col("patient_id") == patient_id), ) + return patient + +def _transform_fn(input: tuple[str, str, BaseTask]) -> Iterator[Dict[str, Any]]: + (patient_id, path, task) = input + patient = _get_patient(path, patient_id) for sample in task(patient): # Schema is too complex to be handled by LitData, so we pickle the sample here yield _pickle(sample) @@ -201,9 +169,6 @@ def __init__( f"Initializing {self.dataset_name} dataset from {self.root} (dev mode: {self.dev})" ) - # Cached attributes - self._unique_patient_ids = None - @staticmethod def cache_subfolder( root: str, tables: List[str], dataset_name: str, dev: bool @@ -246,26 +211,6 @@ def setup_cache_dir( f"Initializing {self.dataset_name} dataset cache directory to {self.cache_dir}" ) - def _table_cache(self, table_name: str) -> str: - """Generates the cache path for a specific table. - - Args: - table_name (str): The name of the table. - - Returns: - str: The cache path for the table. - """ - (self.cache_dir / "tables").mkdir(parents=True, exist_ok=True) - return str(self.cache_dir / "tables" / f"{table_name}.parquet") - - def _dataset_cache(self) -> str: - """Generates the cache path for the global event dataframe. - - Returns: - str: The cache path for the global event dataframe. - """ - return str(self.cache_dir / "global_event_df.parquet") - def _task_cache(self, task_name: str) -> str: """Generates the cache path for a specific task. @@ -276,43 +221,61 @@ def _task_cache(self, task_name: str) -> str: """ return str(self.cache_dir / "tasks" / task_name) - @property - def collected_global_event_df(self) -> pl.LazyFrame: + def _merged_cache(self) -> str: """Collects and returns the global event data frame. Returns: dd.DataFrame: The collected global event data frame. """ + ret_path = str(self.cache_dir / "global_event_df.parquet") - if not path_exists(self._dataset_cache()): + if not path_exists(ret_path): global_event_df = self.load_data() + # In dev mode, limit to 1000 patients if self.dev: patients = global_event_df["patient_id"].unique().head(1000).tolist() filter = global_event_df["patient_id"].isin(patients) global_event_df: dd.DataFrame = global_event_df[filter] + # Collect unique patient IDs + patients = global_event_df["patient_id"].unique().compute().tolist() + n_patients = len(patients) + n_events = global_event_df.shape[0].compute() + logger.info(f"Collected {n_events} events for {n_patients} patients.") + + # Estimate memory usage and partitioning mem_usage = global_event_df.memory_usage(deep=False).compute().sum() - n_partitions = mem_usage // (256 * 1024 * 1024) + 1 - bucket = global_event_df["patient_id"].apply( - xxhash.xxh64_intdigest, meta=("patient_id", "int") - ) % n_partitions + n_partitions = mem_usage // (256 * 1024 * 1024) + 1 # 256 MB per partition + bucket = ( + global_event_df["patient_id"].apply( + xxhash.xxh64_intdigest, meta=("patient_id", "int") + ) + % n_partitions + ) global_event_df = global_event_df.assign(bucket=bucket) - - logger.info(f"Estimated full global event dataframe size {mem_usage / (1024**3):.2f} GB") - logger.info(f"Repartitioning global event dataframe into {n_partitions} partitions for caching.") + logger.info( + f"Estimated size {mem_usage / (1024**3):.2f} GB, write to {n_partitions} partitions." + ) client = get_client() handle = global_event_df.to_parquet( - self._dataset_cache(), + ret_path, partition_on=["bucket"], write_index=False, compute=False, ) future = client.compute(handle) progress(future) - with open(self._dataset_cache() + "/index.json", "w") as future: - json.dump({"n_partitions": int(n_partitions)}, future) - return pl.scan_parquet(self._dataset_cache()) + with open(ret_path + "/index.json", "w") as future: + json.dump({ + "n_partitions": int(n_partitions), + "n_patients": int(n_patients), + "n_events": int(n_events), + "patients": patients, + "algorithm": "xxhash64_modulo_partitioning", + }, future) + + return ret_path def load_data(self) -> dd.DataFrame: """Loads data from the specified tables. @@ -344,7 +307,11 @@ def load_table(self, table_name: str) -> dd.DataFrame: csv_path = clean_path(csv_path) logger.info(f"Scanning table: {table_name} from {csv_path}") - df: dd.DataFrame = self.load_csv_or_tsv(table_name, csv_path) + df: dd.DataFrame = dd.read_parquet( + self._table_cache(table_name, source_path=csv_path), + split_row_groups=True, # type: ignore + blocksize="64MB", + ) # Check if there is a preprocessing function for this table # TODO: we need to update the preprocess function to work with Dask DataFrame @@ -361,8 +328,10 @@ def load_table(self, table_name: str) -> dd.DataFrame: other_csv_path = f"{self.root}/{join_cfg.file_path}" other_csv_path = clean_path(other_csv_path) logger.info(f"Joining with table: {other_csv_path}") - join_df: dd.DataFrame = self.load_csv_or_tsv( - f"{table_name}_join_{i}", other_csv_path + join_df: dd.DataFrame = dd.read_parquet( + self._table_cache(f"{table_name}_join_{i}", other_csv_path), + split_row_groups=True, # type: ignore + blocksize="64MB", ) join_key = join_cfg.on columns = join_cfg.columns @@ -416,55 +385,65 @@ def load_table(self, table_name: str) -> dd.DataFrame: return event_frame - def load_csv_or_tsv(self, table_name: str, path: str) -> dd.DataFrame: - """Loads a CSV.gz, CSV, TSV.gz, or TSV file into a Dask DataFrame. + def _table_cache(self, table_name: str, source_path: str | None = None) -> str: + """Generates the cache path for a specific table. If the cached Parquet file does not exist, + it will convert the source CSV/TSV file to Parquet and save it to the cache. Args: table_name (str): The name of the table. - path (str): The URL or local path to the .csv, .csv.gz, .tsv, or .tsv.gz file. + source_path (str | None): The source CSV/TSV file path. If None, it assumes the + Parquet file already exists in the cache. + Returns: - dd.DataFrame: The loaded Dask DataFrame. + str: The cache path for the table. """ - if not path_exists(self._table_cache(table_name)): - # convert .gz file to .parquet file since Dask cannot split on gz files directly - if not path_exists(path): - if not path_exists(alt_path(path)): + # Ensure the tables cache directory exists + (self.cache_dir / "tables").mkdir(parents=True, exist_ok=True) + ret_path = str(self.cache_dir / "tables" / f"{table_name}.parquet") + + if not path_exists(ret_path): + if source_path is None: + raise FileNotFoundError( + f"Table {table_name} not found in cache and no source_path provided." + ) + + # Check if source_path exists, else try alternative path + if not path_exists(source_path): + if not path_exists(alt_path(source_path)): raise FileNotFoundError( - f"Neither path exists: {path} or {alt_path(path)}" + f"Neither path exists: {source_path} or {alt_path(source_path)}" ) - path = alt_path(path) + source_path = alt_path(source_path) + # Determine delimiter based on file extension delimiter = ( - "\t" if path.endswith(".tsv") or path.endswith(".tsv.gz") else "," + "\t" + if source_path.endswith(".tsv") or source_path.endswith(".tsv.gz") + else "," ) # Always infer schema as string to avoid incorrect type inference schema_reader = pv.open_csv( - path, + source_path, read_options=pv.ReadOptions(block_size=1 << 26), # 64 MB parse_options=pv.ParseOptions(delimiter=delimiter), ) schema = pa.schema( [pa.field(name, pa.string()) for name in schema_reader.schema.names] ) + + # Convert CSV/TSV to Parquet csv_reader = pv.open_csv( - path, + source_path, read_options=pv.ReadOptions(block_size=1 << 26), # 64 MB parse_options=pv.ParseOptions(delimiter=delimiter), convert_options=pv.ConvertOptions(column_types=schema), ) - with pq.ParquetWriter( - self._table_cache(table_name), csv_reader.schema - ) as writer: + with pq.ParquetWriter(ret_path, csv_reader.schema) as writer: for batch in csv_reader: writer.write_batch(batch) - pass - return dd.read_parquet( - self._table_cache(table_name), - split_row_groups=True, # type: ignore - blocksize="64MB", - ) + return ret_path @property def unique_patient_ids(self) -> List[str]: @@ -473,16 +452,9 @@ def unique_patient_ids(self) -> List[str]: Returns: List[str]: List of unique patient IDs. """ - if self._unique_patient_ids is None: - self._unique_patient_ids = ( - self.collected_global_event_df.select(pl.col("patient_id")) - .unique() - .collect() - .to_series() - .to_list() - ) - logger.info(f"Found {len(self._unique_patient_ids)} unique patient IDs") - return self._unique_patient_ids + with open(self._merged_cache() + "/index.json", "r") as f: + index_info = json.load(f) + return index_info["patients"] def get_patient(self, patient_id: str) -> Patient: """Retrieves a Patient object for the given patient ID. @@ -499,10 +471,7 @@ def get_patient(self, patient_id: str) -> Patient: assert ( patient_id in self.unique_patient_ids ), f"Patient {patient_id} not found in dataset" - patient_df: pl.DataFrame = self.collected_global_event_df.filter( - pl.col("patient_id") == patient_id - ).collect() - return Patient(patient_id=patient_id, data_source=patient_df) + return _get_patient(self._merged_cache(), patient_id) def iter_patients(self, df: Optional[pl.LazyFrame] = None) -> Iterator[Patient]: """Yields Patient objects for each unique patient in the dataset. @@ -510,17 +479,16 @@ def iter_patients(self, df: Optional[pl.LazyFrame] = None) -> Iterator[Patient]: Yields: Iterator[Patient]: An iterator over Patient objects. """ - if df is None: - df = self.collected_global_event_df - for patitent_id in self.unique_patient_ids: yield self.get_patient(patitent_id) def stats(self) -> None: """Prints statistics about the dataset.""" - df = self.collected_global_event_df - n_patients = len(self.unique_patient_ids) - n_events = df.select(pl.count()).collect().item() + with open(self._merged_cache() + "/index.json", "r") as f: + index_info = json.load(f) + n_patients = index_info["n_patients"] + n_events = index_info["n_events"] + print(f"Dataset: {self.dataset_name}") print(f"Dev mode: {self.dev}") print(f"Number of patients: {n_patients}") @@ -592,7 +560,7 @@ def set_task( ld.optimize( fn=_transform_fn, inputs=[ - (patient_id, self._dataset_cache(), task) + (patient_id, self._merged_cache(), task) for patient_id in self.unique_patient_ids ], output_dir=str(cache_dir), @@ -604,13 +572,15 @@ def set_task( sample_dataset = SampleDataset( streaming_dataset, - input_schema=task.input_schema, # type: ignore - output_schema=task.output_schema, # type: ignore + input_schema=task.input_schema, # type: ignore + output_schema=task.output_schema, # type: ignore dataset_name=self.dataset_name, task_name=task.task_name, input_processors=input_processors, output_processors=output_processors, ) - logger.info(f"Generated {len(sample_dataset)} samples for task {task.task_name}") + logger.info( + f"Generated {len(sample_dataset)} samples for task {task.task_name}" + ) return sample_dataset From 741f9a66b255294254134e9d272353fdf3c27068 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 21 Nov 2025 20:32:51 -0500 Subject: [PATCH 32/51] Better apply task --- examples/memtest.py | 14 ++++++- pyhealth/datasets/base_dataset.py | 68 +++++++++++++++++++------------ 2 files changed, 55 insertions(+), 27 deletions(-) diff --git a/examples/memtest.py b/examples/memtest.py index 56d3f748c..1cfe79079 100644 --- a/examples/memtest.py +++ b/examples/memtest.py @@ -11,8 +11,18 @@ from pyhealth.tasks import MortalityPredictionStageNetMIMIC4 from pyhealth.trainer import Trainer import torch +import dask.config +from dask.distributed import Client, LocalCluster if __name__ == "__main__": + dask.config.set({"temporary-directory": "/mnt/tmpfs/"}) + cluster = LocalCluster( + n_workers=4, + threads_per_worker=1, + memory_limit="8GB", + ) + client = Client(cluster) + # STEP 1: Load MIMIC-IV base dataset base_dataset = MIMIC4Dataset( ehr_root="/home/logic/physionet.org/files/mimiciv/3.1", @@ -25,11 +35,11 @@ ], ) + print(f"Patients: {base_dataset.unique_patient_ids[:10]}, ...") + # STEP 2: Apply StageNet mortality prediction task sample_dataset = base_dataset.set_task( MortalityPredictionStageNetMIMIC4(), - num_workers=4, - cache_dir="../../mimic4_stagenet_cache", ) print(f"Total samples: {len(sample_dataset)}") diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index b8e27978c..97639bd9e 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -97,10 +97,16 @@ def _pickle(datum: dict[str, Any]) -> dict[str, bytes]: def _unpickle(datum: dict[str, bytes]) -> dict[str, Any]: return {k: pickle.loads(v) for k, v in datum.items()} + +def _patient_bucket(patient_id: str, n_partitions: int) -> int: + bucket = xxhash.xxh64_intdigest(patient_id) % n_partitions + return bucket + + def _get_patient(merged_cache: str, patient_id: str) -> Patient: with open(merged_cache + "/index.json", "rb") as f: n_partitions = json.load(f)["n_partitions"] - bucket = xxhash.xxh64_intdigest(patient_id) % n_partitions + bucket = _patient_bucket(patient_id, n_partitions) path = f"{merged_cache}/bucket={bucket}" patient = Patient( patient_id=patient_id, @@ -108,12 +114,21 @@ def _get_patient(merged_cache: str, patient_id: str) -> Patient: ) return patient -def _transform_fn(input: tuple[str, str, BaseTask]) -> Iterator[Dict[str, Any]]: - (patient_id, path, task) = input - patient = _get_patient(path, patient_id) - for sample in task(patient): - # Schema is too complex to be handled by LitData, so we pickle the sample here - yield _pickle(sample) + +def _transform_fn( + input: tuple[int, str, BaseTask], +) -> Iterator[Dict[str, Any]]: + (bucket_id, merged_cache, task) = input + path = f"{merged_cache}/bucket={bucket_id}" + # This is more efficient than reading patient by patient + grouped = pl.read_parquet(path).group_by("patient_id") + + for patient_id, patient_df in grouped: + patient = Patient(patient_id=str(patient_id[0]), data_source=patient_df) + for sample in task(patient): + # Schema is too complex to be handled by LitData, so we pickle the sample here + yield _pickle(sample) + class BaseDataset(ABC): """Abstract base class for all PyHealth datasets. @@ -245,12 +260,10 @@ def _merged_cache(self) -> str: # Estimate memory usage and partitioning mem_usage = global_event_df.memory_usage(deep=False).compute().sum() - n_partitions = mem_usage // (256 * 1024 * 1024) + 1 # 256 MB per partition - bucket = ( - global_event_df["patient_id"].apply( - xxhash.xxh64_intdigest, meta=("patient_id", "int") - ) - % n_partitions + n_partitions = mem_usage // (256 * 1024 * 1024) + 1 # 256 MB per partition + bucket = global_event_df["patient_id"].apply( + lambda pid: _patient_bucket(pid, n_partitions), + meta=("patient_id", "int"), ) global_event_df = global_event_df.assign(bucket=bucket) logger.info( @@ -267,14 +280,17 @@ def _merged_cache(self) -> str: future = client.compute(handle) progress(future) with open(ret_path + "/index.json", "w") as future: - json.dump({ - "n_partitions": int(n_partitions), - "n_patients": int(n_patients), - "n_events": int(n_events), - "patients": patients, - "algorithm": "xxhash64_modulo_partitioning", - }, future) - + json.dump( + { + "n_partitions": int(n_partitions), + "n_patients": int(n_patients), + "n_events": int(n_events), + "patients": patients, + "algorithm": "xxhash64_modulo_partitioning", + }, + future, + ) + return ret_path def load_data(self) -> dd.DataFrame: @@ -557,12 +573,14 @@ def set_task( if not path_exists(str(cache_dir)): import litdata as ld + with open(self._merged_cache() + "/index.json", "r") as f: + index_info = json.load(f) + n_partitions = index_info["n_partitions"] + inputs = [(i, self._merged_cache(), task) for i in range(n_partitions)] + ld.optimize( fn=_transform_fn, - inputs=[ - (patient_id, self._merged_cache(), task) - for patient_id in self.unique_patient_ids - ], + inputs=inputs, output_dir=str(cache_dir), num_workers=num_workers, chunk_bytes="64MB", From 525c1217bde0c06ac814122d8518a015e1152dee Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 21 Nov 2025 21:05:27 -0500 Subject: [PATCH 33/51] Fix bug --- pyhealth/datasets/base_dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 97639bd9e..d82782839 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -99,7 +99,7 @@ def _unpickle(datum: dict[str, bytes]) -> dict[str, Any]: def _patient_bucket(patient_id: str, n_partitions: int) -> int: - bucket = xxhash.xxh64_intdigest(patient_id) % n_partitions + bucket = int(xxhash.xxh64_intdigest(patient_id) % n_partitions) return bucket @@ -234,6 +234,7 @@ def _task_cache(self, task_name: str) -> str: Returns: str: The cache path for the task. """ + (self.cache_dir / "tasks").mkdir(parents=True, exist_ok=True) return str(self.cache_dir / "tasks" / task_name) def _merged_cache(self) -> str: From 791bbda91596597178472849697f1d199a221ce3 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 21 Nov 2025 21:21:25 -0500 Subject: [PATCH 34/51] Fixup --- pyhealth/datasets/base_dataset.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index d82782839..f8b469cd5 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -574,6 +574,8 @@ def set_task( if not path_exists(str(cache_dir)): import litdata as ld + get_client().close() # Close existing client to avoid conflicts with LitData + with open(self._merged_cache() + "/index.json", "r") as f: index_info = json.load(f) n_partitions = index_info["n_partitions"] From b51cc26d073e3ad4ef5e1b591c96d5aa7e7ebbcc Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 21 Nov 2025 21:26:19 -0500 Subject: [PATCH 35/51] Fixup --- pyhealth/datasets/base_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index f8b469cd5..c1f3dd70a 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -291,6 +291,8 @@ def _merged_cache(self) -> str: }, future, ) + + client.shutdown() return ret_path @@ -574,8 +576,6 @@ def set_task( if not path_exists(str(cache_dir)): import litdata as ld - get_client().close() # Close existing client to avoid conflicts with LitData - with open(self._merged_cache() + "/index.json", "r") as f: index_info = json.load(f) n_partitions = index_info["n_partitions"] From 54e201813cc0c7a764ff0217555f424f2db5ac4e Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 21 Nov 2025 21:27:34 -0500 Subject: [PATCH 36/51] Fixup --- pyhealth/datasets/base_dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index c1f3dd70a..0cdcd2e41 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -291,9 +291,9 @@ def _merged_cache(self) -> str: }, future, ) - - client.shutdown() - + + # Ensure the Dask client is properly closed to avoid resource leaks + get_client().shutdown() return ret_path def load_data(self) -> dd.DataFrame: From 44243593e440d1c21a96298a7ac46942d96442d1 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Fri, 21 Nov 2025 21:42:38 -0500 Subject: [PATCH 37/51] Fixup --- examples/memtest.py | 38 ++++++++++++++--------------- pyhealth/datasets/base_dataset.py | 4 +-- pyhealth/datasets/sample_dataset.py | 13 ++++++++++ pyhealth/datasets/utils.py | 5 ++-- 4 files changed, 35 insertions(+), 25 deletions(-) diff --git a/examples/memtest.py b/examples/memtest.py index 1cfe79079..d0772ee39 100644 --- a/examples/memtest.py +++ b/examples/memtest.py @@ -1,7 +1,4 @@ # %% -import multiprocessing as mp -mp.set_start_method("spawn", force=True) - from pyhealth.datasets import ( MIMIC4Dataset, get_dataloader, @@ -16,32 +13,33 @@ if __name__ == "__main__": dask.config.set({"temporary-directory": "/mnt/tmpfs/"}) - cluster = LocalCluster( - n_workers=4, + with LocalCluster( + n_workers=16, threads_per_worker=1, memory_limit="8GB", - ) - client = Client(cluster) - - # STEP 1: Load MIMIC-IV base dataset - base_dataset = MIMIC4Dataset( - ehr_root="/home/logic/physionet.org/files/mimiciv/3.1", - ehr_tables=[ - "patients", - "admissions", - "diagnoses_icd", - "procedures_icd", - "labevents", - ], - ) + ) as cluster: + with Client(cluster) as client: + # STEP 1: Load MIMIC-IV base dataset + base_dataset = MIMIC4Dataset( + ehr_root="/home/logic/physionet.org/files/mimiciv/3.1", + ehr_tables=[ + "patients", + "admissions", + "diagnoses_icd", + "procedures_icd", + "labevents", + ], + dev=True + ) + base_dataset._merged_cache() print(f"Patients: {base_dataset.unique_patient_ids[:10]}, ...") # STEP 2: Apply StageNet mortality prediction task sample_dataset = base_dataset.set_task( MortalityPredictionStageNetMIMIC4(), + num_workers=4, ) - print(f"Total samples: {len(sample_dataset)}") print(f"Input schema: {sample_dataset.input_schema}") print(f"Output schema: {sample_dataset.output_schema}") diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 0cdcd2e41..d82782839 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -291,9 +291,7 @@ def _merged_cache(self) -> str: }, future, ) - - # Ensure the Dask client is properly closed to avoid resource leaks - get_client().shutdown() + return ret_path def load_data(self) -> dd.DataFrame: diff --git a/pyhealth/datasets/sample_dataset.py b/pyhealth/datasets/sample_dataset.py index 1ee4811b7..d4e098926 100644 --- a/pyhealth/datasets/sample_dataset.py +++ b/pyhealth/datasets/sample_dataset.py @@ -88,6 +88,19 @@ def __init__( self.validate() self.build() + def set_shuffle(self, shuffle: bool) -> None: + """Sets whether to shuffle the dataset. + + Args: + shuffle (bool): Whether to shuffle the dataset. + """ + if hasattr(self.dataset, "set_shuffle"): + self.dataset.set_shuffle(shuffle) + else: + raise NotImplementedError( + "Shuffle is not implemented for this dataset type." + ) + def _get_processor_instance(self, processor_spec): """Get processor instance from either string alias, class reference, processor instance, or tuple with kwargs. diff --git a/pyhealth/datasets/utils.py b/pyhealth/datasets/utils.py index 63ca4152a..33878a8bc 100644 --- a/pyhealth/datasets/utils.py +++ b/pyhealth/datasets/utils.py @@ -10,6 +10,7 @@ from torch.utils.data import DataLoader from pyhealth import BASE_CACHE_PATH +from pyhealth.datasets.sample_dataset import SampleDataset from pyhealth.utils import create_directory MODULE_CACHE_PATH = os.path.join(BASE_CACHE_PATH, "datasets") @@ -319,7 +320,7 @@ def collate_fn_dict_with_padding(batch: List[dict]) -> dict: def get_dataloader( - dataset: torch.utils.data.Dataset, batch_size: int, shuffle: bool = False + dataset: SampleDataset, batch_size: int, shuffle: bool = False ) -> DataLoader: """Creates a DataLoader for a given dataset. @@ -331,10 +332,10 @@ def get_dataloader( Returns: A DataLoader instance for the dataset. """ + dataset.set_shuffle(shuffle) dataloader = DataLoader( dataset, batch_size=batch_size, - shuffle=shuffle, collate_fn=collate_fn_dict_with_padding, ) From c14af095471426b8f6ed296551db4a0cf222f57e Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sat, 22 Nov 2025 20:04:37 -0500 Subject: [PATCH 38/51] Move actual compute ctor Co-authored-by: John Wu <54558896+jhnwu3 --- examples/memtest.py | 115 ++++++------ pyhealth/datasets/base_dataset.py | 298 ++++++++++++++++-------------- pyhealth/datasets/mimic4.py | 52 +++--- 3 files changed, 242 insertions(+), 223 deletions(-) diff --git a/examples/memtest.py b/examples/memtest.py index d0772ee39..7b0d51d44 100644 --- a/examples/memtest.py +++ b/examples/memtest.py @@ -13,71 +13,66 @@ if __name__ == "__main__": dask.config.set({"temporary-directory": "/mnt/tmpfs/"}) - with LocalCluster( - n_workers=16, - threads_per_worker=1, - memory_limit="8GB", - ) as cluster: - with Client(cluster) as client: - # STEP 1: Load MIMIC-IV base dataset - base_dataset = MIMIC4Dataset( - ehr_root="/home/logic/physionet.org/files/mimiciv/3.1", - ehr_tables=[ - "patients", - "admissions", - "diagnoses_icd", - "procedures_icd", - "labevents", - ], - dev=True - ) - base_dataset._merged_cache() + + base_dataset = MIMIC4Dataset( + ehr_root="/home/logic/physionet.org/files/mimiciv/3.1", + ehr_tables=[ + "patients", + "admissions", + "diagnoses_icd", + "procedures_icd", + "labevents", + ], + num_workers=8, + mem_per_worker="8GB", + dev=True + ) - print(f"Patients: {base_dataset.unique_patient_ids[:10]}, ...") + # print(f"Patients: {base_dataset.unique_patient_ids[:10]}, ...") - # STEP 2: Apply StageNet mortality prediction task - sample_dataset = base_dataset.set_task( - MortalityPredictionStageNetMIMIC4(), - num_workers=4, - ) - print(f"Total samples: {len(sample_dataset)}") - print(f"Input schema: {sample_dataset.input_schema}") - print(f"Output schema: {sample_dataset.output_schema}") + # # STEP 2: Apply StageNet mortality prediction task + # sample_dataset = base_dataset.set_task( + # MortalityPredictionStageNetMIMIC4(), + # num_workers=4, + # ) + # print(f"Total samples: {len(sample_dataset)}") + # print(f"Input schema: {sample_dataset.input_schema}") + # print(f"Output schema: {sample_dataset.output_schema}") - # Inspect a sample - sample = next(iter(sample_dataset)) - print("\nSample structure:") - print(f" Patient ID: {sample['patient_id']}") - print(f"ICD Codes: {sample['icd_codes']}") - print(f" Labs shape: {len(sample['labs'][0])} timesteps") - print(f" Mortality: {sample['mortality']}") + # # Inspect a sample + # sample = next(iter(sample_dataset)) + # print("\nSample structure:") + # print(f" Patient ID: {sample['patient_id']}") + # print(f"ICD Codes: {sample['icd_codes']}") + # print(f" Labs shape: {len(sample['labs'][0])} timesteps") + # print(f" Mortality: {sample['mortality']}") - # Create dataloaders - train_loader = get_dataloader(sample_dataset, batch_size=256, shuffle=True) + # # Create dataloaders + # train_loader = get_dataloader(sample_dataset, batch_size=256, shuffle=True) - # STEP 4: Initialize StageNet model - model = StageNet( - dataset=sample_dataset, - embedding_dim=128, - chunk_size=128, - levels=3, - dropout=0.3, - ) + # # STEP 4: Initialize StageNet model + # model = StageNet( + # dataset=sample_dataset, + # embedding_dim=128, + # chunk_size=128, + # levels=3, + # dropout=0.3, + # ) - num_params = sum(p.numel() for p in model.parameters()) - print(f"\nModel initialized with {num_params} parameters") + # num_params = sum(p.numel() for p in model.parameters()) + # print(f"\nModel initialized with {num_params} parameters") - # STEP 5: Train the model - trainer = Trainer( - model=model, - device="cuda:5", # or "cpu" - metrics=["pr_auc", "roc_auc", "accuracy", "f1"], - ) + # # STEP 5: Train the model + # trainer = Trainer( + # model=model, + # device="cuda:5", # or "cpu" + # metrics=["pr_auc", "roc_auc", "accuracy", "f1"], + # ) - trainer.train( - train_dataloader=train_loader, - val_dataloader=train_loader, - epochs=50, - monitor="roc_auc", - optimizer_params={"lr": 1e-5}, - ) + # trainer.train( + # train_dataloader=train_loader, + # val_dataloader=train_loader, + # epochs=50, + # monitor="roc_auc", + # optimizer_params={"lr": 1e-5}, + # ) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index d82782839..dd97a8ffe 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -11,7 +11,8 @@ from collections import namedtuple import pickle -from distributed import get_client +from distributed import Client, LocalCluster, get_client +from dask.utils import parse_bytes import polars as pl import pandas as pd import dask.dataframe as dd @@ -30,6 +31,8 @@ from .configs import load_yaml_config from .sample_dataset import SampleDataset +# Set logging level for distributed to ERROR to reduce verbosity +logging.getLogger("distributed").setLevel(logging.ERROR) logger = logging.getLogger(__name__) @@ -99,6 +102,7 @@ def _unpickle(datum: dict[str, bytes]) -> dict[str, Any]: def _patient_bucket(patient_id: str, n_partitions: int) -> int: + """Hash patient_id to a bucket number.""" bucket = int(xxhash.xxh64_intdigest(patient_id) % n_partitions) return bucket @@ -149,6 +153,9 @@ def __init__( dataset_name: str | None = None, config_path: str | None = None, cache_dir: str | Path | None = None, + num_workers: int = 1, + mem_per_worker: str | int = "8GB", + compute: bool = True, dev: bool = False, ): """Initializes the BaseDataset. @@ -162,9 +169,6 @@ def __init__( cache directory will be created under the platform's cache directory. dev (bool): Whether to run in dev mode (limits to 1000 patients). """ - if config_path is None: - raise ValueError("config_path must be provided") - if len(set(tables)) != len(tables): logger.warning("Duplicate table names in tables list. Removing duplicates.") tables = list(set(tables)) @@ -172,59 +176,52 @@ def __init__( self.root = root self.tables = tables self.dataset_name = dataset_name or self.__class__.__name__ - self.config = load_yaml_config(config_path) self.dev = dev + if config_path is not None: + self.config = load_yaml_config(config_path) - subfolder = self.cache_subfolder( - self.root, self.tables, self.dataset_name, self.dev - ) - self.setup_cache_dir(cache_dir=cache_dir, subfolder=subfolder) + # Resource allocation for table joining and processing + self.num_workers = num_workers + if isinstance(mem_per_worker, str): + self.mem_per_worker = parse_bytes(mem_per_worker) + else: + self.mem_per_worker = mem_per_worker + + # Cached value for property + self._cache_dir = cache_dir logger.info( f"Initializing {self.dataset_name} dataset from {self.root} (dev mode: {self.dev})" ) + if compute: + _ = self._joined_cache() - @staticmethod - def cache_subfolder( - root: str, tables: List[str], dataset_name: str, dev: bool - ) -> str: - """Generates a unique identifier for the dataset instance. This is used for creating - cache directories. The UUID is based on the root path, tables, dataset name, and dev mode. + @property + def cache_dir(self) -> Path: + """Returns the cache directory path. Returns: - str: A unique identifier string. + Path: The cache directory path. """ + if self._cache_dir is not None: + return Path(self._cache_dir) + id_str = json.dumps( { - "root": root, - "tables": sorted(tables), - "dataset_name": dataset_name, - "dev": dev, + "root": self.root, + "tables": sorted(self.tables), + "dataset_name": self.dataset_name, + "dev": self.dev, }, sort_keys=True, ) - return str(uuid.uuid5(uuid.NAMESPACE_DNS, id_str)) - - def setup_cache_dir( - self, cache_dir: str | Path | None = None, subfolder: str = str(uuid.uuid4()) - ) -> None: - """Creates the cache directory structure. - - Args: - cache_dir (str | Path | None): The base cache directory. If None, a default cache - directory will be created under the platform's cache directory. - subfolder (str): Subfolder name for this dataset instance's cache. - """ - if cache_dir is None: - cache_dir = platformdirs.user_cache_dir(appname="pyhealth") - logger.info( - f"No cache_dir provided. Using default cache for PyHealth: {cache_dir}" - ) - cache_dir = Path(cache_dir) - self.cache_dir = cache_dir / subfolder - logger.info( - f"Initializing {self.dataset_name} dataset cache directory to {self.cache_dir}" + cache_dir = Path(platformdirs.user_cache_dir(appname="pyhealth")) / str( + uuid.uuid5(uuid.NAMESPACE_DNS, id_str) ) + print(f"No cache_dir provided. Using default cache dir: {cache_dir}") + self._cache_dir = cache_dir + + return cache_dir def _task_cache(self, task_name: str) -> str: """Generates the cache path for a specific task. @@ -237,49 +234,130 @@ def _task_cache(self, task_name: str) -> str: (self.cache_dir / "tasks").mkdir(parents=True, exist_ok=True) return str(self.cache_dir / "tasks" / task_name) - def _merged_cache(self) -> str: - """Collects and returns the global event data frame. + def _table_cache(self, table_name: str, source_path: str | None = None) -> str: + """Generates the cache path for a specific table. If the cached Parquet file does not exist, + it will convert the source CSV/TSV file to Parquet and save it to the cache. + + Args: + table_name (str): The name of the table. + source_path (str | None): The source CSV/TSV file path. If None, it assumes the + Parquet file already exists in the cache. Returns: - dd.DataFrame: The collected global event data frame. + str: The cache path for the table. """ - ret_path = str(self.cache_dir / "global_event_df.parquet") + # Ensure the tables cache directory exists + (self.cache_dir / "tables").mkdir(parents=True, exist_ok=True) + ret_path = str(self.cache_dir / "tables" / f"{table_name}.parquet") if not path_exists(ret_path): - global_event_df = self.load_data() - # In dev mode, limit to 1000 patients - if self.dev: - patients = global_event_df["patient_id"].unique().head(1000).tolist() - filter = global_event_df["patient_id"].isin(patients) - global_event_df: dd.DataFrame = global_event_df[filter] - - # Collect unique patient IDs - patients = global_event_df["patient_id"].unique().compute().tolist() - n_patients = len(patients) - n_events = global_event_df.shape[0].compute() - logger.info(f"Collected {n_events} events for {n_patients} patients.") - - # Estimate memory usage and partitioning - mem_usage = global_event_df.memory_usage(deep=False).compute().sum() - n_partitions = mem_usage // (256 * 1024 * 1024) + 1 # 256 MB per partition - bucket = global_event_df["patient_id"].apply( - lambda pid: _patient_bucket(pid, n_partitions), - meta=("patient_id", "int"), + if source_path is None: + raise FileNotFoundError( + f"Table {table_name} not found in cache and no source_path provided." + ) + + # Check if source_path exists, else try alternative path + if not path_exists(source_path): + if not path_exists(alt_path(source_path)): + raise FileNotFoundError( + f"Neither path exists: {source_path} or {alt_path(source_path)}" + ) + source_path = alt_path(source_path) + + # Determine delimiter based on file extension + delimiter = ( + "\t" + if source_path.endswith(".tsv") or source_path.endswith(".tsv.gz") + else "," ) - global_event_df = global_event_df.assign(bucket=bucket) - logger.info( - f"Estimated size {mem_usage / (1024**3):.2f} GB, write to {n_partitions} partitions." + + # Always infer schema as string to avoid incorrect type inference + schema_reader = pv.open_csv( + source_path, + read_options=pv.ReadOptions(block_size=1 << 26), # 64 MB + parse_options=pv.ParseOptions(delimiter=delimiter), + ) + schema = pa.schema( + [pa.field(name, pa.string()) for name in schema_reader.schema.names] ) - client = get_client() - handle = global_event_df.to_parquet( - ret_path, - partition_on=["bucket"], - write_index=False, - compute=False, + # Convert CSV/TSV to Parquet + csv_reader = pv.open_csv( + source_path, + read_options=pv.ReadOptions(block_size=1 << 26), # 64 MB + parse_options=pv.ParseOptions(delimiter=delimiter), + convert_options=pv.ConvertOptions(column_types=schema), ) - future = client.compute(handle) - progress(future) + with pq.ParquetWriter(ret_path, csv_reader.schema) as writer: + for batch in csv_reader: + writer.write_batch(batch) + + return ret_path + + def _joined_cache(self) -> str: + """Collects and returns the global event data frame. + + Returns: + dd.DataFrame: The collected global event data frame. + """ + ret_path = str(self.cache_dir / "global_event_df.parquet") + + if not path_exists(ret_path): + with LocalCluster( + n_workers=self.num_workers, + threads_per_worker=1, + memory_limit=self.mem_per_worker, + config={"distributed.nanny.terminate_timeout": "60s"}, + ) as cluster: + with Client(cluster) as client: + global_event_df = self.load_data() + # In dev mode, limit to 1000 patients + if self.dev: + patients = ( + global_event_df["patient_id"].unique().head(1000).tolist() + ) + filter = global_event_df["patient_id"].isin(patients) + global_event_df: dd.DataFrame = global_event_df[filter] + + # Collect unique patient IDs + patients = global_event_df["patient_id"].unique().compute().tolist() + n_patients = len(patients) + n_events = global_event_df.shape[0].compute() + logger.info( + f"Collected {n_events} events for {n_patients} patients." + ) + + # Estimate memory usage and partitioning + partition_size = ( + self.mem_per_worker // 16 + ) # Use 1/16 of worker memory for safety + estimated_mem_usage = ( + global_event_df.memory_usage(deep=False).compute().sum() + ) + n_partitions = ( + estimated_mem_usage // partition_size + ) + 1 # Calculate number of partitions based on memory usage + n_partitions = max( + n_partitions, self.num_workers + ) # At least num_workers partitions + + bucket = global_event_df["patient_id"].apply( + lambda pid: _patient_bucket(pid, n_partitions), + meta=("patient_id", "int"), + ) + global_event_df = global_event_df.assign(bucket=bucket) + logger.info( + f"Estimated size {estimated_mem_usage / (1024**3):.2f} GB, write to {n_partitions} partitions." + ) + + handle = global_event_df.to_parquet( + ret_path, + partition_on=["bucket"], + write_index=False, + compute=False, + ) + future = client.compute(handle) + progress(future) with open(ret_path + "/index.json", "w") as future: json.dump( { @@ -402,66 +480,6 @@ def load_table(self, table_name: str) -> dd.DataFrame: return event_frame - def _table_cache(self, table_name: str, source_path: str | None = None) -> str: - """Generates the cache path for a specific table. If the cached Parquet file does not exist, - it will convert the source CSV/TSV file to Parquet and save it to the cache. - - Args: - table_name (str): The name of the table. - source_path (str | None): The source CSV/TSV file path. If None, it assumes the - Parquet file already exists in the cache. - - Returns: - str: The cache path for the table. - """ - # Ensure the tables cache directory exists - (self.cache_dir / "tables").mkdir(parents=True, exist_ok=True) - ret_path = str(self.cache_dir / "tables" / f"{table_name}.parquet") - - if not path_exists(ret_path): - if source_path is None: - raise FileNotFoundError( - f"Table {table_name} not found in cache and no source_path provided." - ) - - # Check if source_path exists, else try alternative path - if not path_exists(source_path): - if not path_exists(alt_path(source_path)): - raise FileNotFoundError( - f"Neither path exists: {source_path} or {alt_path(source_path)}" - ) - source_path = alt_path(source_path) - - # Determine delimiter based on file extension - delimiter = ( - "\t" - if source_path.endswith(".tsv") or source_path.endswith(".tsv.gz") - else "," - ) - - # Always infer schema as string to avoid incorrect type inference - schema_reader = pv.open_csv( - source_path, - read_options=pv.ReadOptions(block_size=1 << 26), # 64 MB - parse_options=pv.ParseOptions(delimiter=delimiter), - ) - schema = pa.schema( - [pa.field(name, pa.string()) for name in schema_reader.schema.names] - ) - - # Convert CSV/TSV to Parquet - csv_reader = pv.open_csv( - source_path, - read_options=pv.ReadOptions(block_size=1 << 26), # 64 MB - parse_options=pv.ParseOptions(delimiter=delimiter), - convert_options=pv.ConvertOptions(column_types=schema), - ) - with pq.ParquetWriter(ret_path, csv_reader.schema) as writer: - for batch in csv_reader: - writer.write_batch(batch) - - return ret_path - @property def unique_patient_ids(self) -> List[str]: """Returns a list of unique patient IDs. @@ -469,7 +487,7 @@ def unique_patient_ids(self) -> List[str]: Returns: List[str]: List of unique patient IDs. """ - with open(self._merged_cache() + "/index.json", "r") as f: + with open(self._joined_cache() + "/index.json", "r") as f: index_info = json.load(f) return index_info["patients"] @@ -488,7 +506,7 @@ def get_patient(self, patient_id: str) -> Patient: assert ( patient_id in self.unique_patient_ids ), f"Patient {patient_id} not found in dataset" - return _get_patient(self._merged_cache(), patient_id) + return _get_patient(self._joined_cache(), patient_id) def iter_patients(self, df: Optional[pl.LazyFrame] = None) -> Iterator[Patient]: """Yields Patient objects for each unique patient in the dataset. @@ -501,7 +519,7 @@ def iter_patients(self, df: Optional[pl.LazyFrame] = None) -> Iterator[Patient]: def stats(self) -> None: """Prints statistics about the dataset.""" - with open(self._merged_cache() + "/index.json", "r") as f: + with open(self._joined_cache() + "/index.json", "r") as f: index_info = json.load(f) n_patients = index_info["n_patients"] n_events = index_info["n_events"] @@ -574,10 +592,10 @@ def set_task( if not path_exists(str(cache_dir)): import litdata as ld - with open(self._merged_cache() + "/index.json", "r") as f: + with open(self._joined_cache() + "/index.json", "r") as f: index_info = json.load(f) n_partitions = index_info["n_partitions"] - inputs = [(i, self._merged_cache(), task) for i in range(n_partitions)] + inputs = [(i, self._joined_cache(), task) for i in range(n_partitions)] ld.optimize( fn=_transform_fn, diff --git a/pyhealth/datasets/mimic4.py b/pyhealth/datasets/mimic4.py index e95c9934e..03ef61a1a 100644 --- a/pyhealth/datasets/mimic4.py +++ b/pyhealth/datasets/mimic4.py @@ -21,7 +21,7 @@ def log_memory_usage(tag=""): """Log current memory usage if psutil is available.""" if HAS_PSUTIL: - process = psutil.Process(os.getpid()) + process = psutil.Process(os.getpid()) # type: ignore mem_info = process.memory_info() logger.info(f"Memory usage {tag}: {mem_info.rss / (1024 * 1024):.1f} MB") else: @@ -210,21 +210,27 @@ def __init__( ehr_config_path: Optional[str] = None, note_config_path: Optional[str] = None, cxr_config_path: Optional[str] = None, - cache_dir: str | Path | None = None, dataset_name: str = "mimic4", + cache_dir: str | Path | None = None, + num_workers: int = 1, + mem_per_worker: str | int = "8GB", + compute: bool = True, dev: bool = False, ): + super().__init__( + root=f"{str(ehr_root)},{str(note_root)},{str(cxr_root)}", + tables=(ehr_tables or []) + (note_tables or []) + (cxr_tables or []), + dataset_name=dataset_name, + cache_dir=cache_dir, + num_workers=num_workers, + mem_per_worker=mem_per_worker, + compute=False, # defer compute to later, we need to aggregate all sub-datasets first + dev=dev, + ) log_memory_usage("Starting MIMIC4Dataset init") # Initialize child datasets - self.dataset_name = dataset_name self.sub_datasets = {} - self.root = None - self.tables = None - self.config = None - # Dev flag is only used in the MIMIC4Dataset class - # to ensure the same set of patients are used for all sub-datasets. - self.dev = dev # We need at least one root directory if not any([ehr_root, note_root, cxr_root]): @@ -242,7 +248,10 @@ def __init__( root=ehr_root, tables=ehr_tables, config_path=ehr_config_path, - cache_dir=cache_dir, + cache_dir=f"{self.cache_dir}/ehr", + num_workers=num_workers, + mem_per_worker=mem_per_worker, + compute=False, # defer compute to later, we need to aggregate all sub-datasets first ) log_memory_usage("After EHR dataset initialization") @@ -253,7 +262,10 @@ def __init__( root=note_root, tables=note_tables, config_path=note_config_path, - cache_dir=cache_dir, + cache_dir=f"{self.cache_dir}/note", + num_workers=num_workers, + mem_per_worker=mem_per_worker, + compute=False, # defer compute to later, we need to aggregate all sub-datasets first ) log_memory_usage("After Note dataset initialization") @@ -264,21 +276,15 @@ def __init__( root=cxr_root, tables=cxr_tables, config_path=cxr_config_path, - cache_dir=cache_dir, + cache_dir=f"{self.cache_dir}/cxr", + num_workers=num_workers, + mem_per_worker=mem_per_worker, + compute=False, # defer compute to later, we need to aggregate all sub-datasets first ) log_memory_usage("After CXR dataset initialization") - subfolder = BaseDataset.cache_subfolder( - str(ehr_root) + str(note_root) + str(cxr_root), - ehr_tables + note_tables + cxr_tables, - self.dataset_name, - self.dev - ) - self.setup_cache_dir(cache_dir=cache_dir, subfolder=subfolder) - - # Cache attributes - self._unique_patient_ids = None - + if compute: + _ = self._joined_cache() log_memory_usage("Completed MIMIC4Dataset init") def load_data(self) -> dd.DataFrame: From 86d04c633a2b3f751e54a960948b4141223482aa Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sat, 22 Nov 2025 21:57:26 -0500 Subject: [PATCH 39/51] Update Patient to use pandas, because polars will have nested mp issues. Co-authored-by: John Wu <54558896+jhnwu3 --- pyhealth/data/data.py | 53 ++++++++-------- tests/core/test_patient_event.py | 103 +++++++++++++++++++++++++++++++ 2 files changed, 131 insertions(+), 25 deletions(-) create mode 100644 tests/core/test_patient_event.py diff --git a/pyhealth/data/data.py b/pyhealth/data/data.py index c9b88b1a6..c9dfe9b2f 100644 --- a/pyhealth/data/data.py +++ b/pyhealth/data/data.py @@ -5,7 +5,7 @@ from typing import Dict, List, Mapping, Optional, Union, Any import numpy as np -import polars as pl +import pandas as pd @dataclass(frozen=True) @@ -119,54 +119,57 @@ class Patient: Attributes: patient_id (str): Unique patient identifier. - data_source (pl.DataFrame): DataFrame containing all events, sorted by timestamp. - event_type_partitions (Dict[str, pl.DataFrame]): Dictionary mapping event types to their respective DataFrame partitions. + data_source (pd.DataFrame): DataFrame containing all events, sorted by timestamp. + event_type_partitions (Dict[str, pd.DataFrame]): Dictionary mapping event types to their respective DataFrame partitions. """ - def __init__(self, patient_id: str, data_source: pl.DataFrame) -> None: + def __init__(self, patient_id: str, data_source: pd.DataFrame) -> None: """ Initialize a Patient instance. Args: patient_id (str): Unique patient identifier. - data_source (pl.DataFrame): DataFrame containing all events. + data_source (pd.DataFrame): DataFrame containing all events. """ self.patient_id = patient_id - self.data_source = data_source.sort("timestamp") - self.event_type_partitions = self.data_source.partition_by("event_type", maintain_order=True, as_dict=True) + self.data_source = data_source.sort_values("timestamp") + self.event_type_partitions = { + (event_type,): group.copy() + for event_type, group in self.data_source.groupby("event_type", sort=False) + } - def _filter_by_time_range_regular(self, df: pl.DataFrame, start: Optional[datetime], end: Optional[datetime]) -> pl.DataFrame: + def _filter_by_time_range_regular(self, df: pd.DataFrame, start: Optional[datetime], end: Optional[datetime]) -> pd.DataFrame: """Regular filtering by time. Time complexity: O(n).""" if start is not None: - df = df.filter(pl.col("timestamp") >= start) + df = df[df["timestamp"] >= start] if end is not None: - df = df.filter(pl.col("timestamp") <= end) + df = df[df["timestamp"] <= end] return df - def _filter_by_time_range_fast(self, df: pl.DataFrame, start: Optional[datetime], end: Optional[datetime]) -> pl.DataFrame: + def _filter_by_time_range_fast(self, df: pd.DataFrame, start: Optional[datetime], end: Optional[datetime]) -> pd.DataFrame: """Fast filtering by time using binary search on sorted timestamps. Time complexity: O(log n).""" if start is None and end is None: return df - df = df.filter(pl.col("timestamp").is_not_null()) - ts_col = df["timestamp"].dt.epoch("s").to_numpy() + df = df[df["timestamp"].notna()].sort_values("timestamp") + ts_col = pd.to_datetime(df["timestamp"]).to_numpy(dtype="datetime64[ns]") start_idx = 0 end_idx = len(ts_col) if start is not None: - start_idx = np.searchsorted(ts_col, start.timestamp(), side="left") + start_idx = np.searchsorted(ts_col, np.datetime64(start, "ns"), side="left") if end is not None: - end_idx = np.searchsorted(ts_col, end.timestamp(), side="right") - return df.slice(start_idx, end_idx - start_idx) + end_idx = np.searchsorted(ts_col, np.datetime64(end, "ns"), side="right") + return df.iloc[start_idx:end_idx] - def _filter_by_event_type_regular(self, df: pl.DataFrame, event_type: Optional[str]) -> pl.DataFrame: + def _filter_by_event_type_regular(self, df: pd.DataFrame, event_type: Optional[str]) -> pd.DataFrame: """Regular filtering by event type. Time complexity: O(n).""" if event_type: - df = df.filter(pl.col("event_type") == event_type) + df = df[df["event_type"] == event_type] return df - def _filter_by_event_type_fast(self, df: pl.DataFrame, event_type: Optional[str]) -> pl.DataFrame: + def _filter_by_event_type_fast(self, df: pd.DataFrame, event_type: Optional[str]) -> pd.DataFrame: """Fast filtering by event type using pre-built event type index. Time complexity: O(1).""" if event_type: - return self.event_type_partitions.get((event_type,), df[:0]) + return self.event_type_partitions.get((event_type,), df.iloc[0:0]) else: return df @@ -177,7 +180,7 @@ def get_events( end: Optional[datetime] = None, filters: Optional[List[tuple]] = None, return_df: bool = False, - ) -> Union[pl.DataFrame, List[Event]]: + ) -> Union[pd.DataFrame, List[Event]]: """Get events with optional type and time filters. Args: @@ -191,7 +194,7 @@ def get_events( and time filters. The logic is "AND" between different filters. Returns: - Union[pl.DataFrame, List[Event]]: Filtered events as a DataFrame + Union[pd.DataFrame, List[Event]]: Filtered events as a DataFrame or a list of Event objects. """ # faster filtering (by default) @@ -213,7 +216,7 @@ def get_events( f"Invalid filter format: {filt} (must be tuple of (attr, op, value))" ) attr, op, val = filt - col_expr = pl.col(f"{event_type}/{attr}") + col_expr = df[f"{event_type}/{attr}"] # Build operator expression if op == "==": exprs.append(col_expr == val) @@ -230,7 +233,7 @@ def get_events( else: raise ValueError(f"Unsupported operator: {op} in filter {filt}") if exprs: - df = df.filter(reduce(operator.and_, exprs)) + df = df.loc[reduce(operator.and_, exprs)] if return_df: return df - return [Event.from_dict(d) for d in df.to_dicts()] \ No newline at end of file + return [Event.from_dict(d) for d in df.to_dict(orient="records")] diff --git a/tests/core/test_patient_event.py b/tests/core/test_patient_event.py new file mode 100644 index 000000000..1419f4b0c --- /dev/null +++ b/tests/core/test_patient_event.py @@ -0,0 +1,103 @@ +from datetime import datetime, timedelta +import pandas as pd + +from tests.base import BaseTestCase +from pyhealth.data.data import Event, Patient + + +class PatientEventTestCase(BaseTestCase): + def setUp(self): + base_time = datetime(2024, 1, 1, 12, 0, 0) + self.event_rows = [ + { + "patient_id": "p1", + "event_type": "med", + "timestamp": base_time + timedelta(days=1), + "med/dose": 10, + }, + { + "patient_id": "p1", + "event_type": "diag", + "timestamp": base_time + timedelta(days=2), + "diag/code": "A", + }, + { + "patient_id": "p1", + "event_type": "med", + "timestamp": base_time + timedelta(days=4), + "med/dose": 20, + }, + { + "patient_id": "p1", + "event_type": "diag", + "timestamp": base_time + timedelta(days=3), + "diag/code": "B", + "diag/severity": 2, + }, + { + "patient_id": "p1", + "event_type": "lab", + "timestamp": pd.NaT, + "lab/value": 99, + }, + ] + unsorted_df = pd.DataFrame(self.event_rows) + self.patient = Patient(patient_id="p1", data_source=unsorted_df) + self.base_time = base_time + + def test_event_accessors_and_from_dict(self): + ts = datetime(2024, 5, 1, 8, 0, 0) + event = Event(event_type="diag", timestamp=ts, code="X", value=1) + self.assertEqual(event["timestamp"], ts) + self.assertEqual(event["event_type"], "diag") + self.assertEqual(event["code"], "X") + self.assertIn("value", event) + self.assertEqual(event.value, 1) + + raw = {"event_type": "diag", "timestamp": ts, "diag/code": "Y", "diag/score": 5} + reconstructed = Event.from_dict(raw) + self.assertEqual(reconstructed.event_type, "diag") + self.assertEqual(reconstructed.attr_dict, {"code": "Y", "score": 5}) + + def test_patient_sorting_and_partitions(self): + # timestamps should be sorted with NaT at the end + timestamps = list(self.patient.data_source["timestamp"]) + sorted_without_nat = sorted([ts for ts in timestamps if pd.notna(ts)]) + self.assertEqual(timestamps[:-1], sorted_without_nat) + self.assertTrue(pd.isna(timestamps[-1])) + + diag_partition = self.patient.event_type_partitions[("diag",)] + self.assertListEqual(list(diag_partition["diag/code"]), ["A", "B"]) + + med_partition = self.patient.event_type_partitions[("med",)] + self.assertListEqual(list(med_partition["med/dose"]), [10, 20]) + + def test_get_events_by_type_and_time(self): + diag_df = self.patient.get_events(event_type="diag", return_df=True) + self.assertEqual(len(diag_df), 2) + self.assertListEqual(list(diag_df["diag/code"]), ["A", "B"]) + + # Time filtering should include bounds and drop NaT + start = self.base_time + timedelta(days=2) + end = self.base_time + timedelta(days=4) + ranged = self.patient.get_events(start=start, end=end, return_df=True) + self.assertListEqual(list(ranged["event_type"]), ["diag", "diag", "med"]) + self.assertTrue(ranged["timestamp"].between(start, end, inclusive="both").all()) + self.assertFalse(ranged["timestamp"].isna().any()) + + def test_attribute_filters(self): + filtered_events = self.patient.get_events( + event_type="diag", filters=[("code", "==", "B")] + ) + self.assertEqual(len(filtered_events), 1) + self.assertEqual(filtered_events[0].attr_dict, {"code": "B", "severity": 2}) + + filtered_df = self.patient.get_events( + event_type="diag", filters=[("code", "!=", "B")], return_df=True + ) + self.assertEqual(len(filtered_df), 1) + self.assertEqual(filtered_df.iloc[0]["diag/code"], "A") + + def test_filters_require_event_type(self): + with self.assertRaises(AssertionError): + self.patient.get_events(filters=[("code", "==", "A")]) From c82c6be3be527714466db57a382aa0a28800cae3 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sat, 22 Nov 2025 21:58:05 -0500 Subject: [PATCH 40/51] Fix typing --- pyhealth/data/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyhealth/data/data.py b/pyhealth/data/data.py index c9dfe9b2f..08b7da8b3 100644 --- a/pyhealth/data/data.py +++ b/pyhealth/data/data.py @@ -236,4 +236,4 @@ def get_events( df = df.loc[reduce(operator.and_, exprs)] if return_df: return df - return [Event.from_dict(d) for d in df.to_dict(orient="records")] + return [Event.from_dict(d) for d in df.to_dict(orient="records")] # type: ignore From 23084c9ebc1f950eba612e0f773defebeff6ac65 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sat, 22 Nov 2025 22:16:26 -0500 Subject: [PATCH 41/51] Add type safe method for Patient --- pyhealth/data/data.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pyhealth/data/data.py b/pyhealth/data/data.py index 08b7da8b3..bb305a608 100644 --- a/pyhealth/data/data.py +++ b/pyhealth/data/data.py @@ -173,6 +173,14 @@ def _filter_by_event_type_fast(self, df: pd.DataFrame, event_type: Optional[str] else: return df + def get_events_py(self, **kawargs) -> List[Event]: + """Type-safe wrapper for get_events.""" + return self.get_events(**kawargs, return_df=False) # type: ignore + + def get_events_df(self, **kawargs) -> pd.DataFrame: + """DataFrame wrapper for get_events.""" + return self.get_events(**kawargs, return_df=True) # type: ignore + def get_events( self, event_type: Optional[str] = None, From 1f1e2cc468072d075c096dc62c8185ff4c6cdb8c Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sat, 22 Nov 2025 22:37:08 -0500 Subject: [PATCH 42/51] Fix MortalityPredictionStageNetMIMIC4 to pandas --- .../mortality_prediction_stagenet_mimic4.py | 73 +++++++++++-------- 1 file changed, 41 insertions(+), 32 deletions(-) diff --git a/pyhealth/tasks/mortality_prediction_stagenet_mimic4.py b/pyhealth/tasks/mortality_prediction_stagenet_mimic4.py index 25da630be..437671067 100644 --- a/pyhealth/tasks/mortality_prediction_stagenet_mimic4.py +++ b/pyhealth/tasks/mortality_prediction_stagenet_mimic4.py @@ -1,7 +1,9 @@ from datetime import datetime from typing import Any, ClassVar, Dict, List, Type -import polars as pl +import pandas as pd + +from pyhealth.data.data import Patient from .base_task import BaseTask @@ -75,7 +77,7 @@ class MortalityPredictionStageNetMIMIC4(BaseTask): item for itemids in LAB_CATEGORIES.values() for item in itemids ] - def __call__(self, patient: Any) -> List[Dict[str, Any]]: + def __call__(self, patient: Patient) -> List[Dict[str, Any]]: """Process a patient to create mortality prediction samples. Creates ONE sample per patient with all admissions aggregated. @@ -89,7 +91,7 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: procedures, labs across visits, and final mortality label """ # Filter patients by age (>= 18) - demographics = patient.get_events(event_type="patients") + demographics = patient.get_events_py(event_type="patients") if not demographics: return [] @@ -103,7 +105,7 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: return [] # Get all admissions - admissions = patient.get_events(event_type="admissions") + admissions = patient.get_events_py(event_type="admissions") if len(admissions) < 1: return [] @@ -156,7 +158,7 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: pass # Get diagnosis codes for this admission using hadm_id - diagnoses_icd = patient.get_events( + diagnoses_icd = patient.get_events_py( event_type="diagnoses_icd", filters=[("hadm_id", "==", admission.hadm_id)], ) @@ -167,7 +169,7 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: ] # Get procedure codes for this admission using hadm_id - procedures_icd = patient.get_events( + procedures_icd = patient.get_events_py( event_type="procedures_icd", filters=[("hadm_id", "==", admission.hadm_id)], ) @@ -185,46 +187,51 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: all_icd_times.append(time_from_previous) # Get lab events for this admission - labevents_df = patient.get_events( + labevents_df = patient.get_events_df( event_type="labevents", start=admission_time, end=admission_dischtime, - return_df=True, ) # Filter to relevant lab items - labevents_df = labevents_df.filter( - pl.col("labevents/itemid").is_in(self.LABITEMS) - ) + labevents_df = labevents_df[ + labevents_df["labevents/itemid"].isin(self.LABITEMS) + ] # Parse storetime and filter - if labevents_df.height > 0: - labevents_df = labevents_df.with_columns( - pl.col("labevents/storetime").str.strptime( - pl.Datetime, "%Y-%m-%d %H:%M:%S" - ) - ) - labevents_df = labevents_df.filter( - pl.col("labevents/storetime") <= admission_dischtime + if len(labevents_df) > 0: + labevents_df = labevents_df.copy() + labevents_df["labevents/storetime"] = pd.to_datetime( + labevents_df["labevents/storetime"], + format="%Y-%m-%d %H:%M:%S", + errors="coerce", ) + labevents_df = labevents_df[ + labevents_df["labevents/storetime"] <= admission_dischtime + ] - if labevents_df.height > 0: + if len(labevents_df) > 0: # Select relevant columns - labevents_df = labevents_df.select( - pl.col("timestamp"), - pl.col("labevents/itemid"), - pl.col("labevents/valuenum").str.strip_chars().replace("", None).cast(pl.Float64), + labevents_df = labevents_df.loc[ + :, ["timestamp", "labevents/itemid", "labevents/valuenum"] + ].copy() + labevents_df["labevents/valuenum"] = pd.to_numeric( + labevents_df["labevents/valuenum"] + .astype(str) + .str.strip() + .replace("", pd.NA), # type: ignore + errors="coerce", ) # Group by timestamp and aggregate into 10D vectors # For each timestamp, create vector of lab categories unique_timestamps = sorted( - labevents_df["timestamp"].unique().to_list() + labevents_df["timestamp"].unique().tolist() ) for lab_ts in unique_timestamps: # Get all lab events at this timestamp - ts_labs = labevents_df.filter(pl.col("timestamp") == lab_ts) + ts_labs = labevents_df[labevents_df["timestamp"] == lab_ts] # Create 10-dimensional vector (one per category) lab_vector = [] @@ -234,11 +241,13 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: # Find first matching value for this category category_value = None for itemid in category_itemids: - matching = ts_labs.filter( - pl.col("labevents/itemid") == itemid - ) - if matching.height > 0: - category_value = matching["labevents/valuenum"][0] + matching = ts_labs[ + ts_labs["labevents/itemid"] == itemid + ] + if len(matching) > 0: + category_value = matching[ + "labevents/valuenum" + ].iloc[0] break lab_vector.append(category_value) @@ -273,4 +282,4 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: "labs": labs_data, "mortality": final_mortality, } - return [sample] \ No newline at end of file + return [sample] From cc42dd387b585d226efe3391f06abf555af574a8 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sat, 22 Nov 2025 22:41:06 -0500 Subject: [PATCH 43/51] Fix litdata.optimize hang --- pyhealth/datasets/base_dataset.py | 43 +++++++++++-------------------- 1 file changed, 15 insertions(+), 28 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index dd97a8ffe..a1f4015df 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -107,28 +107,16 @@ def _patient_bucket(patient_id: str, n_partitions: int) -> int: return bucket -def _get_patient(merged_cache: str, patient_id: str) -> Patient: - with open(merged_cache + "/index.json", "rb") as f: - n_partitions = json.load(f)["n_partitions"] - bucket = _patient_bucket(patient_id, n_partitions) - path = f"{merged_cache}/bucket={bucket}" - patient = Patient( - patient_id=patient_id, - data_source=pl.read_parquet(path).filter(pl.col("patient_id") == patient_id), - ) - return patient - - def _transform_fn( input: tuple[int, str, BaseTask], ) -> Iterator[Dict[str, Any]]: (bucket_id, merged_cache, task) = input path = f"{merged_cache}/bucket={bucket_id}" # This is more efficient than reading patient by patient - grouped = pl.read_parquet(path).group_by("patient_id") + grouped = pd.read_parquet(path).groupby("patient_id") for patient_id, patient_df in grouped: - patient = Patient(patient_id=str(patient_id[0]), data_source=patient_df) + patient = Patient(patient_id=str(patient_id), data_source=patient_df) for sample in task(patient): # Schema is too complex to be handled by LitData, so we pickle the sample here yield _pickle(sample) @@ -223,17 +211,6 @@ def cache_dir(self) -> Path: return cache_dir - def _task_cache(self, task_name: str) -> str: - """Generates the cache path for a specific task. - - Args: - task_name (str): The name of the task. - Returns: - str: The cache path for the task. - """ - (self.cache_dir / "tasks").mkdir(parents=True, exist_ok=True) - return str(self.cache_dir / "tasks" / task_name) - def _table_cache(self, table_name: str, source_path: str | None = None) -> str: """Generates the cache path for a specific table. If the cached Parquet file does not exist, it will convert the source CSV/TSV file to Parquet and save it to the cache. @@ -307,7 +284,6 @@ def _joined_cache(self) -> str: n_workers=self.num_workers, threads_per_worker=1, memory_limit=self.mem_per_worker, - config={"distributed.nanny.terminate_timeout": "60s"}, ) as cluster: with Client(cluster) as client: global_event_df = self.load_data() @@ -506,7 +482,18 @@ def get_patient(self, patient_id: str) -> Patient: assert ( patient_id in self.unique_patient_ids ), f"Patient {patient_id} not found in dataset" - return _get_patient(self._joined_cache(), patient_id) + + path = self._joined_cache() + with open(f"{path}/index.json", "rb") as f: + n_partitions = json.load(f)["n_partitions"] + bucket = _patient_bucket(patient_id, n_partitions) + path = f"{path}/bucket={bucket}" + df = pd.read_parquet(path) + patient = Patient( + patient_id=patient_id, + data_source=df[df["patient_id"] == patient_id], + ) + return patient def iter_patients(self, df: Optional[pl.LazyFrame] = None) -> Iterator[Patient]: """Yields Patient objects for each unique patient in the dataset. @@ -584,7 +571,7 @@ def set_task( f"This argument is no longer supported: cache_format={cache_format}" ) if cache_dir is None: - cache_dir = self._task_cache(task.task_name) + cache_dir = str(self.cache_dir / "tasks" / task.task_name) logger.info( "No cache_dir provided. Using default task cache dir: %s", cache_dir ) From cd9a79fd98d5d7c0eb980747741293021c073933 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sun, 23 Nov 2025 00:11:37 -0500 Subject: [PATCH 44/51] Fix sample dataset --- examples/memtest.py | 91 +++++++++++++++-------------- pyhealth/datasets/sample_dataset.py | 36 ++++++------ 2 files changed, 65 insertions(+), 62 deletions(-) diff --git a/examples/memtest.py b/examples/memtest.py index 7b0d51d44..1662ed63a 100644 --- a/examples/memtest.py +++ b/examples/memtest.py @@ -1,4 +1,5 @@ # %% +from pyhealth.data.data import Patient from pyhealth.datasets import ( MIMIC4Dataset, get_dataloader, @@ -7,9 +8,10 @@ from pyhealth.models import StageNet from pyhealth.tasks import MortalityPredictionStageNetMIMIC4 from pyhealth.trainer import Trainer -import torch import dask.config -from dask.distributed import Client, LocalCluster + +import warnings +warnings.filterwarnings("ignore", "pkg_resources is deprecated as an API") if __name__ == "__main__": dask.config.set({"temporary-directory": "/mnt/tmpfs/"}) @@ -23,56 +25,57 @@ "procedures_icd", "labevents", ], - num_workers=8, + num_workers=2, mem_per_worker="8GB", dev=True ) - # print(f"Patients: {base_dataset.unique_patient_ids[:10]}, ...") + print(f"Patients: {base_dataset.unique_patient_ids[:10]}, ...") - # # STEP 2: Apply StageNet mortality prediction task - # sample_dataset = base_dataset.set_task( - # MortalityPredictionStageNetMIMIC4(), - # num_workers=4, - # ) - # print(f"Total samples: {len(sample_dataset)}") - # print(f"Input schema: {sample_dataset.input_schema}") - # print(f"Output schema: {sample_dataset.output_schema}") + # STEP 2: Apply StageNet mortality prediction task + sample_dataset = base_dataset.set_task( + MortalityPredictionStageNetMIMIC4(), + num_workers=2, + ) + print(f"Total samples: {len(sample_dataset)}") + print(f"Input schema: {sample_dataset.input_schema}") + print(f"Output schema: {sample_dataset.output_schema}") - # # Inspect a sample - # sample = next(iter(sample_dataset)) - # print("\nSample structure:") - # print(f" Patient ID: {sample['patient_id']}") - # print(f"ICD Codes: {sample['icd_codes']}") - # print(f" Labs shape: {len(sample['labs'][0])} timesteps") - # print(f" Mortality: {sample['mortality']}") + # Inspect a sample + sample = next(iter(sample_dataset)) + print("\nSample structure:") + print(f" Patient ID: {sample['patient_id']}") + print(f"ICD Codes: {sample['icd_codes']}") + print(f" Labs shape: {len(sample['labs'][0])} timesteps") + print(f" Mortality: {sample['mortality']}") - # # Create dataloaders - # train_loader = get_dataloader(sample_dataset, batch_size=256, shuffle=True) + # Create dataloaders + train_loader = get_dataloader(sample_dataset, batch_size=256, shuffle=True) - # # STEP 4: Initialize StageNet model - # model = StageNet( - # dataset=sample_dataset, - # embedding_dim=128, - # chunk_size=128, - # levels=3, - # dropout=0.3, - # ) + # STEP 4: Initialize StageNet model + model = StageNet( + dataset=sample_dataset, + embedding_dim=128, + chunk_size=128, + levels=3, + dropout=0.3, + ) - # num_params = sum(p.numel() for p in model.parameters()) - # print(f"\nModel initialized with {num_params} parameters") + num_params = sum(p.numel() for p in model.parameters()) + print(f"\nModel initialized with {num_params} parameters") - # # STEP 5: Train the model - # trainer = Trainer( - # model=model, - # device="cuda:5", # or "cpu" - # metrics=["pr_auc", "roc_auc", "accuracy", "f1"], - # ) + # STEP 5: Train the model + trainer = Trainer( + model=model, + device="cpu", # or "cpu" + metrics=["pr_auc", "roc_auc", "accuracy", "f1"], + enable_logging=False, + ) + print("\nStarting training...") - # trainer.train( - # train_dataloader=train_loader, - # val_dataloader=train_loader, - # epochs=50, - # monitor="roc_auc", - # optimizer_params={"lr": 1e-5}, - # ) + trainer.train( + train_dataloader=train_loader, + epochs=50, + monitor="roc_auc", + optimizer_params={"lr": 1e-5}, + ) diff --git a/pyhealth/datasets/sample_dataset.py b/pyhealth/datasets/sample_dataset.py index d4e098926..c895a151f 100644 --- a/pyhealth/datasets/sample_dataset.py +++ b/pyhealth/datasets/sample_dataset.py @@ -4,6 +4,7 @@ from torch.utils.data import IterableDataset from litdata.streaming import StreamingDataset from tqdm import tqdm +import pickle from ..processors import get_processor from ..processors.base_processor import FeatureProcessor @@ -88,6 +89,13 @@ def __init__( self.validate() self.build() + # Apply processors + self.dataset = StreamingDataset( + input_dir=self.dataset.input_dir, + cache_dir=self.dataset.cache_dir, + transform=self.transform, + ) + def set_shuffle(self, shuffle: bool) -> None: """Sets whether to shuffle the dataset. @@ -160,30 +168,22 @@ def build(self) -> None: for k, v in self.output_schema.items(): self.output_processors[k] = self._get_processor_instance(v) self.output_processors[k].fit(iter(self.dataset), k) - # Always process samples with the (fitted) processors - for sample in tqdm(iter(self.dataset), desc="Processing samples"): - for k, v in sample.items(): - if k in self.input_processors: - sample[k] = self.input_processors[k].process(v) - elif k in self.output_processors: - sample[k] = self.output_processors[k].process(v) return + def transform(self, sample) -> Dict: + for k, v in sample.items(): + if k in self.input_processors: + sample[k] = self.input_processors[k].process(pickle.loads(v)) + elif k in self.output_processors: + sample[k] = self.output_processors[k].process(pickle.loads(v)) + else: + sample[k] = pickle.loads(v) + return sample + def __iter__(self) -> Iterator: - # TODO: transform samples on the fly return self.dataset.__iter__() def __getitem__(self, index: int) -> Dict: - """Returns a sample by index. - - Args: - index (int): Index of the sample to retrieve. - - Returns: - Dict: A dict with patient_id, visit_id/record_id, and other - task-specific attributes as key. Conversion to index/tensor - will be done in the model. - """ return self.dataset.__getitem__(index) def __str__(self) -> str: From d7a1fcfaadaee50cb25715e0efaabfa1b5fa3c6c Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sun, 23 Nov 2025 08:00:43 -0500 Subject: [PATCH 45/51] add SampleSubset --- pyhealth/datasets/sample_dataset.py | 119 ++++++++++++++++++++++------ tests/core/test_sample_dataset.py | 111 ++++++++++++++++++++++++++ 2 files changed, 207 insertions(+), 23 deletions(-) create mode 100644 tests/core/test_sample_dataset.py diff --git a/pyhealth/datasets/sample_dataset.py b/pyhealth/datasets/sample_dataset.py index c895a151f..01f8450fc 100644 --- a/pyhealth/datasets/sample_dataset.py +++ b/pyhealth/datasets/sample_dataset.py @@ -1,8 +1,10 @@ from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Type +from bisect import bisect_right import inspect from torch.utils.data import IterableDataset from litdata.streaming import StreamingDataset +from litdata.utilities.train_test_split import deepcopy_dataset from tqdm import tqdm import pickle @@ -96,19 +98,6 @@ def __init__( transform=self.transform, ) - def set_shuffle(self, shuffle: bool) -> None: - """Sets whether to shuffle the dataset. - - Args: - shuffle (bool): Whether to shuffle the dataset. - """ - if hasattr(self.dataset, "set_shuffle"): - self.dataset.set_shuffle(shuffle) - else: - raise NotImplementedError( - "Shuffle is not implemented for this dataset type." - ) - def _get_processor_instance(self, processor_spec): """Get processor instance from either string alias, class reference, processor instance, or tuple with kwargs. @@ -171,6 +160,7 @@ def build(self) -> None: return def transform(self, sample) -> Dict: + """Applies the input and output processors to a sample.""" for k, v in sample.items(): if k in self.input_processors: sample[k] = self.input_processors[k].process(pickle.loads(v)) @@ -181,23 +171,106 @@ def transform(self, sample) -> Dict: return sample def __iter__(self) -> Iterator: + """Returns an iterator over the dataset samples.""" return self.dataset.__iter__() def __getitem__(self, index: int) -> Dict: + """Gets a sample by index. Try to use iterator for better performance.""" return self.dataset.__getitem__(index) def __str__(self) -> str: - """Returns a string representation of the dataset. - - Returns: - str: A string with the dataset and task names. - """ + """String representation of the SampleDataset.""" return f"Sample dataset {self.dataset_name} {self.task_name}" def __len__(self) -> int: - """Returns the number of samples in the dataset. - - Returns: - int: The number of samples. - """ + """Returns the number of samples in the dataset.""" return self.dataset.__len__() + +class SampleSubset(IterableDataset): + """A subset of the SampleDataset. + + Args: + sample_dataset (SampleDataset): The original SampleDataset. + indices (List[int]): List of indices to include in the subset. + """ + + def __init__(self, dataset: SampleDataset, indices: List[int]) -> None: + self.dataset_name = dataset.dataset_name + self.task_name = dataset.task_name + base_dataset = deepcopy_dataset(dataset.dataset) + + if len(base_dataset.subsampled_files) != len(base_dataset.region_of_interest): + raise ValueError( + "The provided dataset has mismatched subsampled_files and region_of_interest lengths." + ) + + dataset_length = sum( + end - start for start, end in base_dataset.region_of_interest + ) + if any(idx < 0 or idx >= dataset_length for idx in indices): + raise ValueError( + f"Subset indices must be in [0, {dataset_length - 1}] for the provided dataset." + ) + + # Build chunk boundaries so we can translate global indices into + # chunk-local (start, end) pairs that litdata understands. + chunk_starts: List[int] = [] + chunk_boundaries: List[Tuple[str, int, int, int, int]] = [] + cursor = 0 + for filename, (roi_start, roi_end) in zip( + base_dataset.subsampled_files, base_dataset.region_of_interest + ): + chunk_len = roi_end - roi_start + if chunk_len <= 0: + continue + chunk_starts.append(cursor) + chunk_boundaries.append( + (filename, roi_start, roi_end, cursor, cursor + chunk_len) + ) + cursor += chunk_len + + new_subsampled_files: List[str] = [] + new_roi: List[Tuple[int, int]] = [] + prev_chunk_idx: Optional[int] = None + + for idx in indices: + chunk_idx = bisect_right(chunk_starts, idx) - 1 + if chunk_idx < 0 or idx >= chunk_boundaries[chunk_idx][4]: + raise ValueError(f"Index {idx} is out of bounds for the dataset.") + + filename, roi_start, _, global_start, _ = chunk_boundaries[chunk_idx] + offset_in_chunk = roi_start + (idx - global_start) + + if ( + new_roi + and prev_chunk_idx == chunk_idx + and offset_in_chunk == new_roi[-1][1] + ): + new_roi[-1] = (new_roi[-1][0], new_roi[-1][1] + 1) + else: + new_subsampled_files.append(filename) + new_roi.append((offset_in_chunk, offset_in_chunk + 1)) + + prev_chunk_idx = chunk_idx + + self.dataset: StreamingDataset = base_dataset + self.dataset.subsampled_files = new_subsampled_files + self.dataset.region_of_interest = new_roi + self.dataset.reset() + self._length = sum(end - start for start, end in new_roi) + + def __iter__(self) -> Iterator: + """Returns an iterator over the dataset samples.""" + return self.dataset.__iter__() + + def __getitem__(self, index: int) -> Dict: + """Gets a sample by index. Try to use iterator for better performance.""" + return self.dataset.__getitem__(index) + + def __str__(self) -> str: + """String representation of the SampleDataset.""" + return f"Sample dataset {self.dataset_name} {self.task_name} subset" + + def __len__(self) -> int: + """Returns the number of samples in the dataset.""" + return self._length diff --git a/tests/core/test_sample_dataset.py b/tests/core/test_sample_dataset.py new file mode 100644 index 000000000..22ecad2dc --- /dev/null +++ b/tests/core/test_sample_dataset.py @@ -0,0 +1,111 @@ +import pickle +import tempfile +import unittest +from pathlib import Path +from typing import Any, Dict, Iterator, List + +from litdata import StreamingDataset, optimize + +from pyhealth.datasets.sample_dataset import SampleDataset, SampleSubset +from pyhealth.processors.base_processor import FeatureProcessor + +# Top-level identity function for litdata.optimize (must be picklable). +def _identity(sample: Dict[str, Any]) -> Dict[str, Any]: + return sample + + +class RecordingProcessor(FeatureProcessor): + """Processor that records fit/process calls and prefixes outputs.""" + + def __init__(self, prefix: str) -> None: + self.prefix = prefix + self.fit_called = False + self.fit_seen: List[Any] = [] + self.process_seen: List[Any] = [] + + def fit(self, samples: Iterator[Dict[str, Any]], field: str) -> None: + self.fit_called = True + for sample in samples: + self.fit_seen.append(pickle.loads(sample[field])) + + def process(self, value: Any) -> Any: + self.process_seen.append(value) + return f"{self.prefix}-{value}" + + +class TestSampleDatasetAndSubset(unittest.TestCase): + def setUp(self) -> None: + super().setUp() + self.tmpdir = tempfile.TemporaryDirectory() + self.output_dir = Path(self.tmpdir.name) / "stream" + + raw_samples = [ + {"patient_id": "p1", "record_id": "r1", "x": 1, "y": 10}, + {"patient_id": "p2", "record_id": "r2", "x": 2, "y": 20}, + {"patient_id": "p3", "record_id": "r3", "x": 3, "y": 30}, + ] + pickled_samples = [{k: pickle.dumps(v) for k, v in sample.items()} for sample in raw_samples] + + optimize( + fn=_identity, + inputs=pickled_samples, + output_dir=str(self.output_dir), + chunk_size=len(pickled_samples), + num_workers=1, + verbose=False, + ) + + streaming_dataset = StreamingDataset( + input_dir=str(self.output_dir), + cache_dir=str(self.output_dir), + ) + + self.sample_dataset = SampleDataset( + dataset=streaming_dataset, + input_schema={"x": (RecordingProcessor, {"prefix": "in"})}, + output_schema={"y": (RecordingProcessor, {"prefix": "out"})}, + dataset_name="test_dataset", + task_name="task", + ) + + self.raw_samples = raw_samples + + def tearDown(self) -> None: + self.tmpdir.cleanup() + super().tearDown() + + def test_sample_dataset_builds_processors(self) -> None: + self.assertIn("x", self.sample_dataset.input_processors) + self.assertIn("y", self.sample_dataset.output_processors) + + input_proc: RecordingProcessor = self.sample_dataset.input_processors["x"] # type: ignore + output_proc: RecordingProcessor = self.sample_dataset.output_processors["y"] # type: ignore + + self.assertTrue(input_proc.fit_called) + self.assertTrue(output_proc.fit_called) + self.assertEqual(input_proc.fit_seen, [1, 2, 3]) + self.assertEqual(output_proc.fit_seen, [10, 20, 30]) + + def test_sample_dataset_returns_processed_items(self) -> None: + item = self.sample_dataset[0] + self.assertEqual(item["x"], "in-1") + self.assertEqual(item["y"], "out-10") + self.assertEqual(item["patient_id"], "p1") + self.assertEqual(item["record_id"], "r1") + + def test_sample_subset_respects_indices_and_processing(self) -> None: + subset_indices = [1, 2] + subset = SampleSubset(self.sample_dataset, subset_indices) + + self.assertEqual(len(subset), len(subset_indices)) + + second = subset[0] + third = subset[1] + + self.assertEqual(second["x"], "in-2") + self.assertEqual(second["y"], "out-20") + self.assertEqual(second["patient_id"], "p2") + + self.assertEqual(third["x"], "in-3") + self.assertEqual(third["y"], "out-30") + self.assertEqual(third["patient_id"], "p3") From 2d47cc36e98c5e011f3d6d50088b676baaaddda0 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sun, 23 Nov 2025 08:03:34 -0500 Subject: [PATCH 46/51] Add more test cases --- tests/core/test_sample_dataset.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/core/test_sample_dataset.py b/tests/core/test_sample_dataset.py index 22ecad2dc..0592d5b41 100644 --- a/tests/core/test_sample_dataset.py +++ b/tests/core/test_sample_dataset.py @@ -87,18 +87,27 @@ def test_sample_dataset_builds_processors(self) -> None: self.assertEqual(output_proc.fit_seen, [10, 20, 30]) def test_sample_dataset_returns_processed_items(self) -> None: + # __getitem__ path item = self.sample_dataset[0] self.assertEqual(item["x"], "in-1") self.assertEqual(item["y"], "out-10") self.assertEqual(item["patient_id"], "p1") self.assertEqual(item["record_id"], "r1") + # __iter__ path + self.sample_dataset.dataset.reset() + items = list(iter(self.sample_dataset)) + self.assertEqual([s["x"] for s in items], ["in-1", "in-2", "in-3"]) + self.assertEqual([s["y"] for s in items], ["out-10", "out-20", "out-30"]) + self.assertEqual([s["patient_id"] for s in items], ["p1", "p2", "p3"]) + def test_sample_subset_respects_indices_and_processing(self) -> None: subset_indices = [1, 2] subset = SampleSubset(self.sample_dataset, subset_indices) self.assertEqual(len(subset), len(subset_indices)) + # __getitem__ path second = subset[0] third = subset[1] @@ -109,3 +118,11 @@ def test_sample_subset_respects_indices_and_processing(self) -> None: self.assertEqual(third["x"], "in-3") self.assertEqual(third["y"], "out-30") self.assertEqual(third["patient_id"], "p3") + + # __iter__ path + subset.dataset.reset() + iter_items = list(iter(subset)) + self.assertEqual(len(iter_items), 2) + self.assertEqual([s["x"] for s in iter_items], ["in-2", "in-3"]) + self.assertEqual([s["y"] for s in iter_items], ["out-20", "out-30"]) + self.assertEqual([s["patient_id"] for s in iter_items], ["p2", "p3"]) From d4f32169e9360193fb1f75bc27d341ddb583a381 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sun, 23 Nov 2025 08:06:52 -0500 Subject: [PATCH 47/51] Refactor SampleSubset __init__ --- pyhealth/datasets/sample_dataset.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/pyhealth/datasets/sample_dataset.py b/pyhealth/datasets/sample_dataset.py index 01f8450fc..d5854f22d 100644 --- a/pyhealth/datasets/sample_dataset.py +++ b/pyhealth/datasets/sample_dataset.py @@ -197,8 +197,16 @@ class SampleSubset(IterableDataset): def __init__(self, dataset: SampleDataset, indices: List[int]) -> None: self.dataset_name = dataset.dataset_name self.task_name = dataset.task_name + base_dataset = deepcopy_dataset(dataset.dataset) + self.dataset, self._length = self._build_subset_dataset( + base_dataset, indices + ) + def _build_subset_dataset( + self, base_dataset: StreamingDataset, indices: List[int] + ) -> Tuple[StreamingDataset, int]: + """Create a StreamingDataset restricted to the provided indices.""" if len(base_dataset.subsampled_files) != len(base_dataset.region_of_interest): raise ValueError( "The provided dataset has mismatched subsampled_files and region_of_interest lengths." @@ -253,12 +261,13 @@ def __init__(self, dataset: SampleDataset, indices: List[int]) -> None: prev_chunk_idx = chunk_idx - self.dataset: StreamingDataset = base_dataset - self.dataset.subsampled_files = new_subsampled_files - self.dataset.region_of_interest = new_roi - self.dataset.reset() - self._length = sum(end - start for start, end in new_roi) + base_dataset.subsampled_files = new_subsampled_files + base_dataset.region_of_interest = new_roi + base_dataset.reset() + subset_length = sum(end - start for start, end in new_roi) + return base_dataset, subset_length + def __iter__(self) -> Iterator: """Returns an iterator over the dataset samples.""" return self.dataset.__iter__() From a93d1a6b43c15c02c4f93f0f751cb0848235f0c2 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sun, 23 Nov 2025 08:16:53 -0500 Subject: [PATCH 48/51] support set_shuffle, add testcase --- pyhealth/datasets/sample_dataset.py | 21 +++++++++++++++++- tests/core/test_sample_dataset.py | 33 +++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/pyhealth/datasets/sample_dataset.py b/pyhealth/datasets/sample_dataset.py index d5854f22d..d138fa85f 100644 --- a/pyhealth/datasets/sample_dataset.py +++ b/pyhealth/datasets/sample_dataset.py @@ -170,6 +170,15 @@ def transform(self, sample) -> Dict: sample[k] = pickle.loads(v) return sample + def set_shuffle(self, shuffle: bool) -> None: + """Sets whether to shuffle the dataset during iteration. + + Args: + shuffle (bool): Whether to shuffle the dataset. + """ + self.dataset.set_shuffle(shuffle) + return + def __iter__(self) -> Iterator: """Returns an iterator over the dataset samples.""" return self.dataset.__iter__() @@ -198,7 +207,8 @@ def __init__(self, dataset: SampleDataset, indices: List[int]) -> None: self.dataset_name = dataset.dataset_name self.task_name = dataset.task_name - base_dataset = deepcopy_dataset(dataset.dataset) + base_dataset: StreamingDataset = deepcopy_dataset(dataset.dataset) + base_dataset.set_shuffle(False) # Disable shuffling when creating subset self.dataset, self._length = self._build_subset_dataset( base_dataset, indices ) @@ -268,6 +278,15 @@ def _build_subset_dataset( return base_dataset, subset_length + def set_shuffle(self, shuffle: bool) -> None: + """Sets whether to shuffle the dataset during iteration. + + Args: + shuffle (bool): Whether to shuffle the dataset. + """ + self.dataset.set_shuffle(shuffle) + return + def __iter__(self) -> Iterator: """Returns an iterator over the dataset samples.""" return self.dataset.__iter__() diff --git a/tests/core/test_sample_dataset.py b/tests/core/test_sample_dataset.py index 0592d5b41..04431278f 100644 --- a/tests/core/test_sample_dataset.py +++ b/tests/core/test_sample_dataset.py @@ -126,3 +126,36 @@ def test_sample_subset_respects_indices_and_processing(self) -> None: self.assertEqual([s["x"] for s in iter_items], ["in-2", "in-3"]) self.assertEqual([s["y"] for s in iter_items], ["out-20", "out-30"]) self.assertEqual([s["patient_id"] for s in iter_items], ["p2", "p3"]) + + def test_shuffle_behavior_and_isolation(self) -> None: + # Baseline (no shuffle) + baseline = [s["patient_id"] for s in iter(self.sample_dataset)] + self.sample_dataset.dataset.reset() + + # Shuffle affects iteration but not __getitem__ + self.sample_dataset.set_shuffle(True) + shuffled_iter = [s["patient_id"] for s in iter(self.sample_dataset)] + self.assertCountEqual(shuffled_iter, baseline) + if len(baseline) > 1: + self.assertNotEqual(shuffled_iter, baseline) + self.sample_dataset.dataset.reset() + self.assertEqual(self.sample_dataset[0]["patient_id"], "p1") + + # Subset created from shuffled dataset should disable shuffle during construction + subset = SampleSubset(self.sample_dataset, [0, 1]) + self.assertFalse(subset.dataset.shuffle) + subset_items = [s["patient_id"] for s in iter(subset)] + self.assertEqual(subset_items, ["p1", "p2"]) + subset.dataset.reset() + self.assertEqual(subset[0]["patient_id"], "p1") + + # Shuffling one subset doesn't affect dataset or other subsets + subset2 = SampleSubset(self.sample_dataset, [1, 2]) + subset.set_shuffle(True) + shuffled_subset_iter = [s["patient_id"] for s in iter(subset)] + self.assertCountEqual(shuffled_subset_iter, ["p1", "p2"]) + if len(shuffled_subset_iter) > 1: + self.assertNotEqual(shuffled_subset_iter, ["p1", "p2"]) + self.assertFalse(subset2.dataset.shuffle) + self.assertEqual(subset2[0]["patient_id"], "p2") + self.assertTrue(self.sample_dataset.dataset.shuffle) From 4dbc4f4e75c079d5f4df1378867624d611c071ca Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sun, 23 Nov 2025 08:22:42 -0500 Subject: [PATCH 49/51] Fix incorrect typehint --- pyhealth/datasets/sample_dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyhealth/datasets/sample_dataset.py b/pyhealth/datasets/sample_dataset.py index d138fa85f..f1bcd8288 100644 --- a/pyhealth/datasets/sample_dataset.py +++ b/pyhealth/datasets/sample_dataset.py @@ -1,11 +1,11 @@ from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Type +from collections.abc import Sequence from bisect import bisect_right import inspect from torch.utils.data import IterableDataset from litdata.streaming import StreamingDataset from litdata.utilities.train_test_split import deepcopy_dataset -from tqdm import tqdm import pickle from ..processors import get_processor @@ -203,7 +203,7 @@ class SampleSubset(IterableDataset): indices (List[int]): List of indices to include in the subset. """ - def __init__(self, dataset: SampleDataset, indices: List[int]) -> None: + def __init__(self, dataset: SampleDataset, indices: Sequence[int]) -> None: self.dataset_name = dataset.dataset_name self.task_name = dataset.task_name @@ -214,7 +214,7 @@ def __init__(self, dataset: SampleDataset, indices: List[int]) -> None: ) def _build_subset_dataset( - self, base_dataset: StreamingDataset, indices: List[int] + self, base_dataset: StreamingDataset, indices: Sequence[int] ) -> Tuple[StreamingDataset, int]: """Create a StreamingDataset restricted to the provided indices.""" if len(base_dataset.subsampled_files) != len(base_dataset.region_of_interest): From 66e9e64b1ef153fb71419e561d043f30b1eea032 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sun, 23 Nov 2025 08:33:17 -0500 Subject: [PATCH 50/51] Fix splitter --- examples/memtest.py | 28 +++++++++++- pyhealth/datasets/splitter.py | 83 ++++++++++++----------------------- pyhealth/datasets/utils.py | 4 +- 3 files changed, 56 insertions(+), 59 deletions(-) diff --git a/examples/memtest.py b/examples/memtest.py index 1662ed63a..0b795296c 100644 --- a/examples/memtest.py +++ b/examples/memtest.py @@ -1,4 +1,5 @@ # %% +import torch from pyhealth.data.data import Patient from pyhealth.datasets import ( MIMIC4Dataset, @@ -49,8 +50,15 @@ print(f" Labs shape: {len(sample['labs'][0])} timesteps") print(f" Mortality: {sample['mortality']}") + # STEP 3: Split dataset + train_dataset, val_dataset, test_dataset = split_by_patient( + sample_dataset, [0.8, 0.1, 0.1] + ) + # Create dataloaders - train_loader = get_dataloader(sample_dataset, batch_size=256, shuffle=True) + train_loader = get_dataloader(train_dataset, batch_size=32, shuffle=True) + val_loader = get_dataloader(val_dataset, batch_size=32, shuffle=False) + test_loader = get_dataloader(test_dataset, batch_size=32, shuffle=False) # STEP 4: Initialize StageNet model model = StageNet( @@ -75,7 +83,23 @@ trainer.train( train_dataloader=train_loader, - epochs=50, + val_dataloader=val_loader, + epochs=5, monitor="roc_auc", optimizer_params={"lr": 1e-5}, ) + + # STEP 6: Evaluate on test set + results = trainer.evaluate(test_loader) + print("\nTest Results:") + for metric, value in results.items(): + print(f" {metric}: {value:.4f}") + + # STEP 7: Inspect model predictions + sample_batch = next(iter(test_loader)) + with torch.no_grad(): + output = model(**sample_batch) + + print("\nSample predictions:") + print(f" Predicted probabilities: {output['y_prob'][:5]}") + print(f" True labels: {output['y_true'][:5]}") diff --git a/pyhealth/datasets/splitter.py b/pyhealth/datasets/splitter.py index c70df5660..f9e3c7213 100644 --- a/pyhealth/datasets/splitter.py +++ b/pyhealth/datasets/splitter.py @@ -4,7 +4,7 @@ import numpy as np import torch -from .sample_dataset import SampleDataset +from .sample_dataset import SampleDataset, SampleSubset # TODO: train_dataset.dataset still access the whole dataset which may leak information # TODO: add more splitting methods @@ -24,11 +24,7 @@ def split_by_visit( Returns: train_dataset, val_dataset, test_dataset: three subsets of the dataset of - type `torch.utils.data.Subset`. - - Note: - The original dataset can be accessed by `train_dataset.dataset`, - `val_dataset.dataset`, and `test_dataset.dataset`. + type `SampleSubset`. """ if seed is not None: np.random.seed(seed) @@ -40,9 +36,9 @@ def split_by_visit( int(len(dataset) * ratios[0]) : int(len(dataset) * (ratios[0] + ratios[1])) ] test_index = index[int(len(dataset) * (ratios[0] + ratios[1])) :] - train_dataset = torch.utils.data.Subset(dataset, train_index) - val_dataset = torch.utils.data.Subset(dataset, val_index) - test_dataset = torch.utils.data.Subset(dataset, test_index) + train_dataset = SampleSubset(dataset, train_index) # type: ignore + val_dataset = SampleSubset(dataset, val_index) # type: ignore + test_dataset = SampleSubset(dataset, test_index) # type: ignore return train_dataset, val_dataset, test_dataset @@ -60,11 +56,7 @@ def split_by_patient( Returns: train_dataset, val_dataset, test_dataset: three subsets of the dataset of - type `torch.utils.data.Subset`. - - Note: - The original dataset can be accessed by `train_dataset.dataset`, - `val_dataset.dataset`, and `test_dataset.dataset`. + type `SampleSubset`. """ if seed is not None: np.random.seed(seed) @@ -82,9 +74,9 @@ def split_by_patient( ) val_index = list(chain(*[dataset.patient_to_index[i] for i in val_patient_indx])) test_index = list(chain(*[dataset.patient_to_index[i] for i in test_patient_indx])) - train_dataset = torch.utils.data.Subset(dataset, train_index) - val_dataset = torch.utils.data.Subset(dataset, val_index) - test_dataset = torch.utils.data.Subset(dataset, test_index) + train_dataset = SampleSubset(dataset, train_index) + val_dataset = SampleSubset(dataset, val_index) + test_dataset = SampleSubset(dataset, test_index) return train_dataset, val_dataset, test_dataset @@ -103,11 +95,7 @@ def split_by_sample( Returns: train_dataset, val_dataset, test_dataset: three subsets of the dataset of - type `torch.utils.data.Subset`. - - Note: - The original dataset can be accessed by `train_dataset.dataset`, - `val_dataset.dataset`, and `test_dataset.dataset`. + type `SampleSubset`. """ if seed is not None: np.random.seed(seed) @@ -119,9 +107,9 @@ def split_by_sample( int(len(dataset) * ratios[0]) : int(len(dataset) * (ratios[0] + ratios[1])) ] test_index = index[int(len(dataset) * (ratios[0] + ratios[1])) :] - train_dataset = torch.utils.data.Subset(dataset, train_index) - val_dataset = torch.utils.data.Subset(dataset, val_index) - test_dataset = torch.utils.data.Subset(dataset, test_index) + train_dataset = SampleSubset(dataset, train_index.tolist()) + val_dataset = SampleSubset(dataset, val_index.tolist()) + test_dataset = SampleSubset(dataset, test_index.tolist()) if get_index: return ( @@ -147,12 +135,7 @@ def split_by_visit_conformal( Returns: train_dataset, val_dataset, cal_dataset, test_dataset: four subsets - of the dataset of type `torch.utils.data.Subset`. - - Note: - The original dataset can be accessed by `train_dataset.dataset`, - `val_dataset.dataset`, `cal_dataset.dataset`, and - `test_dataset.dataset`. + of the dataset of type `SampleSubset`. """ if seed is not None: np.random.seed(seed) @@ -172,10 +155,10 @@ def split_by_visit_conformal( cal_index = index[val_end:cal_end] test_index = index[cal_end:] - train_dataset = torch.utils.data.Subset(dataset, train_index) - val_dataset = torch.utils.data.Subset(dataset, val_index) - cal_dataset = torch.utils.data.Subset(dataset, cal_index) - test_dataset = torch.utils.data.Subset(dataset, test_index) + train_dataset = SampleSubset(dataset, train_index) # type: ignore + val_dataset = SampleSubset(dataset, val_index) # type: ignore + cal_dataset = SampleSubset(dataset, cal_index) # type: ignore + test_dataset = SampleSubset(dataset, test_index) # type: ignore return train_dataset, val_dataset, cal_dataset, test_dataset @@ -194,12 +177,7 @@ def split_by_patient_conformal( Returns: train_dataset, val_dataset, cal_dataset, test_dataset: four subsets - of the dataset of type `torch.utils.data.Subset`. - - Note: - The original dataset can be accessed by `train_dataset.dataset`, - `val_dataset.dataset`, `cal_dataset.dataset`, and - `test_dataset.dataset`. + of the dataset of type `SampleSubset`. """ if seed is not None: np.random.seed(seed) @@ -227,10 +205,10 @@ def split_by_patient_conformal( cal_index = list(chain(*[dataset.patient_to_index[i] for i in cal_patient_indx])) test_index = list(chain(*[dataset.patient_to_index[i] for i in test_patient_indx])) - train_dataset = torch.utils.data.Subset(dataset, train_index) - val_dataset = torch.utils.data.Subset(dataset, val_index) - cal_dataset = torch.utils.data.Subset(dataset, cal_index) - test_dataset = torch.utils.data.Subset(dataset, test_index) + train_dataset = SampleSubset(dataset, train_index) # type: ignore + val_dataset = SampleSubset(dataset, val_index) # type: ignore + cal_dataset = SampleSubset(dataset, cal_index) # type: ignore + test_dataset = SampleSubset(dataset, test_index) # type: ignore return train_dataset, val_dataset, cal_dataset, test_dataset @@ -251,13 +229,8 @@ def split_by_sample_conformal( Returns: train_dataset, val_dataset, cal_dataset, test_dataset: four subsets - of the dataset of type `torch.utils.data.Subset`, or four tensors + of the dataset of type `SampleSubset`, or four tensors of indices if get_index=True. - - Note: - The original dataset can be accessed by `train_dataset.dataset`, - `val_dataset.dataset`, `cal_dataset.dataset`, and - `test_dataset.dataset`. """ if seed is not None: np.random.seed(seed) @@ -285,8 +258,8 @@ def split_by_sample_conformal( torch.tensor(test_index), ) else: - train_dataset = torch.utils.data.Subset(dataset, train_index) - val_dataset = torch.utils.data.Subset(dataset, val_index) - cal_dataset = torch.utils.data.Subset(dataset, cal_index) - test_dataset = torch.utils.data.Subset(dataset, test_index) + train_dataset = SampleSubset(dataset, train_index) # type: ignore + val_dataset = SampleSubset(dataset, val_index) # type: ignore + cal_dataset = SampleSubset(dataset, cal_index) # type: ignore + test_dataset = SampleSubset(dataset, test_index) # type: ignore return train_dataset, val_dataset, cal_dataset, test_dataset diff --git a/pyhealth/datasets/utils.py b/pyhealth/datasets/utils.py index 33878a8bc..f45ca9c98 100644 --- a/pyhealth/datasets/utils.py +++ b/pyhealth/datasets/utils.py @@ -10,7 +10,7 @@ from torch.utils.data import DataLoader from pyhealth import BASE_CACHE_PATH -from pyhealth.datasets.sample_dataset import SampleDataset +from pyhealth.datasets.sample_dataset import SampleDataset, SampleSubset from pyhealth.utils import create_directory MODULE_CACHE_PATH = os.path.join(BASE_CACHE_PATH, "datasets") @@ -320,7 +320,7 @@ def collate_fn_dict_with_padding(batch: List[dict]) -> dict: def get_dataloader( - dataset: SampleDataset, batch_size: int, shuffle: bool = False + dataset: SampleDataset | SampleSubset, batch_size: int, shuffle: bool = False ) -> DataLoader: """Creates a DataLoader for a given dataset. From 790cb420869ea5dc8017edebcbf9b20f7a7be0b3 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Wed, 26 Nov 2025 05:48:52 -0500 Subject: [PATCH 51/51] Fix task.pre_filter --- pyhealth/datasets/base_dataset.py | 7 ++++++- pyhealth/tasks/base_task.py | 4 ++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index a1f4015df..787debe88 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -112,11 +112,16 @@ def _transform_fn( ) -> Iterator[Dict[str, Any]]: (bucket_id, merged_cache, task) = input path = f"{merged_cache}/bucket={bucket_id}" + df = pd.read_parquet(path) + # TODO: Need to make sure pre_filter works with pandas DataFrame + df = task.pre_filter(df) + # This is more efficient than reading patient by patient - grouped = pd.read_parquet(path).groupby("patient_id") + grouped = df.groupby("patient_id") for patient_id, patient_df in grouped: patient = Patient(patient_id=str(patient_id), data_source=patient_df) + # TODO: Need to make sure task(patient) works with pandas DataFrame for sample in task(patient): # Schema is too complex to be handled by LitData, so we pickle the sample here yield _pickle(sample) diff --git a/pyhealth/tasks/base_task.py b/pyhealth/tasks/base_task.py index 9026e1e24..cd8c79adf 100644 --- a/pyhealth/tasks/base_task.py +++ b/pyhealth/tasks/base_task.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Dict, List, Union, Type -import dask.dataframe as dd +import pandas as pd class BaseTask(ABC): @@ -9,7 +9,7 @@ class BaseTask(ABC): input_schema: Dict[str, Union[str, Type]] output_schema: Dict[str, Union[str, Type]] - def pre_filter(self, df: dd.DataFrame) -> dd.DataFrame: + def pre_filter(self, df: pd.DataFrame) -> pd.DataFrame: return df @abstractmethod