Skip to content

Support the combination of multiple datasets with the dataloader#99

Open
hanaol wants to merge 1 commit intomainfrom
hanaol/dataloader-multi-dataset
Open

Support the combination of multiple datasets with the dataloader#99
hanaol wants to merge 1 commit intomainfrom
hanaol/dataloader-multi-dataset

Conversation

@hanaol
Copy link
Collaborator

@hanaol hanaol commented Mar 13, 2026

Problem
The current dataloader only supports a single dataset. As a result, it is not possible to combine data coming from different paths.

Solution
This PR extends the dataloader to support multiple datasets. The root, split_file, and val_frac parameters retain the same meaning as before. In addition, this introduces a new parameter, dataset_id, which serves as an identifier for each dataset and will be useful for future development of multi-head models.

Example

  datasets:
    - root: $DATASET1_PATH/mp_filelist.txt
      split_file: $DATASET1_PATH/split.json
      val_frac: 0.005
    - root: $DATASET2_PATH/mp_filelist.txt
      split_file: $DATASET2_PATH/split.json
      val_frac: 0.005

@hanaol hanaol requested a review from forklady42 March 13, 2026 23:39
@forklady42
Copy link
Collaborator

@claude please review this PR

(I'll review afterwards)

@claude
Copy link

claude bot commented Mar 16, 2026

Claude finished @forklady42's task in 1m 58s —— View job


PR Review: Support the combination of multiple datasets with the dataloader

  • Gather context
  • Read changed files
  • Analyze code
  • Post review

Overall this is a well-structured approach to multi-dataset support. The DatasetSpec dataclass, AddDatasetID wrapper, and ConcatDataset usage are all clean patterns. A few issues worth addressing before merging:

Bug: collate_fn fallback is incorrect for dict-based batches

collate.py:10-11 — The fallback triggers when default_collate raises a RuntimeError (e.g. mismatched tensor shapes). At that point, batch is a list of dicts like {"data": ..., "label": ..., "index": ..., "Dataset_ID": ...}. But zip(*batch) unpacks dict keys, not values — so x, y, index, dataset_id would receive ("data","data",...), ("label","label",...), etc. rather than actual tensors.

This was a pre-existing issue (the old fallback had the same structure when RhoData.__getitem__ still returned a dict), but the new dataset_id field makes it more prominent. The fallback path also returns a positional 4-tuple while the normal default_collate path returns a dict — inconsistent return types will break any downstream consumer.

Fix this →

Performance: setup() always processes all datasets regardless of stage

dataset.py:105-138setup() iterates all specs and calls split_data() on every dataset unconditionally, even when stage == "test" (in which case train_parts and val_parts are built and then discarded). For large multi-dataset configurations this loads and splits unnecessary data.

# Current: all parts built regardless of stage
for spec in self.specs:
    ...
    train_parts.append(...)   # wasted when stage == "test"
    val_parts.append(...)     # wasted when stage == "test"

Consider gating which parts are built based on stage.

Fix this →

Design note: dataset_id is silently ignored by the model

The "Dataset_ID" key is added to every batch sample (dataset.py:39) but lightning.py:_loss_calculation only accesses batch["data"] and batch["label"]. For training and validation, the dataset_id is a no-op today. The PR description flags this as intentional ("useful for future development of multi-head models"), so this is fine — but worth a comment in _loss_calculation or the docstring so future developers know to extend it there.

Minor: defensive None check on a non-optional field

dataset.py:82DatasetSpec.root is typed as str (no default, required). The if s.root is None guard is dead code under normal construction. Either change the type to str | None = None (to make it optional), or remove the guard and let the dataclass raise a TypeError naturally.

Missing tests

No tests were added for the new multi-dataset path (DatasetSpec, AddDatasetID, multi-spec setup()). Given the collate fallback bug above, a test covering mismatched-shape batches across datasets would be valuable.


Copy link
Collaborator

@forklady42 forklady42 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comment about clarifying the config interface. Also worth reviewing Claude's comments.

datasets:
- root: /scratch/gpfs/ROSENGROUP/common/globus_share_OA/mp/chg_datasets/dataset_2/mp_filelist.txt
split_file: /scratch/gpfs/ROSENGROUP/common/globus_share_OA/mp/chg_datasets/dataset_2/split.json
val_frac: 0.005
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

val_frac isn't used when split_file is provided. this is at least worth a comment so no one is misled

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants