diff --git a/tests/torchtune/datasets/test_concat_dataset.py b/tests/torchtune/datasets/test_concat_dataset.py index 32ecc3bc5a..352d20e372 100644 --- a/tests/torchtune/datasets/test_concat_dataset.py +++ b/tests/torchtune/datasets/test_concat_dataset.py @@ -6,7 +6,25 @@ import pytest from datasets import Dataset +from torch.utils.data import Dataset as TorchDataset from torchtune.datasets._concat import ConcatDataset +from torchtune.datasets._packed import PackedDataset + + +class DummyDataset(TorchDataset): + def __init__(self, sample_size): + self.sample_size = sample_size + + def __getitem__(self, index): + if index >= 1000: + raise IndexError() + return { + "tokens": [index] * self.sample_size, + "labels": [index] * self.sample_size, + } + + def __len__(self): + return 1000 class TestConcatDataset: @@ -20,6 +38,16 @@ def datasets(self): ds6 = Dataset.from_list([{"data": f"ds6_{i}"} for i in range(42)]) return [ds1, ds2, ds3, ds4, ds5, ds6] + @pytest.fixture + def torch_datasets(self): + ds1 = DummyDataset(4) + ds2 = DummyDataset(8) + ds3 = DummyDataset(15) + ds4 = DummyDataset(16) + ds5 = DummyDataset(23) + ds6 = DummyDataset(42) + return [ds1, ds2, ds3, ds4, ds5, ds6] + def test_length(self, datasets): """Test the correct computation of total length""" multi_dataset = ConcatDataset(datasets) @@ -51,3 +79,14 @@ def test_invalid_index_type(self, datasets): with pytest.raises(TypeError): multi_dataset["invalid_type"] # Non-integer index + + def test_packed_dataset(self, torch_datasets): + torch_datasets[0] = PackedDataset( + torch_datasets[0], + max_seq_len=25, + max_packs=5, + split_across_pack=True, + ) + + with pytest.raises(ValueError): + concated_dataset = ConcatDataset(torch_datasets) diff --git a/torchtune/datasets/_concat.py b/torchtune/datasets/_concat.py index 2a76602697..304a605641 100644 --- a/torchtune/datasets/_concat.py +++ b/torchtune/datasets/_concat.py @@ -10,6 +10,8 @@ from torchtune import utils +from torchtune.datasets._packed import PackedDataset + log = utils.get_logger("DEBUG") @@ -35,6 +37,9 @@ class ConcatDataset(Dataset): datasets (List[Dataset]): A list of datasets to concatenate. Each dataset must be an instance of a class derived from :class:`~torch.utils.data.Dataset`. + Raises: + ValueError: if instanse of `PackedDataset` is in `datasets` + Examples: >>> dataset1 = MyCustomDataset(params1) >>> dataset2 = MyCustomDataset(params2) @@ -64,6 +69,13 @@ class ConcatDataset(Dataset): def __init__(self, datasets: List[Dataset]): self._datasets: List[Dataset] = datasets + + for dataset in self._datasets: + if isinstance(dataset, PackedDataset): + raise ValueError( + "ConcatDataset can't process instances of PackedDataset." + ) + self._len: int = sum(len(dataset) for dataset in datasets) self._indexes: List[Tuple[int, int, int]] = []