diff --git a/examples/memtest.py b/examples/memtest.py index 8a63090e8..0b795296c 100644 --- a/examples/memtest.py +++ b/examples/memtest.py @@ -1,32 +1,105 @@ # %% -import psutil, os, time, threading -PEAK_MEM_USAGE = 0 -SELF_PROC = psutil.Process(os.getpid()) +import torch +from pyhealth.data.data import Patient +from pyhealth.datasets import ( + MIMIC4Dataset, + get_dataloader, + split_by_patient, +) +from pyhealth.models import StageNet +from pyhealth.tasks import MortalityPredictionStageNetMIMIC4 +from pyhealth.trainer import Trainer +import dask.config -def track_mem(): - global PEAK_MEM_USAGE - while True: - m = SELF_PROC.memory_info().rss - if m > PEAK_MEM_USAGE: - PEAK_MEM_USAGE = m - time.sleep(0.1) +import warnings +warnings.filterwarnings("ignore", "pkg_resources is deprecated as an API") -threading.Thread(target=track_mem, daemon=True).start() -print(f"[MEM] start={PEAK_MEM_USAGE / (1024**3)} GB") +if __name__ == "__main__": + dask.config.set({"temporary-directory": "/mnt/tmpfs/"}) + + base_dataset = MIMIC4Dataset( + ehr_root="/home/logic/physionet.org/files/mimiciv/3.1", + ehr_tables=[ + "patients", + "admissions", + "diagnoses_icd", + "procedures_icd", + "labevents", + ], + num_workers=2, + mem_per_worker="8GB", + dev=True + ) -# %% -from pyhealth.datasets import MIMIC4Dataset -DATASET_DIR = "/home/logic/physionet.org/files/mimiciv/3.1" -dataset = MIMIC4Dataset( - ehr_root=DATASET_DIR, - ehr_tables=[ - "patients", - "admissions", - "diagnoses_icd", - "procedures_icd", - "prescriptions", - "labevents", - ], -) -print(f"[MEM] __init__={PEAK_MEM_USAGE / (1024**3):.3f} GB") -# %% + print(f"Patients: {base_dataset.unique_patient_ids[:10]}, ...") + + # STEP 2: Apply StageNet mortality prediction task + sample_dataset = base_dataset.set_task( + MortalityPredictionStageNetMIMIC4(), + num_workers=2, + ) + print(f"Total samples: {len(sample_dataset)}") + print(f"Input schema: {sample_dataset.input_schema}") + print(f"Output schema: {sample_dataset.output_schema}") + + # Inspect a sample + sample = next(iter(sample_dataset)) + print("\nSample structure:") + print(f" Patient ID: {sample['patient_id']}") + print(f"ICD Codes: {sample['icd_codes']}") + print(f" Labs shape: {len(sample['labs'][0])} timesteps") + print(f" Mortality: {sample['mortality']}") + + # STEP 3: Split dataset + train_dataset, val_dataset, test_dataset = split_by_patient( + sample_dataset, [0.8, 0.1, 0.1] + ) + + # Create dataloaders + train_loader = get_dataloader(train_dataset, batch_size=32, shuffle=True) + val_loader = get_dataloader(val_dataset, batch_size=32, shuffle=False) + test_loader = get_dataloader(test_dataset, batch_size=32, shuffle=False) + + # STEP 4: Initialize StageNet model + model = StageNet( + dataset=sample_dataset, + embedding_dim=128, + chunk_size=128, + levels=3, + dropout=0.3, + ) + + num_params = sum(p.numel() for p in model.parameters()) + print(f"\nModel initialized with {num_params} parameters") + + # STEP 5: Train the model + trainer = Trainer( + model=model, + device="cpu", # or "cpu" + metrics=["pr_auc", "roc_auc", "accuracy", "f1"], + enable_logging=False, + ) + print("\nStarting training...") + + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=5, + monitor="roc_auc", + optimizer_params={"lr": 1e-5}, + ) + + # STEP 6: Evaluate on test set + results = trainer.evaluate(test_loader) + print("\nTest Results:") + for metric, value in results.items(): + print(f" {metric}: {value:.4f}") + + # STEP 7: Inspect model predictions + sample_batch = next(iter(test_loader)) + with torch.no_grad(): + output = model(**sample_batch) + + print("\nSample predictions:") + print(f" Predicted probabilities: {output['y_prob'][:5]}") + print(f" True labels: {output['y_true'][:5]}") diff --git a/pyhealth/data/data.py b/pyhealth/data/data.py index 2a6d3a45c..bb305a608 100644 --- a/pyhealth/data/data.py +++ b/pyhealth/data/data.py @@ -2,10 +2,10 @@ from dataclasses import dataclass, field from datetime import datetime from functools import reduce -from typing import Dict, List, Mapping, Optional, Union +from typing import Dict, List, Mapping, Optional, Union, Any import numpy as np -import polars as pl +import pandas as pd @dataclass(frozen=True) @@ -20,9 +20,9 @@ class Event: event_type: str timestamp: datetime - attr_dict: Mapping[str, any] = field(default_factory=dict) + attr_dict: Mapping[str, Any] = field(default_factory=dict) - def __init__(self, event_type: str, timestamp: datetime = None, **kwargs): + def __init__(self, event_type: str, timestamp: datetime | None = None, **kwargs): """Initialize an Event instance. Args: @@ -50,7 +50,7 @@ def __init__(self, event_type: str, timestamp: datetime = None, **kwargs): object.__setattr__(self, "attr_dict", attr_dict) @classmethod - def from_dict(cls, d: Dict[str, any]) -> "Event": + def from_dict(cls, d: Dict[str, Any]) -> "Event": """Create an Event instance from a dictionary. Args: @@ -61,12 +61,12 @@ def from_dict(cls, d: Dict[str, any]) -> "Event": """ timestamp: datetime = d["timestamp"] event_type: str = d["event_type"] - attr_dict: Dict[str, any] = { + attr_dict: Dict[str, Any] = { k.split("/", 1)[1]: v for k, v in d.items() if k.split("/")[0] == event_type } return cls(event_type=event_type, timestamp=timestamp, attr_dict=attr_dict) - def __getitem__(self, key: str) -> any: + def __getitem__(self, key: str) -> Any: """Get an attribute by key. Args: @@ -95,7 +95,7 @@ def __contains__(self, key: str) -> bool: return True return key in self.attr_dict - def __getattr__(self, key: str) -> any: + def __getattr__(self, key: str) -> Any: """Get an attribute using dot notation. Args: @@ -119,57 +119,68 @@ class Patient: Attributes: patient_id (str): Unique patient identifier. - data_source (pl.DataFrame): DataFrame containing all events, sorted by timestamp. - event_type_partitions (Dict[str, pl.DataFrame]): Dictionary mapping event types to their respective DataFrame partitions. + data_source (pd.DataFrame): DataFrame containing all events, sorted by timestamp. + event_type_partitions (Dict[str, pd.DataFrame]): Dictionary mapping event types to their respective DataFrame partitions. """ - def __init__(self, patient_id: str, data_source: pl.DataFrame) -> None: + def __init__(self, patient_id: str, data_source: pd.DataFrame) -> None: """ Initialize a Patient instance. Args: patient_id (str): Unique patient identifier. - data_source (pl.DataFrame): DataFrame containing all events. + data_source (pd.DataFrame): DataFrame containing all events. """ self.patient_id = patient_id - self.data_source = data_source.sort("timestamp") - self.event_type_partitions = self.data_source.partition_by("event_type", maintain_order=True, as_dict=True) + self.data_source = data_source.sort_values("timestamp") + self.event_type_partitions = { + (event_type,): group.copy() + for event_type, group in self.data_source.groupby("event_type", sort=False) + } - def _filter_by_time_range_regular(self, df: pl.DataFrame, start: Optional[datetime], end: Optional[datetime]) -> pl.DataFrame: + def _filter_by_time_range_regular(self, df: pd.DataFrame, start: Optional[datetime], end: Optional[datetime]) -> pd.DataFrame: """Regular filtering by time. Time complexity: O(n).""" if start is not None: - df = df.filter(pl.col("timestamp") >= start) + df = df[df["timestamp"] >= start] if end is not None: - df = df.filter(pl.col("timestamp") <= end) + df = df[df["timestamp"] <= end] return df - def _filter_by_time_range_fast(self, df: pl.DataFrame, start: Optional[datetime], end: Optional[datetime]) -> pl.DataFrame: + def _filter_by_time_range_fast(self, df: pd.DataFrame, start: Optional[datetime], end: Optional[datetime]) -> pd.DataFrame: """Fast filtering by time using binary search on sorted timestamps. Time complexity: O(log n).""" if start is None and end is None: return df - df = df.filter(pl.col("timestamp").is_not_null()) - ts_col = df["timestamp"].to_numpy() + df = df[df["timestamp"].notna()].sort_values("timestamp") + ts_col = pd.to_datetime(df["timestamp"]).to_numpy(dtype="datetime64[ns]") start_idx = 0 end_idx = len(ts_col) if start is not None: - start_idx = np.searchsorted(ts_col, start, side="left") + start_idx = np.searchsorted(ts_col, np.datetime64(start, "ns"), side="left") if end is not None: - end_idx = np.searchsorted(ts_col, end, side="right") - return df.slice(start_idx, end_idx - start_idx) + end_idx = np.searchsorted(ts_col, np.datetime64(end, "ns"), side="right") + return df.iloc[start_idx:end_idx] - def _filter_by_event_type_regular(self, df: pl.DataFrame, event_type: Optional[str]) -> pl.DataFrame: + def _filter_by_event_type_regular(self, df: pd.DataFrame, event_type: Optional[str]) -> pd.DataFrame: """Regular filtering by event type. Time complexity: O(n).""" if event_type: - df = df.filter(pl.col("event_type") == event_type) + df = df[df["event_type"] == event_type] return df - def _filter_by_event_type_fast(self, df: pl.DataFrame, event_type: Optional[str]) -> pl.DataFrame: + def _filter_by_event_type_fast(self, df: pd.DataFrame, event_type: Optional[str]) -> pd.DataFrame: """Fast filtering by event type using pre-built event type index. Time complexity: O(1).""" if event_type: - return self.event_type_partitions.get((event_type,), df[:0]) + return self.event_type_partitions.get((event_type,), df.iloc[0:0]) else: return df + def get_events_py(self, **kawargs) -> List[Event]: + """Type-safe wrapper for get_events.""" + return self.get_events(**kawargs, return_df=False) # type: ignore + + def get_events_df(self, **kawargs) -> pd.DataFrame: + """DataFrame wrapper for get_events.""" + return self.get_events(**kawargs, return_df=True) # type: ignore + def get_events( self, event_type: Optional[str] = None, @@ -177,7 +188,7 @@ def get_events( end: Optional[datetime] = None, filters: Optional[List[tuple]] = None, return_df: bool = False, - ) -> Union[pl.DataFrame, List[Event]]: + ) -> Union[pd.DataFrame, List[Event]]: """Get events with optional type and time filters. Args: @@ -191,7 +202,7 @@ def get_events( and time filters. The logic is "AND" between different filters. Returns: - Union[pl.DataFrame, List[Event]]: Filtered events as a DataFrame + Union[pd.DataFrame, List[Event]]: Filtered events as a DataFrame or a list of Event objects. """ # faster filtering (by default) @@ -213,7 +224,7 @@ def get_events( f"Invalid filter format: {filt} (must be tuple of (attr, op, value))" ) attr, op, val = filt - col_expr = pl.col(f"{event_type}/{attr}") + col_expr = df[f"{event_type}/{attr}"] # Build operator expression if op == "==": exprs.append(col_expr == val) @@ -230,7 +241,7 @@ def get_events( else: raise ValueError(f"Unsupported operator: {op} in filter {filt}") if exprs: - df = df.filter(reduce(operator.and_, exprs)) + df = df.loc[reduce(operator.and_, exprs)] if return_df: return df - return [Event.from_dict(d) for d in df.to_dicts()] + return [Event.from_dict(d) for d in df.to_dict(orient="records")] # type: ignore diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 3390453ff..787debe88 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -1,23 +1,38 @@ import logging import os -import pickle from abc import ABC -from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path -from typing import Dict, Iterator, List, Optional +from typing import Dict, Iterator, List, Optional, Any, Callable from urllib.parse import urlparse, urlunparse +import uuid +import json +import functools +import operator +from collections import namedtuple +import pickle +from distributed import Client, LocalCluster, get_client +from dask.utils import parse_bytes import polars as pl +import pandas as pd +import dask.dataframe as dd +import pyarrow as pa +import pyarrow.csv as pv +import pyarrow.parquet as pq import requests -from tqdm import tqdm +import platformdirs +from litdata.streaming import StreamingDataset +from dask.distributed import progress +import xxhash from ..data import Patient from ..tasks import BaseTask from ..processors.base_processor import FeatureProcessor from .configs import load_yaml_config from .sample_dataset import SampleDataset -from .utils import _convert_for_cache, _restore_from_cache +# Set logging level for distributed to ERROR to reduce verbosity +logging.getLogger("distributed").setLevel(logging.ERROR) logger = logging.getLogger(__name__) @@ -56,42 +71,60 @@ def path_exists(path: str) -> bool: return Path(path).exists() -def scan_csv_gz_or_csv_tsv(path: str) -> pl.LazyFrame: +def alt_path(path: str) -> str: """ - Scan a CSV.gz, CSV, TSV.gz, or TSV file and returns a LazyFrame. - It will fall back to the other extension if not found. + Get the alternative path by switching between .csv.gz, .csv, .tsv.gz, and .tsv extensions. Args: - path (str): URL or local path to a .csv, .csv.gz, .tsv, or .tsv.gz file + path (str): Original file path. Returns: - pl.LazyFrame: The LazyFrame for the CSV.gz, CSV, TSV.gz, or TSV file. + str: Alternative file path. """ - - def scan_file(file_path: str) -> pl.LazyFrame: - separator = "\t" if ".tsv" in file_path else "," - return pl.scan_csv(file_path, separator=separator, infer_schema=False) - - if path_exists(path): - return scan_file(path) - - # Try the alternative extension if path.endswith(".csv.gz"): - alt_path = path[:-3] # Remove .gz -> try .csv + return path[:-3] # Remove .gz -> try .csv elif path.endswith(".csv"): - alt_path = f"{path}.gz" # Add .gz -> try .csv.gz + return f"{path}.gz" # Add .gz -> try .csv.gz elif path.endswith(".tsv.gz"): - alt_path = path[:-3] # Remove .gz -> try .tsv + return path[:-3] # Remove .gz -> try .tsv elif path.endswith(".tsv"): - alt_path = f"{path}.gz" # Add .gz -> try .tsv.gz + return f"{path}.gz" # Add .gz -> try .tsv.gz else: - raise FileNotFoundError(f"Path does not have expected extension: {path}") + raise ValueError(f"Path does not have expected extension: {path}") + + +def _pickle(datum: dict[str, Any]) -> dict[str, bytes]: + return {k: pickle.dumps(v) for k, v in datum.items()} + + +def _unpickle(datum: dict[str, bytes]) -> dict[str, Any]: + return {k: pickle.loads(v) for k, v in datum.items()} + + +def _patient_bucket(patient_id: str, n_partitions: int) -> int: + """Hash patient_id to a bucket number.""" + bucket = int(xxhash.xxh64_intdigest(patient_id) % n_partitions) + return bucket - if path_exists(alt_path): - logger.info(f"Original path does not exist. Using alternative: {alt_path}") - return scan_file(alt_path) - raise FileNotFoundError(f"Neither path exists: {path} or {alt_path}") +def _transform_fn( + input: tuple[int, str, BaseTask], +) -> Iterator[Dict[str, Any]]: + (bucket_id, merged_cache, task) = input + path = f"{merged_cache}/bucket={bucket_id}" + df = pd.read_parquet(path) + # TODO: Need to make sure pre_filter works with pandas DataFrame + df = task.pre_filter(df) + + # This is more efficient than reading patient by patient + grouped = df.groupby("patient_id") + + for patient_id, patient_df in grouped: + patient = Patient(patient_id=str(patient_id), data_source=patient_df) + # TODO: Need to make sure task(patient) works with pandas DataFrame + for sample in task(patient): + # Schema is too complex to be handled by LitData, so we pickle the sample here + yield _pickle(sample) class BaseDataset(ABC): @@ -110,8 +143,12 @@ def __init__( self, root: str, tables: List[str], - dataset_name: Optional[str] = None, - config_path: Optional[str] = None, + dataset_name: str | None = None, + config_path: str | None = None, + cache_dir: str | Path | None = None, + num_workers: int = 1, + mem_per_worker: str | int = "8GB", + compute: bool = True, dev: bool = False, ): """Initializes the BaseDataset. @@ -119,87 +156,220 @@ def __init__( Args: root (str): The root directory where dataset files are stored. tables (List[str]): List of table names to load. - dataset_name (Optional[str]): Name of the dataset. Defaults to class name. - config_path (Optional[str]): Path to the configuration YAML file. + dataset_name (str | None): Name of the dataset. Defaults to class name. + config_path (str | None): Path to the configuration YAML file. + cache_dir (str | Path | None): Directory to cache processed data. If None, a default + cache directory will be created under the platform's cache directory. dev (bool): Whether to run in dev mode (limits to 1000 patients). """ if len(set(tables)) != len(tables): logger.warning("Duplicate table names in tables list. Removing duplicates.") tables = list(set(tables)) + self.root = root self.tables = tables self.dataset_name = dataset_name or self.__class__.__name__ - self.config = load_yaml_config(config_path) self.dev = dev + if config_path is not None: + self.config = load_yaml_config(config_path) + + # Resource allocation for table joining and processing + self.num_workers = num_workers + if isinstance(mem_per_worker, str): + self.mem_per_worker = parse_bytes(mem_per_worker) + else: + self.mem_per_worker = mem_per_worker + + # Cached value for property + self._cache_dir = cache_dir logger.info( f"Initializing {self.dataset_name} dataset from {self.root} (dev mode: {self.dev})" ) + if compute: + _ = self._joined_cache() - self.global_event_df = self.load_data() + @property + def cache_dir(self) -> Path: + """Returns the cache directory path. - # Cached attributes - self._collected_global_event_df = None - self._unique_patient_ids = None + Returns: + Path: The cache directory path. + """ + if self._cache_dir is not None: + return Path(self._cache_dir) + + id_str = json.dumps( + { + "root": self.root, + "tables": sorted(self.tables), + "dataset_name": self.dataset_name, + "dev": self.dev, + }, + sort_keys=True, + ) + cache_dir = Path(platformdirs.user_cache_dir(appname="pyhealth")) / str( + uuid.uuid5(uuid.NAMESPACE_DNS, id_str) + ) + print(f"No cache_dir provided. Using default cache dir: {cache_dir}") + self._cache_dir = cache_dir - @property - def collected_global_event_df(self) -> pl.DataFrame: - """Collects and returns the global event data frame. + return cache_dir + + def _table_cache(self, table_name: str, source_path: str | None = None) -> str: + """Generates the cache path for a specific table. If the cached Parquet file does not exist, + it will convert the source CSV/TSV file to Parquet and save it to the cache. + + Args: + table_name (str): The name of the table. + source_path (str | None): The source CSV/TSV file path. If None, it assumes the + Parquet file already exists in the cache. Returns: - pl.DataFrame: The collected global event data frame. + str: The cache path for the table. """ - if self._collected_global_event_df is None: - logger.info("Collecting global event dataframe...") - - # Collect the dataframe - with dev mode limiting if applicable - df = self.global_event_df - # TODO: dev doesn't seem to improve the speed / memory usage - if self.dev: - # Limit the number of patients in dev mode - logger.info("Dev mode enabled: limiting to 1000 patients") - limited_patients = df.select(pl.col("patient_id")).unique().limit(1000) - df = df.join(limited_patients, on="patient_id", how="inner") - - self._collected_global_event_df = df.collect() - - # Profile the Polars collect() operation (commented out by default) - # self._collected_global_event_df, profile = df.profile() - # profile = profile.with_columns([ - # (pl.col("end") - pl.col("start")).alias("duration"), - # ]) - # profile = profile.with_columns([ - # (pl.col("duration") / profile["duration"].sum() * 100).alias("percentage") - # ]) - # profile = profile.sort("duration", descending=True) - # with pl.Config() as cfg: - # cfg.set_tbl_rows(-1) - # cfg.set_fmt_str_lengths(200) - # print(profile) + # Ensure the tables cache directory exists + (self.cache_dir / "tables").mkdir(parents=True, exist_ok=True) + ret_path = str(self.cache_dir / "tables" / f"{table_name}.parquet") + + if not path_exists(ret_path): + if source_path is None: + raise FileNotFoundError( + f"Table {table_name} not found in cache and no source_path provided." + ) - logger.info( - f"Collected dataframe with shape: {self._collected_global_event_df.shape}" + # Check if source_path exists, else try alternative path + if not path_exists(source_path): + if not path_exists(alt_path(source_path)): + raise FileNotFoundError( + f"Neither path exists: {source_path} or {alt_path(source_path)}" + ) + source_path = alt_path(source_path) + + # Determine delimiter based on file extension + delimiter = ( + "\t" + if source_path.endswith(".tsv") or source_path.endswith(".tsv.gz") + else "," + ) + + # Always infer schema as string to avoid incorrect type inference + schema_reader = pv.open_csv( + source_path, + read_options=pv.ReadOptions(block_size=1 << 26), # 64 MB + parse_options=pv.ParseOptions(delimiter=delimiter), + ) + schema = pa.schema( + [pa.field(name, pa.string()) for name in schema_reader.schema.names] + ) + + # Convert CSV/TSV to Parquet + csv_reader = pv.open_csv( + source_path, + read_options=pv.ReadOptions(block_size=1 << 26), # 64 MB + parse_options=pv.ParseOptions(delimiter=delimiter), + convert_options=pv.ConvertOptions(column_types=schema), ) + with pq.ParquetWriter(ret_path, csv_reader.schema) as writer: + for batch in csv_reader: + writer.write_batch(batch) + + return ret_path - return self._collected_global_event_df + def _joined_cache(self) -> str: + """Collects and returns the global event data frame. - def load_data(self) -> pl.LazyFrame: + Returns: + dd.DataFrame: The collected global event data frame. + """ + ret_path = str(self.cache_dir / "global_event_df.parquet") + + if not path_exists(ret_path): + with LocalCluster( + n_workers=self.num_workers, + threads_per_worker=1, + memory_limit=self.mem_per_worker, + ) as cluster: + with Client(cluster) as client: + global_event_df = self.load_data() + # In dev mode, limit to 1000 patients + if self.dev: + patients = ( + global_event_df["patient_id"].unique().head(1000).tolist() + ) + filter = global_event_df["patient_id"].isin(patients) + global_event_df: dd.DataFrame = global_event_df[filter] + + # Collect unique patient IDs + patients = global_event_df["patient_id"].unique().compute().tolist() + n_patients = len(patients) + n_events = global_event_df.shape[0].compute() + logger.info( + f"Collected {n_events} events for {n_patients} patients." + ) + + # Estimate memory usage and partitioning + partition_size = ( + self.mem_per_worker // 16 + ) # Use 1/16 of worker memory for safety + estimated_mem_usage = ( + global_event_df.memory_usage(deep=False).compute().sum() + ) + n_partitions = ( + estimated_mem_usage // partition_size + ) + 1 # Calculate number of partitions based on memory usage + n_partitions = max( + n_partitions, self.num_workers + ) # At least num_workers partitions + + bucket = global_event_df["patient_id"].apply( + lambda pid: _patient_bucket(pid, n_partitions), + meta=("patient_id", "int"), + ) + global_event_df = global_event_df.assign(bucket=bucket) + logger.info( + f"Estimated size {estimated_mem_usage / (1024**3):.2f} GB, write to {n_partitions} partitions." + ) + + handle = global_event_df.to_parquet( + ret_path, + partition_on=["bucket"], + write_index=False, + compute=False, + ) + future = client.compute(handle) + progress(future) + with open(ret_path + "/index.json", "w") as future: + json.dump( + { + "n_partitions": int(n_partitions), + "n_patients": int(n_patients), + "n_events": int(n_events), + "patients": patients, + "algorithm": "xxhash64_modulo_partitioning", + }, + future, + ) + + return ret_path + + def load_data(self) -> dd.DataFrame: """Loads data from the specified tables. Returns: - pl.LazyFrame: A concatenated lazy frame of all tables. + dd.DataFrame: A concatenated lazy frame of all tables. """ frames = [self.load_table(table.lower()) for table in self.tables] - return pl.concat(frames, how="diagonal") + return dd.concat(frames, axis=0, join="outer") - def load_table(self, table_name: str) -> pl.LazyFrame: + def load_table(self, table_name: str) -> dd.DataFrame: """Loads a table and processes joins if specified. Args: table_name (str): The name of the table to load. Returns: - pl.LazyFrame: The processed lazy frame for the table. + dd.DataFrame: The processed lazy frame for the table. Raises: ValueError: If the table is not found in the config. @@ -208,42 +378,44 @@ def load_table(self, table_name: str) -> pl.LazyFrame: if table_name not in self.config.tables: raise ValueError(f"Table {table_name} not found in config") - def _to_lower(col_name: str) -> str: - lower_name = col_name.lower() - if lower_name != col_name: - logger.warning("Renaming column %s to lowercase %s", col_name, lower_name) - return lower_name - table_cfg = self.config.tables[table_name] csv_path = f"{self.root}/{table_cfg.file_path}" csv_path = clean_path(csv_path) logger.info(f"Scanning table: {table_name} from {csv_path}") - df = scan_csv_gz_or_csv_tsv(csv_path) - - # Convert column names to lowercase before calling preprocess_func - df = df.rename(_to_lower) + df: dd.DataFrame = dd.read_parquet( + self._table_cache(table_name, source_path=csv_path), + split_row_groups=True, # type: ignore + blocksize="64MB", + ) # Check if there is a preprocessing function for this table + # TODO: we need to update the preprocess function to work with Dask DataFrame + # for all datasets. Only care about MIMIC4 for now. preprocess_func = getattr(self, f"preprocess_{table_name}", None) if preprocess_func is not None: logger.info( f"Preprocessing table: {table_name} with {preprocess_func.__name__}" ) - df = preprocess_func(df) + df: dd.DataFrame = preprocess_func(df) # Handle joins - for join_cfg in table_cfg.join: + for i, join_cfg in enumerate(table_cfg.join): other_csv_path = f"{self.root}/{join_cfg.file_path}" other_csv_path = clean_path(other_csv_path) logger.info(f"Joining with table: {other_csv_path}") - join_df = scan_csv_gz_or_csv_tsv(other_csv_path) - join_df = join_df.rename(_to_lower) + join_df: dd.DataFrame = dd.read_parquet( + self._table_cache(f"{table_name}_join_{i}", other_csv_path), + split_row_groups=True, # type: ignore + blocksize="64MB", + ) join_key = join_cfg.on columns = join_cfg.columns how = join_cfg.how - df = df.join(join_df.select([join_key] + columns), on=join_key, how=how) + df: dd.DataFrame = df.merge( + join_df[[join_key] + columns], on=join_key, how=how + ) patient_id_col = table_cfg.patient_id timestamp_col = table_cfg.timestamp @@ -254,37 +426,38 @@ def _to_lower(col_name: str) -> str: if timestamp_col: if isinstance(timestamp_col, list): # Concatenate all timestamp parts in order with no separator - combined_timestamp = pl.concat_str( - [pl.col(col) for col in timestamp_col] - ).str.strptime(pl.Datetime, format=timestamp_format, strict=True) - timestamp_expr = combined_timestamp + timestamp_series: dd.Series = functools.reduce( + operator.add, (df[col].astype(str) for col in timestamp_col) + ) else: # Single timestamp column - timestamp_expr = pl.col(timestamp_col).str.strptime( - pl.Datetime, format=timestamp_format, strict=True - ) + timestamp_series: dd.Series = df[timestamp_col].astype(str) + timestamp_series: dd.Series = dd.to_datetime( + timestamp_series, + format=timestamp_format, + errors="raise", + ) + df: dd.DataFrame = df.assign( + timestamp=timestamp_series.astype("datetime64[ms]") + ) else: - timestamp_expr = pl.lit(None, dtype=pl.Datetime) + df: dd.DataFrame = df.assign(timestamp=pd.NaT) # If patient_id_col is None, use row index as patient_id - patient_id_expr = ( - pl.col(patient_id_col).cast(pl.Utf8) - if patient_id_col - else pl.int_range(0, pl.count()).cast(pl.Utf8) - ) - base_columns = [ - patient_id_expr.alias("patient_id"), - pl.lit(table_name).cast(pl.Utf8).alias("event_type"), - # ms should be sufficient for most cases - timestamp_expr.cast(pl.Datetime(time_unit="ms")).alias("timestamp"), - ] + if patient_id_col: + df: dd.DataFrame = df.assign(patient_id=df[patient_id_col].astype(str)) + else: + df: dd.DataFrame = df.reset_index(drop=True) + df: dd.DataFrame = df.assign(patient_id=df.index.astype(str)) + + df: dd.DataFrame = df.assign(event_type=table_name) - # Flatten attribute columns with event_type prefix - attribute_columns = [ - pl.col(attr.lower()).alias(f"{table_name}/{attr}") for attr in attribute_cols - ] + rename_attr = {attr: f"{table_name}/{attr}" for attr in attribute_cols} + df: dd.DataFrame = df.rename(columns=rename_attr) - event_frame = df.select(base_columns + attribute_columns) + attr_cols = [rename_attr[attr] for attr in attribute_cols] + final_cols = ["patient_id", "event_type", "timestamp"] + attr_cols + event_frame = df[final_cols] return event_frame @@ -295,15 +468,9 @@ def unique_patient_ids(self) -> List[str]: Returns: List[str]: List of unique patient IDs. """ - if self._unique_patient_ids is None: - self._unique_patient_ids = ( - self.collected_global_event_df.select("patient_id") - .unique() - .to_series() - .to_list() - ) - logger.info(f"Found {len(self._unique_patient_ids)} unique patient IDs") - return self._unique_patient_ids + with open(self._joined_cache() + "/index.json", "r") as f: + index_info = json.load(f) + return index_info["patients"] def get_patient(self, patient_id: str) -> Patient: """Retrieves a Patient object for the given patient ID. @@ -320,8 +487,18 @@ def get_patient(self, patient_id: str) -> Patient: assert ( patient_id in self.unique_patient_ids ), f"Patient {patient_id} not found in dataset" - df = self.collected_global_event_df.filter(pl.col("patient_id") == patient_id) - return Patient(patient_id=patient_id, data_source=df) + + path = self._joined_cache() + with open(f"{path}/index.json", "rb") as f: + n_partitions = json.load(f)["n_partitions"] + bucket = _patient_bucket(patient_id, n_partitions) + path = f"{path}/bucket={bucket}" + df = pd.read_parquet(path) + patient = Patient( + patient_id=patient_id, + data_source=df[df["patient_id"] == patient_id], + ) + return patient def iter_patients(self, df: Optional[pl.LazyFrame] = None) -> Iterator[Patient]: """Yields Patient objects for each unique patient in the dataset. @@ -329,21 +506,20 @@ def iter_patients(self, df: Optional[pl.LazyFrame] = None) -> Iterator[Patient]: Yields: Iterator[Patient]: An iterator over Patient objects. """ - if df is None: - df = self.collected_global_event_df - grouped = df.group_by("patient_id") - - for patient_id, patient_df in grouped: - patient_id = patient_id[0] - yield Patient(patient_id=patient_id, data_source=patient_df) + for patitent_id in self.unique_patient_ids: + yield self.get_patient(patitent_id) def stats(self) -> None: """Prints statistics about the dataset.""" - df = self.collected_global_event_df + with open(self._joined_cache() + "/index.json", "r") as f: + index_info = json.load(f) + n_patients = index_info["n_patients"] + n_events = index_info["n_events"] + print(f"Dataset: {self.dataset_name}") print(f"Dev mode: {self.dev}") - print(f"Number of patients: {df['patient_id'].n_unique()}") - print(f"Number of events: {df.height}") + print(f"Number of patients: {n_patients}") + print(f"Number of events: {n_events}") @property def default_task(self) -> Optional[BaseTask]: @@ -358,7 +534,7 @@ def set_task( self, task: Optional[BaseTask] = None, num_workers: int = 1, - cache_dir: Optional[str] = None, + cache_dir: str | Path | None = None, cache_format: str = "parquet", input_processors: Optional[Dict[str, FeatureProcessor]] = None, output_processors: Optional[Dict[str, FeatureProcessor]] = None, @@ -395,102 +571,45 @@ def set_task( f"Setting task {task.task_name} for {self.dataset_name} base dataset..." ) - # Check for cached data if cache_dir is provided - samples = None - if cache_dir is not None: - cache_filename = f"{task.task_name}.{cache_format}" - cache_path = Path(cache_dir) / cache_filename - if cache_path.exists(): - logger.info(f"Loading cached samples from {cache_path}") - try: - if cache_format == "parquet": - # Load samples from parquet file - cached_df = pl.read_parquet(cache_path) - samples = [ - _restore_from_cache(row) for row in cached_df.to_dicts() - ] - elif cache_format == "pickle": - # Load samples from pickle file - with open(cache_path, "rb") as f: - samples = pickle.load(f) - else: - msg = f"Unsupported cache format: {cache_format}" - raise ValueError(msg) - logger.info(f"Loaded {len(samples)} cached samples") - except Exception as e: - logger.warning( - "Failed to load cached data: %s. Regenerating...", - e, - ) - samples = None - - # Generate samples if not loaded from cache - if samples is None: - logger.info(f"Generating samples with {num_workers} worker(s)...") - filtered_global_event_df = task.pre_filter(self.collected_global_event_df) - samples = [] - - if num_workers == 1: - # single-threading (by default) - for patient in tqdm( - self.iter_patients(filtered_global_event_df), - total=filtered_global_event_df["patient_id"].n_unique(), - desc=(f"Generating samples for {task.task_name} " "with 1 worker"), - smoothing=0, - ): - samples.extend(task(patient)) - else: - # multi-threading (not recommended) - logger.info( - f"Generating samples for {task.task_name} with " - f"{num_workers} workers" - ) - patients = list(self.iter_patients(filtered_global_event_df)) - with ThreadPoolExecutor(max_workers=num_workers) as executor: - futures = [executor.submit(task, patient) for patient in patients] - for future in tqdm( - as_completed(futures), - total=len(futures), - desc=( - f"Collecting samples for {task.task_name} " - f"from {num_workers} workers" - ), - ): - samples.extend(future.result()) - - # Cache the samples if cache_dir is provided - if cache_dir is not None: - cache_path = Path(cache_dir) / cache_filename - cache_path.parent.mkdir(parents=True, exist_ok=True) - logger.info(f"Caching samples to {cache_path}") - try: - if cache_format == "parquet": - # Save samples as parquet file - samples_for_cache = [ - _convert_for_cache(sample) for sample in samples - ] - samples_df = pl.DataFrame(samples_for_cache) - samples_df.write_parquet(cache_path) - elif cache_format == "pickle": - # Save samples as pickle file - with open(cache_path, "wb") as f: - pickle.dump(samples, f) - else: - msg = f"Unsupported cache format: {cache_format}" - raise ValueError(msg) - logger.info(f"Successfully cached {len(samples)} samples") - except Exception as e: - logger.warning(f"Failed to cache samples: {e}") + if cache_format != "parquet": + logger.warning( + f"This argument is no longer supported: cache_format={cache_format}" + ) + if cache_dir is None: + cache_dir = str(self.cache_dir / "tasks" / task.task_name) + logger.info( + "No cache_dir provided. Using default task cache dir: %s", cache_dir + ) + + if not path_exists(str(cache_dir)): + import litdata as ld + + with open(self._joined_cache() + "/index.json", "r") as f: + index_info = json.load(f) + n_partitions = index_info["n_partitions"] + inputs = [(i, self._joined_cache(), task) for i in range(n_partitions)] + + ld.optimize( + fn=_transform_fn, + inputs=inputs, + output_dir=str(cache_dir), + num_workers=num_workers, + chunk_bytes="64MB", + ) + + streaming_dataset = StreamingDataset(str(cache_dir), transform=_unpickle) sample_dataset = SampleDataset( - samples, - input_schema=task.input_schema, - output_schema=task.output_schema, + streaming_dataset, + input_schema=task.input_schema, # type: ignore + output_schema=task.output_schema, # type: ignore dataset_name=self.dataset_name, - task_name=task, + task_name=task.task_name, input_processors=input_processors, output_processors=output_processors, ) - logger.info(f"Generated {len(samples)} samples for task {task.task_name}") + logger.info( + f"Generated {len(sample_dataset)} samples for task {task.task_name}" + ) return sample_dataset diff --git a/pyhealth/datasets/mimic4.py b/pyhealth/datasets/mimic4.py index 05321dedb..03ef61a1a 100644 --- a/pyhealth/datasets/mimic4.py +++ b/pyhealth/datasets/mimic4.py @@ -1,10 +1,11 @@ import logging import os import warnings -from typing import Dict, List, Optional +from typing import List, Optional +from pathlib import Path import pandas as pd -import polars as pl +import dask.dataframe as dd try: import psutil @@ -20,7 +21,7 @@ def log_memory_usage(tag=""): """Log current memory usage if psutil is available.""" if HAS_PSUTIL: - process = psutil.Process(os.getpid()) + process = psutil.Process(os.getpid()) # type: ignore mem_info = process.memory_info() logger.info(f"Memory usage {tag}: {mem_info.rss / (1024 * 1024):.1f} MB") else: @@ -210,19 +211,26 @@ def __init__( note_config_path: Optional[str] = None, cxr_config_path: Optional[str] = None, dataset_name: str = "mimic4", + cache_dir: str | Path | None = None, + num_workers: int = 1, + mem_per_worker: str | int = "8GB", + compute: bool = True, dev: bool = False, ): + super().__init__( + root=f"{str(ehr_root)},{str(note_root)},{str(cxr_root)}", + tables=(ehr_tables or []) + (note_tables or []) + (cxr_tables or []), + dataset_name=dataset_name, + cache_dir=cache_dir, + num_workers=num_workers, + mem_per_worker=mem_per_worker, + compute=False, # defer compute to later, we need to aggregate all sub-datasets first + dev=dev, + ) log_memory_usage("Starting MIMIC4Dataset init") # Initialize child datasets - self.dataset_name = dataset_name self.sub_datasets = {} - self.root = None - self.tables = None - self.config = None - # Dev flag is only used in the MIMIC4Dataset class - # to ensure the same set of patients are used for all sub-datasets. - self.dev = dev # We need at least one root directory if not any([ehr_root, note_root, cxr_root]): @@ -240,6 +248,10 @@ def __init__( root=ehr_root, tables=ehr_tables, config_path=ehr_config_path, + cache_dir=f"{self.cache_dir}/ehr", + num_workers=num_workers, + mem_per_worker=mem_per_worker, + compute=False, # defer compute to later, we need to aggregate all sub-datasets first ) log_memory_usage("After EHR dataset initialization") @@ -250,6 +262,10 @@ def __init__( root=note_root, tables=note_tables, config_path=note_config_path, + cache_dir=f"{self.cache_dir}/note", + num_workers=num_workers, + mem_per_worker=mem_per_worker, + compute=False, # defer compute to later, we need to aggregate all sub-datasets first ) log_memory_usage("After Note dataset initialization") @@ -260,21 +276,18 @@ def __init__( root=cxr_root, tables=cxr_tables, config_path=cxr_config_path, + cache_dir=f"{self.cache_dir}/cxr", + num_workers=num_workers, + mem_per_worker=mem_per_worker, + compute=False, # defer compute to later, we need to aggregate all sub-datasets first ) log_memory_usage("After CXR dataset initialization") - # Combine data from all sub-datasets - log_memory_usage("Before combining data") - self.global_event_df = self._combine_data() - log_memory_usage("After combining data") - - # Cache attributes - self._collected_global_event_df = None - self._unique_patient_ids = None - + if compute: + _ = self._joined_cache() log_memory_usage("Completed MIMIC4Dataset init") - def _combine_data(self) -> pl.LazyFrame: + def load_data(self) -> dd.DataFrame: """ Combines data from all initialized sub-datasets into a unified global event dataframe. @@ -285,12 +298,14 @@ def _combine_data(self) -> pl.LazyFrame: # Collect global event dataframes from all sub-datasets for dataset_type, dataset in self.sub_datasets.items(): + dataset_type: str + dataset: BaseDataset logger.info(f"Combining data from {dataset_type} dataset") - frames.append(dataset.global_event_df) + frames.append(dataset.load_data()) # Concatenate all frames logger.info("Creating combined dataframe") if len(frames) == 1: return frames[0] else: - return pl.concat(frames, how="diagonal") + return dd.concat(frames, axis=0, join="outer") diff --git a/pyhealth/datasets/sample_dataset.py b/pyhealth/datasets/sample_dataset.py index 54c40420c..f1bcd8288 100644 --- a/pyhealth/datasets/sample_dataset.py +++ b/pyhealth/datasets/sample_dataset.py @@ -1,14 +1,18 @@ -from typing import Any, Dict, List, Optional, Tuple, Union, Type +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Type +from collections.abc import Sequence +from bisect import bisect_right import inspect -from torch.utils.data import Dataset -from tqdm import tqdm +from torch.utils.data import IterableDataset +from litdata.streaming import StreamingDataset +from litdata.utilities.train_test_split import deepcopy_dataset +import pickle from ..processors import get_processor from ..processors.base_processor import FeatureProcessor -class SampleDataset(Dataset): +class SampleDataset(IterableDataset): """Sample dataset class for handling and processing data samples. Attributes: @@ -23,7 +27,7 @@ class SampleDataset(Dataset): def __init__( self, - samples: List[Dict], + dataset: StreamingDataset, input_schema: Dict[str, Union[str, Type[FeatureProcessor], FeatureProcessor, Tuple[Union[str, Type[FeatureProcessor]], Dict[str, Any]]]], output_schema: Dict[str, Union[str, Type[FeatureProcessor], FeatureProcessor, Tuple[Union[str, Type[FeatureProcessor]], Dict[str, Any]]]], dataset_name: Optional[str] = None, @@ -56,7 +60,7 @@ def __init__( dataset_name = "" if task_name is None: task_name = "" - self.samples = samples + self.dataset = dataset self.input_schema = input_schema self.output_schema = output_schema self.input_processors = input_processors if input_processors is not None else {} @@ -69,7 +73,7 @@ def __init__( self.patient_to_index = {} self.record_to_index = {} - for i, sample in enumerate(samples): + for i, sample in enumerate(iter(self.dataset)): # Create patient_to_index mapping patient_id = sample.get("patient_id") if patient_id is not None: @@ -87,6 +91,13 @@ def __init__( self.validate() self.build() + # Apply processors + self.dataset = StreamingDataset( + input_dir=self.dataset.input_dir, + cache_dir=self.dataset.cache_dir, + transform=self.transform, + ) + def _get_processor_instance(self, processor_spec): """Get processor instance from either string alias, class reference, processor instance, or tuple with kwargs. @@ -128,7 +139,7 @@ def validate(self) -> None: """Validates that the samples match the input and output schemas.""" input_keys = set(self.input_schema.keys()) output_keys = set(self.output_schema.keys()) - for s in self.samples: + for s in iter(self.dataset): assert input_keys.issubset(s.keys()), "Input schema does not match samples." assert output_keys.issubset(s.keys()), ( "Output schema does not match samples." @@ -141,45 +152,153 @@ def build(self) -> None: if not self.input_processors: for k, v in self.input_schema.items(): self.input_processors[k] = self._get_processor_instance(v) - self.input_processors[k].fit(self.samples, k) + self.input_processors[k].fit(iter(self.dataset), k) if not self.output_processors: for k, v in self.output_schema.items(): self.output_processors[k] = self._get_processor_instance(v) - self.output_processors[k].fit(self.samples, k) - # Always process samples with the (fitted) processors - for sample in tqdm(self.samples, desc="Processing samples"): - for k, v in sample.items(): - if k in self.input_processors: - sample[k] = self.input_processors[k].process(v) - elif k in self.output_processors: - sample[k] = self.output_processors[k].process(v) + self.output_processors[k].fit(iter(self.dataset), k) return - def __getitem__(self, index: int) -> Dict: - """Returns a sample by index. + def transform(self, sample) -> Dict: + """Applies the input and output processors to a sample.""" + for k, v in sample.items(): + if k in self.input_processors: + sample[k] = self.input_processors[k].process(pickle.loads(v)) + elif k in self.output_processors: + sample[k] = self.output_processors[k].process(pickle.loads(v)) + else: + sample[k] = pickle.loads(v) + return sample - Args: - index (int): Index of the sample to retrieve. + def set_shuffle(self, shuffle: bool) -> None: + """Sets whether to shuffle the dataset during iteration. - Returns: - Dict: A dict with patient_id, visit_id/record_id, and other - task-specific attributes as key. Conversion to index/tensor - will be done in the model. + Args: + shuffle (bool): Whether to shuffle the dataset. """ - return self.samples[index] + self.dataset.set_shuffle(shuffle) + return - def __str__(self) -> str: - """Returns a string representation of the dataset. + def __iter__(self) -> Iterator: + """Returns an iterator over the dataset samples.""" + return self.dataset.__iter__() - Returns: - str: A string with the dataset and task names. - """ + def __getitem__(self, index: int) -> Dict: + """Gets a sample by index. Try to use iterator for better performance.""" + return self.dataset.__getitem__(index) + + def __str__(self) -> str: + """String representation of the SampleDataset.""" return f"Sample dataset {self.dataset_name} {self.task_name}" def __len__(self) -> int: - """Returns the number of samples in the dataset. + """Returns the number of samples in the dataset.""" + return self.dataset.__len__() - Returns: - int: The number of samples. +class SampleSubset(IterableDataset): + """A subset of the SampleDataset. + + Args: + sample_dataset (SampleDataset): The original SampleDataset. + indices (List[int]): List of indices to include in the subset. + """ + + def __init__(self, dataset: SampleDataset, indices: Sequence[int]) -> None: + self.dataset_name = dataset.dataset_name + self.task_name = dataset.task_name + + base_dataset: StreamingDataset = deepcopy_dataset(dataset.dataset) + base_dataset.set_shuffle(False) # Disable shuffling when creating subset + self.dataset, self._length = self._build_subset_dataset( + base_dataset, indices + ) + + def _build_subset_dataset( + self, base_dataset: StreamingDataset, indices: Sequence[int] + ) -> Tuple[StreamingDataset, int]: + """Create a StreamingDataset restricted to the provided indices.""" + if len(base_dataset.subsampled_files) != len(base_dataset.region_of_interest): + raise ValueError( + "The provided dataset has mismatched subsampled_files and region_of_interest lengths." + ) + + dataset_length = sum( + end - start for start, end in base_dataset.region_of_interest + ) + if any(idx < 0 or idx >= dataset_length for idx in indices): + raise ValueError( + f"Subset indices must be in [0, {dataset_length - 1}] for the provided dataset." + ) + + # Build chunk boundaries so we can translate global indices into + # chunk-local (start, end) pairs that litdata understands. + chunk_starts: List[int] = [] + chunk_boundaries: List[Tuple[str, int, int, int, int]] = [] + cursor = 0 + for filename, (roi_start, roi_end) in zip( + base_dataset.subsampled_files, base_dataset.region_of_interest + ): + chunk_len = roi_end - roi_start + if chunk_len <= 0: + continue + chunk_starts.append(cursor) + chunk_boundaries.append( + (filename, roi_start, roi_end, cursor, cursor + chunk_len) + ) + cursor += chunk_len + + new_subsampled_files: List[str] = [] + new_roi: List[Tuple[int, int]] = [] + prev_chunk_idx: Optional[int] = None + + for idx in indices: + chunk_idx = bisect_right(chunk_starts, idx) - 1 + if chunk_idx < 0 or idx >= chunk_boundaries[chunk_idx][4]: + raise ValueError(f"Index {idx} is out of bounds for the dataset.") + + filename, roi_start, _, global_start, _ = chunk_boundaries[chunk_idx] + offset_in_chunk = roi_start + (idx - global_start) + + if ( + new_roi + and prev_chunk_idx == chunk_idx + and offset_in_chunk == new_roi[-1][1] + ): + new_roi[-1] = (new_roi[-1][0], new_roi[-1][1] + 1) + else: + new_subsampled_files.append(filename) + new_roi.append((offset_in_chunk, offset_in_chunk + 1)) + + prev_chunk_idx = chunk_idx + + base_dataset.subsampled_files = new_subsampled_files + base_dataset.region_of_interest = new_roi + base_dataset.reset() + subset_length = sum(end - start for start, end in new_roi) + + return base_dataset, subset_length + + def set_shuffle(self, shuffle: bool) -> None: + """Sets whether to shuffle the dataset during iteration. + + Args: + shuffle (bool): Whether to shuffle the dataset. """ - return len(self.samples) + self.dataset.set_shuffle(shuffle) + return + + def __iter__(self) -> Iterator: + """Returns an iterator over the dataset samples.""" + return self.dataset.__iter__() + + def __getitem__(self, index: int) -> Dict: + """Gets a sample by index. Try to use iterator for better performance.""" + return self.dataset.__getitem__(index) + + def __str__(self) -> str: + """String representation of the SampleDataset.""" + return f"Sample dataset {self.dataset_name} {self.task_name} subset" + + def __len__(self) -> int: + """Returns the number of samples in the dataset.""" + return self._length diff --git a/pyhealth/datasets/splitter.py b/pyhealth/datasets/splitter.py index c70df5660..f9e3c7213 100644 --- a/pyhealth/datasets/splitter.py +++ b/pyhealth/datasets/splitter.py @@ -4,7 +4,7 @@ import numpy as np import torch -from .sample_dataset import SampleDataset +from .sample_dataset import SampleDataset, SampleSubset # TODO: train_dataset.dataset still access the whole dataset which may leak information # TODO: add more splitting methods @@ -24,11 +24,7 @@ def split_by_visit( Returns: train_dataset, val_dataset, test_dataset: three subsets of the dataset of - type `torch.utils.data.Subset`. - - Note: - The original dataset can be accessed by `train_dataset.dataset`, - `val_dataset.dataset`, and `test_dataset.dataset`. + type `SampleSubset`. """ if seed is not None: np.random.seed(seed) @@ -40,9 +36,9 @@ def split_by_visit( int(len(dataset) * ratios[0]) : int(len(dataset) * (ratios[0] + ratios[1])) ] test_index = index[int(len(dataset) * (ratios[0] + ratios[1])) :] - train_dataset = torch.utils.data.Subset(dataset, train_index) - val_dataset = torch.utils.data.Subset(dataset, val_index) - test_dataset = torch.utils.data.Subset(dataset, test_index) + train_dataset = SampleSubset(dataset, train_index) # type: ignore + val_dataset = SampleSubset(dataset, val_index) # type: ignore + test_dataset = SampleSubset(dataset, test_index) # type: ignore return train_dataset, val_dataset, test_dataset @@ -60,11 +56,7 @@ def split_by_patient( Returns: train_dataset, val_dataset, test_dataset: three subsets of the dataset of - type `torch.utils.data.Subset`. - - Note: - The original dataset can be accessed by `train_dataset.dataset`, - `val_dataset.dataset`, and `test_dataset.dataset`. + type `SampleSubset`. """ if seed is not None: np.random.seed(seed) @@ -82,9 +74,9 @@ def split_by_patient( ) val_index = list(chain(*[dataset.patient_to_index[i] for i in val_patient_indx])) test_index = list(chain(*[dataset.patient_to_index[i] for i in test_patient_indx])) - train_dataset = torch.utils.data.Subset(dataset, train_index) - val_dataset = torch.utils.data.Subset(dataset, val_index) - test_dataset = torch.utils.data.Subset(dataset, test_index) + train_dataset = SampleSubset(dataset, train_index) + val_dataset = SampleSubset(dataset, val_index) + test_dataset = SampleSubset(dataset, test_index) return train_dataset, val_dataset, test_dataset @@ -103,11 +95,7 @@ def split_by_sample( Returns: train_dataset, val_dataset, test_dataset: three subsets of the dataset of - type `torch.utils.data.Subset`. - - Note: - The original dataset can be accessed by `train_dataset.dataset`, - `val_dataset.dataset`, and `test_dataset.dataset`. + type `SampleSubset`. """ if seed is not None: np.random.seed(seed) @@ -119,9 +107,9 @@ def split_by_sample( int(len(dataset) * ratios[0]) : int(len(dataset) * (ratios[0] + ratios[1])) ] test_index = index[int(len(dataset) * (ratios[0] + ratios[1])) :] - train_dataset = torch.utils.data.Subset(dataset, train_index) - val_dataset = torch.utils.data.Subset(dataset, val_index) - test_dataset = torch.utils.data.Subset(dataset, test_index) + train_dataset = SampleSubset(dataset, train_index.tolist()) + val_dataset = SampleSubset(dataset, val_index.tolist()) + test_dataset = SampleSubset(dataset, test_index.tolist()) if get_index: return ( @@ -147,12 +135,7 @@ def split_by_visit_conformal( Returns: train_dataset, val_dataset, cal_dataset, test_dataset: four subsets - of the dataset of type `torch.utils.data.Subset`. - - Note: - The original dataset can be accessed by `train_dataset.dataset`, - `val_dataset.dataset`, `cal_dataset.dataset`, and - `test_dataset.dataset`. + of the dataset of type `SampleSubset`. """ if seed is not None: np.random.seed(seed) @@ -172,10 +155,10 @@ def split_by_visit_conformal( cal_index = index[val_end:cal_end] test_index = index[cal_end:] - train_dataset = torch.utils.data.Subset(dataset, train_index) - val_dataset = torch.utils.data.Subset(dataset, val_index) - cal_dataset = torch.utils.data.Subset(dataset, cal_index) - test_dataset = torch.utils.data.Subset(dataset, test_index) + train_dataset = SampleSubset(dataset, train_index) # type: ignore + val_dataset = SampleSubset(dataset, val_index) # type: ignore + cal_dataset = SampleSubset(dataset, cal_index) # type: ignore + test_dataset = SampleSubset(dataset, test_index) # type: ignore return train_dataset, val_dataset, cal_dataset, test_dataset @@ -194,12 +177,7 @@ def split_by_patient_conformal( Returns: train_dataset, val_dataset, cal_dataset, test_dataset: four subsets - of the dataset of type `torch.utils.data.Subset`. - - Note: - The original dataset can be accessed by `train_dataset.dataset`, - `val_dataset.dataset`, `cal_dataset.dataset`, and - `test_dataset.dataset`. + of the dataset of type `SampleSubset`. """ if seed is not None: np.random.seed(seed) @@ -227,10 +205,10 @@ def split_by_patient_conformal( cal_index = list(chain(*[dataset.patient_to_index[i] for i in cal_patient_indx])) test_index = list(chain(*[dataset.patient_to_index[i] for i in test_patient_indx])) - train_dataset = torch.utils.data.Subset(dataset, train_index) - val_dataset = torch.utils.data.Subset(dataset, val_index) - cal_dataset = torch.utils.data.Subset(dataset, cal_index) - test_dataset = torch.utils.data.Subset(dataset, test_index) + train_dataset = SampleSubset(dataset, train_index) # type: ignore + val_dataset = SampleSubset(dataset, val_index) # type: ignore + cal_dataset = SampleSubset(dataset, cal_index) # type: ignore + test_dataset = SampleSubset(dataset, test_index) # type: ignore return train_dataset, val_dataset, cal_dataset, test_dataset @@ -251,13 +229,8 @@ def split_by_sample_conformal( Returns: train_dataset, val_dataset, cal_dataset, test_dataset: four subsets - of the dataset of type `torch.utils.data.Subset`, or four tensors + of the dataset of type `SampleSubset`, or four tensors of indices if get_index=True. - - Note: - The original dataset can be accessed by `train_dataset.dataset`, - `val_dataset.dataset`, `cal_dataset.dataset`, and - `test_dataset.dataset`. """ if seed is not None: np.random.seed(seed) @@ -285,8 +258,8 @@ def split_by_sample_conformal( torch.tensor(test_index), ) else: - train_dataset = torch.utils.data.Subset(dataset, train_index) - val_dataset = torch.utils.data.Subset(dataset, val_index) - cal_dataset = torch.utils.data.Subset(dataset, cal_index) - test_dataset = torch.utils.data.Subset(dataset, test_index) + train_dataset = SampleSubset(dataset, train_index) # type: ignore + val_dataset = SampleSubset(dataset, val_index) # type: ignore + cal_dataset = SampleSubset(dataset, cal_index) # type: ignore + test_dataset = SampleSubset(dataset, test_index) # type: ignore return train_dataset, val_dataset, cal_dataset, test_dataset diff --git a/pyhealth/datasets/utils.py b/pyhealth/datasets/utils.py index 63ca4152a..f45ca9c98 100644 --- a/pyhealth/datasets/utils.py +++ b/pyhealth/datasets/utils.py @@ -10,6 +10,7 @@ from torch.utils.data import DataLoader from pyhealth import BASE_CACHE_PATH +from pyhealth.datasets.sample_dataset import SampleDataset, SampleSubset from pyhealth.utils import create_directory MODULE_CACHE_PATH = os.path.join(BASE_CACHE_PATH, "datasets") @@ -319,7 +320,7 @@ def collate_fn_dict_with_padding(batch: List[dict]) -> dict: def get_dataloader( - dataset: torch.utils.data.Dataset, batch_size: int, shuffle: bool = False + dataset: SampleDataset | SampleSubset, batch_size: int, shuffle: bool = False ) -> DataLoader: """Creates a DataLoader for a given dataset. @@ -331,10 +332,10 @@ def get_dataloader( Returns: A DataLoader instance for the dataset. """ + dataset.set_shuffle(shuffle) dataloader = DataLoader( dataset, batch_size=batch_size, - shuffle=shuffle, collate_fn=collate_fn_dict_with_padding, ) diff --git a/pyhealth/processors/base_processor.py b/pyhealth/processors/base_processor.py index 050cb5357..d207f0220 100644 --- a/pyhealth/processors/base_processor.py +++ b/pyhealth/processors/base_processor.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Iterator class Processor(ABC): @@ -33,7 +33,7 @@ class FeatureProcessor(Processor): Example: Tokenization, image loading, normalization. """ - def fit(self, samples: List[Dict[str, Any]], field: str) -> None: + def fit(self, samples: Iterator[Dict[str, Any]], field: str) -> None: """Fit the processor to the samples. Args: diff --git a/pyhealth/processors/label_processor.py b/pyhealth/processors/label_processor.py index ad2df1897..a9f9937d9 100644 --- a/pyhealth/processors/label_processor.py +++ b/pyhealth/processors/label_processor.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict, List +from typing import Any, Dict, Iterator, List import torch @@ -19,7 +19,7 @@ def __init__(self): super().__init__() self.label_vocab: Dict[Any, int] = {} - def fit(self, samples: List[Dict[str, Any]], field: str) -> None: + def fit(self, samples: Iterator[Dict[str, Any]], field: str) -> None: all_labels = set([sample[field] for sample in samples]) if len(all_labels) != 2: raise ValueError(f"Expected 2 unique labels, got {len(all_labels)}") diff --git a/pyhealth/processors/stagenet_processor.py b/pyhealth/processors/stagenet_processor.py index cbbafac94..420163f51 100644 --- a/pyhealth/processors/stagenet_processor.py +++ b/pyhealth/processors/stagenet_processor.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple import torch @@ -61,7 +61,7 @@ def __init__(self, padding: int = 0): self._max_nested_len = None # Max inner sequence length for nested codes self._padding = padding # Additional padding beyond observed max - def fit(self, samples: List[Dict], key: str) -> None: + def fit(self, samples: Iterator[dict[str, Any]], field: str) -> None: """Build vocabulary and determine input structure. Args: @@ -70,9 +70,9 @@ def fit(self, samples: List[Dict], key: str) -> None: """ # Examine first non-None sample to determine structure for sample in samples: - if key in sample and sample[key] is not None: + if field in sample and sample[field] is not None: # Unpack tuple: (time, values) - time_data, value_data = sample[key] + time_data, value_data = sample[field] # Determine nesting level for codes if isinstance(value_data, list) and len(value_data) > 0: @@ -90,9 +90,9 @@ def fit(self, samples: List[Dict], key: str) -> None: # Build vocabulary for codes and find max nested length max_inner_len = 0 for sample in samples: - if key in sample and sample[key] is not None: + if field in sample and sample[field] is not None: # Unpack tuple: (time, values) - time_data, value_data = sample[key] + time_data, value_data = sample[field] if self._is_nested: # Nested codes @@ -256,7 +256,7 @@ def __init__(self): self._size = None # Feature dimension (set during fit) self._is_nested = None - def fit(self, samples: List[Dict], key: str) -> None: + def fit(self, samples: Iterator[dict[str, Any]], field: str) -> None: """Determine input structure. Args: @@ -265,9 +265,9 @@ def fit(self, samples: List[Dict], key: str) -> None: """ # Examine first non-None sample to determine structure for sample in samples: - if key in sample and sample[key] is not None: + if field in sample and sample[field] is not None: # Unpack tuple: (time, values) - time_data, value_data = sample[key] + time_data, value_data = sample[field] # Determine nesting level for numerics if isinstance(value_data, list) and len(value_data) > 0: diff --git a/pyhealth/tasks/base_task.py b/pyhealth/tasks/base_task.py index 888c7e2e1..cd8c79adf 100644 --- a/pyhealth/tasks/base_task.py +++ b/pyhealth/tasks/base_task.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Dict, List, Union, Type -import polars as pl +import pandas as pd class BaseTask(ABC): @@ -9,7 +9,7 @@ class BaseTask(ABC): input_schema: Dict[str, Union[str, Type]] output_schema: Dict[str, Union[str, Type]] - def pre_filter(self, df: pl.LazyFrame) -> pl.LazyFrame: + def pre_filter(self, df: pd.DataFrame) -> pd.DataFrame: return df @abstractmethod diff --git a/pyhealth/tasks/mortality_prediction_stagenet_mimic4.py b/pyhealth/tasks/mortality_prediction_stagenet_mimic4.py index fc9c58f7f..437671067 100644 --- a/pyhealth/tasks/mortality_prediction_stagenet_mimic4.py +++ b/pyhealth/tasks/mortality_prediction_stagenet_mimic4.py @@ -1,7 +1,9 @@ from datetime import datetime -from typing import Any, ClassVar, Dict, List +from typing import Any, ClassVar, Dict, List, Type -import polars as pl +import pandas as pd + +from pyhealth.data.data import Patient from .base_task import BaseTask @@ -35,11 +37,11 @@ class MortalityPredictionStageNetMIMIC4(BaseTask): """ task_name: str = "MortalityPredictionStageNetMIMIC4" - input_schema: Dict[str, str] = { + input_schema: Dict[str, str | Type] = { "icd_codes": "stagenet", "labs": "stagenet_tensor", } - output_schema: Dict[str, str] = {"mortality": "binary"} + output_schema: Dict[str, str | Type] = {"mortality": "binary"} # Organize lab items by category # Each category will map to ONE dimension in the output vector @@ -75,7 +77,7 @@ class MortalityPredictionStageNetMIMIC4(BaseTask): item for itemids in LAB_CATEGORIES.values() for item in itemids ] - def __call__(self, patient: Any) -> List[Dict[str, Any]]: + def __call__(self, patient: Patient) -> List[Dict[str, Any]]: """Process a patient to create mortality prediction samples. Creates ONE sample per patient with all admissions aggregated. @@ -89,7 +91,7 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: procedures, labs across visits, and final mortality label """ # Filter patients by age (>= 18) - demographics = patient.get_events(event_type="patients") + demographics = patient.get_events_py(event_type="patients") if not demographics: return [] @@ -103,7 +105,7 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: return [] # Get all admissions - admissions = patient.get_events(event_type="admissions") + admissions = patient.get_events_py(event_type="admissions") if len(admissions) < 1: return [] @@ -156,7 +158,7 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: pass # Get diagnosis codes for this admission using hadm_id - diagnoses_icd = patient.get_events( + diagnoses_icd = patient.get_events_py( event_type="diagnoses_icd", filters=[("hadm_id", "==", admission.hadm_id)], ) @@ -167,7 +169,7 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: ] # Get procedure codes for this admission using hadm_id - procedures_icd = patient.get_events( + procedures_icd = patient.get_events_py( event_type="procedures_icd", filters=[("hadm_id", "==", admission.hadm_id)], ) @@ -185,46 +187,51 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: all_icd_times.append(time_from_previous) # Get lab events for this admission - labevents_df = patient.get_events( + labevents_df = patient.get_events_df( event_type="labevents", start=admission_time, end=admission_dischtime, - return_df=True, ) # Filter to relevant lab items - labevents_df = labevents_df.filter( - pl.col("labevents/itemid").is_in(self.LABITEMS) - ) + labevents_df = labevents_df[ + labevents_df["labevents/itemid"].isin(self.LABITEMS) + ] # Parse storetime and filter - if labevents_df.height > 0: - labevents_df = labevents_df.with_columns( - pl.col("labevents/storetime").str.strptime( - pl.Datetime, "%Y-%m-%d %H:%M:%S" - ) - ) - labevents_df = labevents_df.filter( - pl.col("labevents/storetime") <= admission_dischtime + if len(labevents_df) > 0: + labevents_df = labevents_df.copy() + labevents_df["labevents/storetime"] = pd.to_datetime( + labevents_df["labevents/storetime"], + format="%Y-%m-%d %H:%M:%S", + errors="coerce", ) + labevents_df = labevents_df[ + labevents_df["labevents/storetime"] <= admission_dischtime + ] - if labevents_df.height > 0: + if len(labevents_df) > 0: # Select relevant columns - labevents_df = labevents_df.select( - pl.col("timestamp"), - pl.col("labevents/itemid"), - pl.col("labevents/valuenum").cast(pl.Float64), + labevents_df = labevents_df.loc[ + :, ["timestamp", "labevents/itemid", "labevents/valuenum"] + ].copy() + labevents_df["labevents/valuenum"] = pd.to_numeric( + labevents_df["labevents/valuenum"] + .astype(str) + .str.strip() + .replace("", pd.NA), # type: ignore + errors="coerce", ) # Group by timestamp and aggregate into 10D vectors # For each timestamp, create vector of lab categories unique_timestamps = sorted( - labevents_df["timestamp"].unique().to_list() + labevents_df["timestamp"].unique().tolist() ) for lab_ts in unique_timestamps: # Get all lab events at this timestamp - ts_labs = labevents_df.filter(pl.col("timestamp") == lab_ts) + ts_labs = labevents_df[labevents_df["timestamp"] == lab_ts] # Create 10-dimensional vector (one per category) lab_vector = [] @@ -234,11 +241,13 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: # Find first matching value for this category category_value = None for itemid in category_itemids: - matching = ts_labs.filter( - pl.col("labevents/itemid") == itemid - ) - if matching.height > 0: - category_value = matching["labevents/valuenum"][0] + matching = ts_labs[ + ts_labs["labevents/itemid"] == itemid + ] + if len(matching) > 0: + category_value = matching[ + "labevents/valuenum" + ].iloc[0] break lab_vector.append(category_value) diff --git a/pyproject.toml b/pyproject.toml index ceedcdd0b..9c3ac0f7b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,9 @@ dependencies = [ "pandas~=2.3.1", "pandarallel~=1.6.5", "pydantic~=2.11.7", + "dask[complete]~=2025.11.0", + "litdata~=0.2.58", + "xxhash~=3.6.0", ] license = "BSD-3-Clause" license-files = ["LICENSE.md"] diff --git a/tests/core/test_patient_event.py b/tests/core/test_patient_event.py new file mode 100644 index 000000000..1419f4b0c --- /dev/null +++ b/tests/core/test_patient_event.py @@ -0,0 +1,103 @@ +from datetime import datetime, timedelta +import pandas as pd + +from tests.base import BaseTestCase +from pyhealth.data.data import Event, Patient + + +class PatientEventTestCase(BaseTestCase): + def setUp(self): + base_time = datetime(2024, 1, 1, 12, 0, 0) + self.event_rows = [ + { + "patient_id": "p1", + "event_type": "med", + "timestamp": base_time + timedelta(days=1), + "med/dose": 10, + }, + { + "patient_id": "p1", + "event_type": "diag", + "timestamp": base_time + timedelta(days=2), + "diag/code": "A", + }, + { + "patient_id": "p1", + "event_type": "med", + "timestamp": base_time + timedelta(days=4), + "med/dose": 20, + }, + { + "patient_id": "p1", + "event_type": "diag", + "timestamp": base_time + timedelta(days=3), + "diag/code": "B", + "diag/severity": 2, + }, + { + "patient_id": "p1", + "event_type": "lab", + "timestamp": pd.NaT, + "lab/value": 99, + }, + ] + unsorted_df = pd.DataFrame(self.event_rows) + self.patient = Patient(patient_id="p1", data_source=unsorted_df) + self.base_time = base_time + + def test_event_accessors_and_from_dict(self): + ts = datetime(2024, 5, 1, 8, 0, 0) + event = Event(event_type="diag", timestamp=ts, code="X", value=1) + self.assertEqual(event["timestamp"], ts) + self.assertEqual(event["event_type"], "diag") + self.assertEqual(event["code"], "X") + self.assertIn("value", event) + self.assertEqual(event.value, 1) + + raw = {"event_type": "diag", "timestamp": ts, "diag/code": "Y", "diag/score": 5} + reconstructed = Event.from_dict(raw) + self.assertEqual(reconstructed.event_type, "diag") + self.assertEqual(reconstructed.attr_dict, {"code": "Y", "score": 5}) + + def test_patient_sorting_and_partitions(self): + # timestamps should be sorted with NaT at the end + timestamps = list(self.patient.data_source["timestamp"]) + sorted_without_nat = sorted([ts for ts in timestamps if pd.notna(ts)]) + self.assertEqual(timestamps[:-1], sorted_without_nat) + self.assertTrue(pd.isna(timestamps[-1])) + + diag_partition = self.patient.event_type_partitions[("diag",)] + self.assertListEqual(list(diag_partition["diag/code"]), ["A", "B"]) + + med_partition = self.patient.event_type_partitions[("med",)] + self.assertListEqual(list(med_partition["med/dose"]), [10, 20]) + + def test_get_events_by_type_and_time(self): + diag_df = self.patient.get_events(event_type="diag", return_df=True) + self.assertEqual(len(diag_df), 2) + self.assertListEqual(list(diag_df["diag/code"]), ["A", "B"]) + + # Time filtering should include bounds and drop NaT + start = self.base_time + timedelta(days=2) + end = self.base_time + timedelta(days=4) + ranged = self.patient.get_events(start=start, end=end, return_df=True) + self.assertListEqual(list(ranged["event_type"]), ["diag", "diag", "med"]) + self.assertTrue(ranged["timestamp"].between(start, end, inclusive="both").all()) + self.assertFalse(ranged["timestamp"].isna().any()) + + def test_attribute_filters(self): + filtered_events = self.patient.get_events( + event_type="diag", filters=[("code", "==", "B")] + ) + self.assertEqual(len(filtered_events), 1) + self.assertEqual(filtered_events[0].attr_dict, {"code": "B", "severity": 2}) + + filtered_df = self.patient.get_events( + event_type="diag", filters=[("code", "!=", "B")], return_df=True + ) + self.assertEqual(len(filtered_df), 1) + self.assertEqual(filtered_df.iloc[0]["diag/code"], "A") + + def test_filters_require_event_type(self): + with self.assertRaises(AssertionError): + self.patient.get_events(filters=[("code", "==", "A")]) diff --git a/tests/core/test_sample_dataset.py b/tests/core/test_sample_dataset.py new file mode 100644 index 000000000..04431278f --- /dev/null +++ b/tests/core/test_sample_dataset.py @@ -0,0 +1,161 @@ +import pickle +import tempfile +import unittest +from pathlib import Path +from typing import Any, Dict, Iterator, List + +from litdata import StreamingDataset, optimize + +from pyhealth.datasets.sample_dataset import SampleDataset, SampleSubset +from pyhealth.processors.base_processor import FeatureProcessor + +# Top-level identity function for litdata.optimize (must be picklable). +def _identity(sample: Dict[str, Any]) -> Dict[str, Any]: + return sample + + +class RecordingProcessor(FeatureProcessor): + """Processor that records fit/process calls and prefixes outputs.""" + + def __init__(self, prefix: str) -> None: + self.prefix = prefix + self.fit_called = False + self.fit_seen: List[Any] = [] + self.process_seen: List[Any] = [] + + def fit(self, samples: Iterator[Dict[str, Any]], field: str) -> None: + self.fit_called = True + for sample in samples: + self.fit_seen.append(pickle.loads(sample[field])) + + def process(self, value: Any) -> Any: + self.process_seen.append(value) + return f"{self.prefix}-{value}" + + +class TestSampleDatasetAndSubset(unittest.TestCase): + def setUp(self) -> None: + super().setUp() + self.tmpdir = tempfile.TemporaryDirectory() + self.output_dir = Path(self.tmpdir.name) / "stream" + + raw_samples = [ + {"patient_id": "p1", "record_id": "r1", "x": 1, "y": 10}, + {"patient_id": "p2", "record_id": "r2", "x": 2, "y": 20}, + {"patient_id": "p3", "record_id": "r3", "x": 3, "y": 30}, + ] + pickled_samples = [{k: pickle.dumps(v) for k, v in sample.items()} for sample in raw_samples] + + optimize( + fn=_identity, + inputs=pickled_samples, + output_dir=str(self.output_dir), + chunk_size=len(pickled_samples), + num_workers=1, + verbose=False, + ) + + streaming_dataset = StreamingDataset( + input_dir=str(self.output_dir), + cache_dir=str(self.output_dir), + ) + + self.sample_dataset = SampleDataset( + dataset=streaming_dataset, + input_schema={"x": (RecordingProcessor, {"prefix": "in"})}, + output_schema={"y": (RecordingProcessor, {"prefix": "out"})}, + dataset_name="test_dataset", + task_name="task", + ) + + self.raw_samples = raw_samples + + def tearDown(self) -> None: + self.tmpdir.cleanup() + super().tearDown() + + def test_sample_dataset_builds_processors(self) -> None: + self.assertIn("x", self.sample_dataset.input_processors) + self.assertIn("y", self.sample_dataset.output_processors) + + input_proc: RecordingProcessor = self.sample_dataset.input_processors["x"] # type: ignore + output_proc: RecordingProcessor = self.sample_dataset.output_processors["y"] # type: ignore + + self.assertTrue(input_proc.fit_called) + self.assertTrue(output_proc.fit_called) + self.assertEqual(input_proc.fit_seen, [1, 2, 3]) + self.assertEqual(output_proc.fit_seen, [10, 20, 30]) + + def test_sample_dataset_returns_processed_items(self) -> None: + # __getitem__ path + item = self.sample_dataset[0] + self.assertEqual(item["x"], "in-1") + self.assertEqual(item["y"], "out-10") + self.assertEqual(item["patient_id"], "p1") + self.assertEqual(item["record_id"], "r1") + + # __iter__ path + self.sample_dataset.dataset.reset() + items = list(iter(self.sample_dataset)) + self.assertEqual([s["x"] for s in items], ["in-1", "in-2", "in-3"]) + self.assertEqual([s["y"] for s in items], ["out-10", "out-20", "out-30"]) + self.assertEqual([s["patient_id"] for s in items], ["p1", "p2", "p3"]) + + def test_sample_subset_respects_indices_and_processing(self) -> None: + subset_indices = [1, 2] + subset = SampleSubset(self.sample_dataset, subset_indices) + + self.assertEqual(len(subset), len(subset_indices)) + + # __getitem__ path + second = subset[0] + third = subset[1] + + self.assertEqual(second["x"], "in-2") + self.assertEqual(second["y"], "out-20") + self.assertEqual(second["patient_id"], "p2") + + self.assertEqual(third["x"], "in-3") + self.assertEqual(third["y"], "out-30") + self.assertEqual(third["patient_id"], "p3") + + # __iter__ path + subset.dataset.reset() + iter_items = list(iter(subset)) + self.assertEqual(len(iter_items), 2) + self.assertEqual([s["x"] for s in iter_items], ["in-2", "in-3"]) + self.assertEqual([s["y"] for s in iter_items], ["out-20", "out-30"]) + self.assertEqual([s["patient_id"] for s in iter_items], ["p2", "p3"]) + + def test_shuffle_behavior_and_isolation(self) -> None: + # Baseline (no shuffle) + baseline = [s["patient_id"] for s in iter(self.sample_dataset)] + self.sample_dataset.dataset.reset() + + # Shuffle affects iteration but not __getitem__ + self.sample_dataset.set_shuffle(True) + shuffled_iter = [s["patient_id"] for s in iter(self.sample_dataset)] + self.assertCountEqual(shuffled_iter, baseline) + if len(baseline) > 1: + self.assertNotEqual(shuffled_iter, baseline) + self.sample_dataset.dataset.reset() + self.assertEqual(self.sample_dataset[0]["patient_id"], "p1") + + # Subset created from shuffled dataset should disable shuffle during construction + subset = SampleSubset(self.sample_dataset, [0, 1]) + self.assertFalse(subset.dataset.shuffle) + subset_items = [s["patient_id"] for s in iter(subset)] + self.assertEqual(subset_items, ["p1", "p2"]) + subset.dataset.reset() + self.assertEqual(subset[0]["patient_id"], "p1") + + # Shuffling one subset doesn't affect dataset or other subsets + subset2 = SampleSubset(self.sample_dataset, [1, 2]) + subset.set_shuffle(True) + shuffled_subset_iter = [s["patient_id"] for s in iter(subset)] + self.assertCountEqual(shuffled_subset_iter, ["p1", "p2"]) + if len(shuffled_subset_iter) > 1: + self.assertNotEqual(shuffled_subset_iter, ["p1", "p2"]) + self.assertFalse(subset2.dataset.shuffle) + self.assertEqual(subset2[0]["patient_id"], "p2") + self.assertTrue(self.sample_dataset.dataset.shuffle)