Skip to content

Conversation

@LogicFan
Copy link
Collaborator

@LogicFan LogicFan commented Nov 21, 2025

This should be able to close #332 .

LogicFan and others added 30 commits November 21, 2025 02:12
Co-authored-by: John Wu <[email protected]>
Co-authored-by: John Wu <[email protected]>
Co-authored-by: John Wu <[email protected]>
Co-authored-by: John Wu <[email protected]>
Co-authored-by: John Wu <[email protected]>
Co-authored-by: John Wu <[email protected]>
Co-authored-by: John Wu <[email protected]>
Co-authored-by: John Wu <[email protected]>
Co-authored-by: John Wu <[email protected]>
@LogicFan LogicFan changed the title Mem 5 [Memory] Better Memory Utilization for PyHealth Nov 23, 2025
Copy link
Collaborator

@jhnwu3 jhnwu3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey do you know which dask version you used? Probably would be good to add this to the pyprojectoml.

Will look deeper into the CI and other implementation details later, have to go wander Japan for a bit 😂

Comment on lines +42 to +44
"dask[complete]~=2025.11.0",
"litdata~=0.2.58",
"xxhash~=3.6.0",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New dependencies are specified here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah im blind.

event_type: str
timestamp: datetime
attr_dict: Mapping[str, any] = field(default_factory=dict)
attr_dict: Mapping[str, Any] = field(default_factory=dict)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small type hint fix here, any is a function for iterables, and Any is the correct type hint.

"""

def __init__(self, patient_id: str, data_source: pl.DataFrame) -> None:
def __init__(self, patient_id: str, data_source: pd.DataFrame) -> None:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The polars will spawn processes. Given this will be ran in litdata.optimize in a multi-processes environment, nested multi-process will cause a hang. Thus, we can this to pandas.

Comment on lines 172 to +183

def get_events_py(self, **kawargs) -> List[Event]:
"""Type-safe wrapper for get_events."""
return self.get_events(**kawargs, return_df=False) # type: ignore

def get_events_df(self, **kawargs) -> pd.DataFrame:
"""DataFrame wrapper for get_events."""
return self.get_events(**kawargs, return_df=True) # type: ignore

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added for type hint purpose.

Comment on lines +190 to 194
labevents_df = patient.get_events_df(
event_type="labevents",
start=admission_time,
end=admission_dischtime,
return_df=True,
)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It returns a pandas dataframe here, any task that uses get_event_df (or get_event(..., return_df=True)) will need to update their code to be compitable with pandas.

output_schema: Dict[str, Union[str, Type]]

def pre_filter(self, df: pl.LazyFrame) -> pl.LazyFrame:
def pre_filter(self, df: dd.DataFrame) -> dd.DataFrame:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It becomes dask dataframe here, any task override this need to update.

Comment on lines +335 to 341
dataset.set_shuffle(shuffle)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
collate_fn=collate_fn_dict_with_padding,
)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

python native dataloader does not support shuffle on IterableDataset, it must be shuffled at dataset level.

Comment on lines +39 to +41
train_dataset = SampleSubset(dataset, train_index) # type: ignore
val_dataset = SampleSubset(dataset, val_index) # type: ignore
test_dataset = SampleSubset(dataset, test_index) # type: ignore
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch native Subset does not support iterable dataset, we use a custom-defined iterable dataset here.

Comment on lines +198 to +205
class SampleSubset(IterableDataset):
"""A subset of the SampleDataset.
Args:
sample_dataset (SampleDataset): The original SampleDataset.
indices (List[int]): List of indices to include in the subset.
"""

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add new SampleSubset to support creating subset of a SampleDataset for train/val/test split.

Comment on lines +216 to +220
def _build_subset_dataset(
self, base_dataset: StreamingDataset, indices: Sequence[int]
) -> Tuple[StreamingDataset, int]:
"""Create a StreamingDataset restricted to the provided indices."""
if len(base_dataset.subsampled_files) != len(base_dataset.region_of_interest):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines +214 to +217
cache_dir: str | Path | None = None,
num_workers: int = 1,
mem_per_worker: str | int = "8GB",
compute: bool = True,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add other dataset will need to add these new args.

Comment on lines +220 to +229
super().__init__(
root=f"{str(ehr_root)},{str(note_root)},{str(cxr_root)}",
tables=(ehr_tables or []) + (note_tables or []) + (cxr_tables or []),
dataset_name=dataset_name,
cache_dir=cache_dir,
num_workers=num_workers,
mem_per_worker=mem_per_worker,
compute=False, # defer compute to later, we need to aggregate all sub-datasets first
dev=dev,
)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

important to assign root, tables, dataset_name, dev to calculate a correct default cache path.

Comment on lines +290 to 291
def load_data(self) -> dd.DataFrame:
"""
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one should override load_data if it requires custom logic, since there is no .global_event_df available. The existance of a dask dataframe in class field would cause a multi-process failure, causing it unable to use multiple worker to process the sample.

Comment on lines 96 to 123
def _pickle(datum: dict[str, Any]) -> dict[str, bytes]:
return {k: pickle.dumps(v) for k, v in datum.items()}


def _unpickle(datum: dict[str, bytes]) -> dict[str, Any]:
return {k: pickle.loads(v) for k, v in datum.items()}


if path_exists(alt_path):
logger.info(f"Original path does not exist. Using alternative: {alt_path}")
return scan_file(alt_path)
def _patient_bucket(patient_id: str, n_partitions: int) -> int:
"""Hash patient_id to a bucket number."""
bucket = int(xxhash.xxh64_intdigest(patient_id) % n_partitions)
return bucket

raise FileNotFoundError(f"Neither path exists: {path} or {alt_path}")

def _transform_fn(
input: tuple[int, str, BaseTask],
) -> Iterator[Dict[str, Any]]:
(bucket_id, merged_cache, task) = input
path = f"{merged_cache}/bucket={bucket_id}"
# This is more efficient than reading patient by patient
grouped = pd.read_parquet(path).groupby("patient_id")

for patient_id, patient_df in grouped:
patient = Patient(patient_id=str(patient_id), data_source=patient_df)
for sample in task(patient):
# Schema is too complex to be handled by LitData, so we pickle the sample here
yield _pickle(sample)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's important to define these functions (outside of a class) here to avoid issue for multi-processes environment.

Comment on lines +197 to +210
id_str = json.dumps(
{
"root": self.root,
"tables": sorted(self.tables),
"dataset_name": self.dataset_name,
"dev": self.dev,
},
sort_keys=True,
)
cache_dir = Path(platformdirs.user_cache_dir(appname="pyhealth")) / str(
uuid.uuid5(uuid.NAMESPACE_DNS, id_str)
)
print(f"No cache_dir provided. Using default cache dir: {cache_dir}")
self._cache_dir = cache_dir
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Compute default cache dir. I think this should be unique enough?

Comment on lines +320 to +323
bucket = global_event_df["patient_id"].apply(
lambda pid: _patient_bucket(pid, n_partitions),
meta=("patient_id", "int"),
)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Split dataframe into bucket based on paitent id, enable faster processing downstream.

Comment on lines +485 to 497

path = self._joined_cache()
with open(f"{path}/index.json", "rb") as f:
n_partitions = json.load(f)["n_partitions"]
bucket = _patient_bucket(patient_id, n_partitions)
path = f"{path}/bucket={bucket}"
df = pd.read_parquet(path)
patient = Patient(
patient_id=patient_id,
data_source=df[df["patient_id"] == patient_id],
)
return patient

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Read from relevant bucket only, this would be much faster for larger dataset.

@LogicFan LogicFan marked this pull request as ready for review November 24, 2025 00:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Improve memory usage of BaseDataset.load_data()

2 participants