-
Notifications
You must be signed in to change notification settings - Fork 472
[Memory] Better Memory Utilization for PyHealth #622
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Co-authored-by: John Wu <[email protected]>
Co-authored-by: John Wu <[email protected]>
Co-authored-by: John Wu <[email protected]>
Co-authored-by: John Wu <[email protected]>
Co-authored-by: John Wu <[email protected]>
Co-authored-by: John Wu <[email protected]>
Co-authored-by: John Wu <[email protected]>
Co-authored-by: John Wu <[email protected]>
Co-authored-by: John Wu <[email protected]>
Co-authored-by: John Wu <[email protected]>
Co-authored-by: John Wu <[email protected]>
Co-authored-by: John Wu <[email protected]>
Co-authored-by: John Wu <[email protected]>
Co-authored-by: John Wu <[email protected]>
Co-authored-by: John Wu <[email protected]>
Co-authored-by: John Wu <[email protected]>
Co-authored-by: John Wu <[email protected]>
Co-authored-by: John Wu <[email protected]>
Co-authored-by: John Wu <[email protected]>
Co-authored-by: John Wu <[email protected]>
Co-authored-by: John Wu <[email protected]>
Co-authored-by: John Wu <[email protected]>
Co-authored-by: John Wu <[email protected]>
Co-authored-by: John Wu <[email protected]>
Co-authored-by: John Wu <[email protected]>
Co-authored-by: John Wu <[email protected]>
Carshing Co-authored-by: John Wu <[email protected]>
jhnwu3
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey do you know which dask version you used? Probably would be good to add this to the pyprojectoml.
Will look deeper into the CI and other implementation details later, have to go wander Japan for a bit 😂
| "dask[complete]~=2025.11.0", | ||
| "litdata~=0.2.58", | ||
| "xxhash~=3.6.0", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New dependencies are specified here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah im blind.
| event_type: str | ||
| timestamp: datetime | ||
| attr_dict: Mapping[str, any] = field(default_factory=dict) | ||
| attr_dict: Mapping[str, Any] = field(default_factory=dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small type hint fix here, any is a function for iterables, and Any is the correct type hint.
| """ | ||
|
|
||
| def __init__(self, patient_id: str, data_source: pl.DataFrame) -> None: | ||
| def __init__(self, patient_id: str, data_source: pd.DataFrame) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The polars will spawn processes. Given this will be ran in litdata.optimize in a multi-processes environment, nested multi-process will cause a hang. Thus, we can this to pandas.
|
|
||
| def get_events_py(self, **kawargs) -> List[Event]: | ||
| """Type-safe wrapper for get_events.""" | ||
| return self.get_events(**kawargs, return_df=False) # type: ignore | ||
|
|
||
| def get_events_df(self, **kawargs) -> pd.DataFrame: | ||
| """DataFrame wrapper for get_events.""" | ||
| return self.get_events(**kawargs, return_df=True) # type: ignore | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added for type hint purpose.
| labevents_df = patient.get_events_df( | ||
| event_type="labevents", | ||
| start=admission_time, | ||
| end=admission_dischtime, | ||
| return_df=True, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It returns a pandas dataframe here, any task that uses get_event_df (or get_event(..., return_df=True)) will need to update their code to be compitable with pandas.
pyhealth/tasks/base_task.py
Outdated
| output_schema: Dict[str, Union[str, Type]] | ||
|
|
||
| def pre_filter(self, df: pl.LazyFrame) -> pl.LazyFrame: | ||
| def pre_filter(self, df: dd.DataFrame) -> dd.DataFrame: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It becomes dask dataframe here, any task override this need to update.
| dataset.set_shuffle(shuffle) | ||
| dataloader = DataLoader( | ||
| dataset, | ||
| batch_size=batch_size, | ||
| shuffle=shuffle, | ||
| collate_fn=collate_fn_dict_with_padding, | ||
| ) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
python native dataloader does not support shuffle on IterableDataset, it must be shuffled at dataset level.
| train_dataset = SampleSubset(dataset, train_index) # type: ignore | ||
| val_dataset = SampleSubset(dataset, val_index) # type: ignore | ||
| test_dataset = SampleSubset(dataset, test_index) # type: ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch native Subset does not support iterable dataset, we use a custom-defined iterable dataset here.
| class SampleSubset(IterableDataset): | ||
| """A subset of the SampleDataset. | ||
| Args: | ||
| sample_dataset (SampleDataset): The original SampleDataset. | ||
| indices (List[int]): List of indices to include in the subset. | ||
| """ | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add new SampleSubset to support creating subset of a SampleDataset for train/val/test split.
| def _build_subset_dataset( | ||
| self, base_dataset: StreamingDataset, indices: Sequence[int] | ||
| ) -> Tuple[StreamingDataset, int]: | ||
| """Create a StreamingDataset restricted to the provided indices.""" | ||
| if len(base_dataset.subsampled_files) != len(base_dataset.region_of_interest): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modified based on official litdata train_test_split https://github.com/Lightning-AI/litData/blob/main/src/litdata/utilities/train_test_split.py
| cache_dir: str | Path | None = None, | ||
| num_workers: int = 1, | ||
| mem_per_worker: str | int = "8GB", | ||
| compute: bool = True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add other dataset will need to add these new args.
| super().__init__( | ||
| root=f"{str(ehr_root)},{str(note_root)},{str(cxr_root)}", | ||
| tables=(ehr_tables or []) + (note_tables or []) + (cxr_tables or []), | ||
| dataset_name=dataset_name, | ||
| cache_dir=cache_dir, | ||
| num_workers=num_workers, | ||
| mem_per_worker=mem_per_worker, | ||
| compute=False, # defer compute to later, we need to aggregate all sub-datasets first | ||
| dev=dev, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
important to assign root, tables, dataset_name, dev to calculate a correct default cache path.
| def load_data(self) -> dd.DataFrame: | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one should override load_data if it requires custom logic, since there is no .global_event_df available. The existance of a dask dataframe in class field would cause a multi-process failure, causing it unable to use multiple worker to process the sample.
| def _pickle(datum: dict[str, Any]) -> dict[str, bytes]: | ||
| return {k: pickle.dumps(v) for k, v in datum.items()} | ||
|
|
||
|
|
||
| def _unpickle(datum: dict[str, bytes]) -> dict[str, Any]: | ||
| return {k: pickle.loads(v) for k, v in datum.items()} | ||
|
|
||
|
|
||
| if path_exists(alt_path): | ||
| logger.info(f"Original path does not exist. Using alternative: {alt_path}") | ||
| return scan_file(alt_path) | ||
| def _patient_bucket(patient_id: str, n_partitions: int) -> int: | ||
| """Hash patient_id to a bucket number.""" | ||
| bucket = int(xxhash.xxh64_intdigest(patient_id) % n_partitions) | ||
| return bucket | ||
|
|
||
| raise FileNotFoundError(f"Neither path exists: {path} or {alt_path}") | ||
|
|
||
| def _transform_fn( | ||
| input: tuple[int, str, BaseTask], | ||
| ) -> Iterator[Dict[str, Any]]: | ||
| (bucket_id, merged_cache, task) = input | ||
| path = f"{merged_cache}/bucket={bucket_id}" | ||
| # This is more efficient than reading patient by patient | ||
| grouped = pd.read_parquet(path).groupby("patient_id") | ||
|
|
||
| for patient_id, patient_df in grouped: | ||
| patient = Patient(patient_id=str(patient_id), data_source=patient_df) | ||
| for sample in task(patient): | ||
| # Schema is too complex to be handled by LitData, so we pickle the sample here | ||
| yield _pickle(sample) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's important to define these functions (outside of a class) here to avoid issue for multi-processes environment.
| id_str = json.dumps( | ||
| { | ||
| "root": self.root, | ||
| "tables": sorted(self.tables), | ||
| "dataset_name": self.dataset_name, | ||
| "dev": self.dev, | ||
| }, | ||
| sort_keys=True, | ||
| ) | ||
| cache_dir = Path(platformdirs.user_cache_dir(appname="pyhealth")) / str( | ||
| uuid.uuid5(uuid.NAMESPACE_DNS, id_str) | ||
| ) | ||
| print(f"No cache_dir provided. Using default cache dir: {cache_dir}") | ||
| self._cache_dir = cache_dir |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Compute default cache dir. I think this should be unique enough?
| bucket = global_event_df["patient_id"].apply( | ||
| lambda pid: _patient_bucket(pid, n_partitions), | ||
| meta=("patient_id", "int"), | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Split dataframe into bucket based on paitent id, enable faster processing downstream.
|
|
||
| path = self._joined_cache() | ||
| with open(f"{path}/index.json", "rb") as f: | ||
| n_partitions = json.load(f)["n_partitions"] | ||
| bucket = _patient_bucket(patient_id, n_partitions) | ||
| path = f"{path}/bucket={bucket}" | ||
| df = pd.read_parquet(path) | ||
| patient = Patient( | ||
| patient_id=patient_id, | ||
| data_source=df[df["patient_id"] == patient_id], | ||
| ) | ||
| return patient | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Read from relevant bucket only, this would be much faster for larger dataset.
This should be able to close #332 .