From 3b5fffb539dc2d99dc1b5a3297fdd4a0345b1a57 Mon Sep 17 00:00:00 2001 From: Mark <75219117+krammnic@users.noreply.github.com> Date: Thu, 10 Oct 2024 03:17:38 -0400 Subject: [PATCH 1/8] Check that there is no PackedDataset while building ConcatDataset --- torchtune/datasets/_concat.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/torchtune/datasets/_concat.py b/torchtune/datasets/_concat.py index 2a76602697..cfe645d815 100644 --- a/torchtune/datasets/_concat.py +++ b/torchtune/datasets/_concat.py @@ -10,6 +10,8 @@ from torchtune import utils +from torchtune.datasets._packed import PackedDataset + log = utils.get_logger("DEBUG") @@ -34,7 +36,9 @@ class ConcatDataset(Dataset): Args: datasets (List[Dataset]): A list of datasets to concatenate. Each dataset must be an instance of a class derived from :class:`~torch.utils.data.Dataset`. - + Raises: + ValueError: if instanse of `PackedDataset` is in `datasets` + Examples: >>> dataset1 = MyCustomDataset(params1) >>> dataset2 = MyCustomDataset(params2) @@ -64,6 +68,13 @@ class ConcatDataset(Dataset): def __init__(self, datasets: List[Dataset]): self._datasets: List[Dataset] = datasets + + for dataset in self._datasets: + if isinstance(dataset, PackedDataset): + raise ValueError( + "ConcatDataset can't proceed instances of PackedDataset." + ) + self._len: int = sum(len(dataset) for dataset in datasets) self._indexes: List[Tuple[int, int, int]] = [] From 935daac2d698e1282bd3e90c8c91c6837f0198fd Mon Sep 17 00:00:00 2001 From: Mark <75219117+krammnic@users.noreply.github.com> Date: Thu, 10 Oct 2024 03:18:35 -0400 Subject: [PATCH 2/8] unit test for new check in test_concat_dataset.py --- .../torchtune/datasets/test_concat_dataset.py | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/tests/torchtune/datasets/test_concat_dataset.py b/tests/torchtune/datasets/test_concat_dataset.py index 32ecc3bc5a..352d20e372 100644 --- a/tests/torchtune/datasets/test_concat_dataset.py +++ b/tests/torchtune/datasets/test_concat_dataset.py @@ -6,7 +6,25 @@ import pytest from datasets import Dataset +from torch.utils.data import Dataset as TorchDataset from torchtune.datasets._concat import ConcatDataset +from torchtune.datasets._packed import PackedDataset + + +class DummyDataset(TorchDataset): + def __init__(self, sample_size): + self.sample_size = sample_size + + def __getitem__(self, index): + if index >= 1000: + raise IndexError() + return { + "tokens": [index] * self.sample_size, + "labels": [index] * self.sample_size, + } + + def __len__(self): + return 1000 class TestConcatDataset: @@ -20,6 +38,16 @@ def datasets(self): ds6 = Dataset.from_list([{"data": f"ds6_{i}"} for i in range(42)]) return [ds1, ds2, ds3, ds4, ds5, ds6] + @pytest.fixture + def torch_datasets(self): + ds1 = DummyDataset(4) + ds2 = DummyDataset(8) + ds3 = DummyDataset(15) + ds4 = DummyDataset(16) + ds5 = DummyDataset(23) + ds6 = DummyDataset(42) + return [ds1, ds2, ds3, ds4, ds5, ds6] + def test_length(self, datasets): """Test the correct computation of total length""" multi_dataset = ConcatDataset(datasets) @@ -51,3 +79,14 @@ def test_invalid_index_type(self, datasets): with pytest.raises(TypeError): multi_dataset["invalid_type"] # Non-integer index + + def test_packed_dataset(self, torch_datasets): + torch_datasets[0] = PackedDataset( + torch_datasets[0], + max_seq_len=25, + max_packs=5, + split_across_pack=True, + ) + + with pytest.raises(ValueError): + concated_dataset = ConcatDataset(torch_datasets) From 77273be159c03bbdf648141ecb1d097b3e55dea4 Mon Sep 17 00:00:00 2001 From: Mark <75219117+krammnic@users.noreply.github.com> Date: Thu, 10 Oct 2024 10:15:43 -0400 Subject: [PATCH 3/8] fix lint test_concat_dataset.py From fe8545a19de1391b9541efbbff2d5c1c1add9568 Mon Sep 17 00:00:00 2001 From: Mark <75219117+krammnic@users.noreply.github.com> Date: Thu, 10 Oct 2024 10:17:02 -0400 Subject: [PATCH 4/8] fix lint _concat.py --- torchtune/datasets/_concat.py | 171 +++++++++++++++++----------------- 1 file changed, 84 insertions(+), 87 deletions(-) diff --git a/torchtune/datasets/_concat.py b/torchtune/datasets/_concat.py index cfe645d815..352d20e372 100644 --- a/torchtune/datasets/_concat.py +++ b/torchtune/datasets/_concat.py @@ -4,92 +4,89 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import List, Tuple - -from torch.utils.data import Dataset - -from torchtune import utils - +import pytest +from datasets import Dataset +from torch.utils.data import Dataset as TorchDataset +from torchtune.datasets._concat import ConcatDataset from torchtune.datasets._packed import PackedDataset -log = utils.get_logger("DEBUG") - - -class ConcatDataset(Dataset): - """ - A dataset class for concatenating multiple sub-datasets into a single dataset. This class enables the - unified handling of different datasets as if they were a single dataset, simplifying tasks such as - training models on multiple sources of data simultaneously. - - The class internally manages the aggregation of different datasets and allows transparent indexing across them. - However, it requires all constituent datasets to be fully loaded into memory, which might not be optimal for - very large datasets. - - Upon initialization, this class computes the cumulative length of all datasets and maintains an internal mapping - of indices to the respective datasets. This approach allows the :class:`~torchtune.datasets.ConcatDataset` - to delegate data retrieval to the appropriate sub-dataset transparently when a particular index is accessed. - - Note: - Using this class with very large datasets can lead to high memory consumption, as it requires all datasets to - be loaded into memory. For large-scale scenarios, consider other strategies that might stream data on demand. - - Args: - datasets (List[Dataset]): A list of datasets to concatenate. Each dataset must be an instance of a class - derived from :class:`~torch.utils.data.Dataset`. - Raises: - ValueError: if instanse of `PackedDataset` is in `datasets` - - Examples: - >>> dataset1 = MyCustomDataset(params1) - >>> dataset2 = MyCustomDataset(params2) - >>> concat_dataset = ConcatDataset([dataset1, dataset2]) - >>> print(len(concat_dataset)) # Total length of both datasets - >>> data_point = concat_dataset[1500] # Accesses an element from the appropriate dataset - - This can also be accomplished by passing in a list of datasets to the YAML config:: - - dataset: - - _component_: torchtune.datasets.instruct_dataset - source: vicgalle/alpaca-gpt4 - template: torchtune.data.AlpacaInstructTemplate - split: train - train_on_input: True - - _component_: torchtune.datasets.instruct_dataset - source: samsum - template: torchtune.data.SummarizeTemplate - column_map: {"output": "summary"} - output: summary - split: train - train_on_input: False - - This class primarily focuses on providing a unified interface to access elements from multiple datasets, - enhancing the flexibility in handling diverse data sources for training machine learning models. - """ - - def __init__(self, datasets: List[Dataset]): - self._datasets: List[Dataset] = datasets - - for dataset in self._datasets: - if isinstance(dataset, PackedDataset): - raise ValueError( - "ConcatDataset can't proceed instances of PackedDataset." - ) - - self._len: int = sum(len(dataset) for dataset in datasets) - self._indexes: List[Tuple[int, int, int]] = [] - - # Calculate distribution of indexes in all datasets - cumulative_index = 0 - for idx, dataset in enumerate(datasets): - next_cumulative_index = cumulative_index + len(dataset) - self._indexes.append((cumulative_index, next_cumulative_index, idx)) - cumulative_index = next_cumulative_index - - def __getitem__(self, index: int) -> Tuple[List[int], List[int]]: - for start, stop, dataset_index in self._indexes: - if start <= index < stop: - dataset = self._datasets[dataset_index] - return dataset[index - start] - - def __len__(self) -> int: - return self._len + +class DummyDataset(TorchDataset): + def __init__(self, sample_size): + self.sample_size = sample_size + + def __getitem__(self, index): + if index >= 1000: + raise IndexError() + return { + "tokens": [index] * self.sample_size, + "labels": [index] * self.sample_size, + } + + def __len__(self): + return 1000 + + +class TestConcatDataset: + @pytest.fixture + def datasets(self): + ds1 = Dataset.from_list([{"data": f"ds1_{i}"} for i in range(4)]) + ds2 = Dataset.from_list([{"data": f"ds2_{i}"} for i in range(8)]) + ds3 = Dataset.from_list([{"data": f"ds3_{i}"} for i in range(15)]) + ds4 = Dataset.from_list([{"data": f"ds4_{i}"} for i in range(16)]) + ds5 = Dataset.from_list([{"data": f"ds5_{i}"} for i in range(23)]) + ds6 = Dataset.from_list([{"data": f"ds6_{i}"} for i in range(42)]) + return [ds1, ds2, ds3, ds4, ds5, ds6] + + @pytest.fixture + def torch_datasets(self): + ds1 = DummyDataset(4) + ds2 = DummyDataset(8) + ds3 = DummyDataset(15) + ds4 = DummyDataset(16) + ds5 = DummyDataset(23) + ds6 = DummyDataset(42) + return [ds1, ds2, ds3, ds4, ds5, ds6] + + def test_length(self, datasets): + """Test the correct computation of total length""" + multi_dataset = ConcatDataset(datasets) + + # sum of individual datasets lengths + expected_length = 4 + 8 + 15 + 16 + 23 + 42 # 108 + assert len(multi_dataset) == expected_length + + def test_getitem(self, datasets): + """Test item retrieval across dataset boundaries""" + multi_dataset = ConcatDataset(datasets) + + # Testing indices across different datasets + assert multi_dataset[-1] is None # Index out of range + assert multi_dataset[0] == {"data": "ds1_0"} + assert multi_dataset[3] == {"data": "ds1_3"} + assert multi_dataset[4] == {"data": "ds2_0"} + assert multi_dataset[10] == {"data": "ds2_6"} + assert multi_dataset[20] == {"data": "ds3_8"} + assert multi_dataset[35] == {"data": "ds4_8"} + assert multi_dataset[50] == {"data": "ds5_7"} + assert multi_dataset[70] == {"data": "ds6_4"} + assert multi_dataset[90] == {"data": "ds6_24"} + assert multi_dataset[108] is None # Index out of range + + def test_invalid_index_type(self, datasets): + """Test handling of invalid index types""" + multi_dataset = ConcatDataset(datasets) + + with pytest.raises(TypeError): + multi_dataset["invalid_type"] # Non-integer index + + def test_packed_dataset(self, torch_datasets): + torch_datasets[0] = PackedDataset( + torch_datasets[0], + max_seq_len=25, + max_packs=5, + split_across_pack=True, + ) + + with pytest.raises(ValueError): + concated_dataset = ConcatDataset(torch_datasets) From b3817437ca8cc503bbb07f552dd749d00c08d895 Mon Sep 17 00:00:00 2001 From: Mark <75219117+krammnic@users.noreply.github.com> Date: Thu, 10 Oct 2024 10:18:31 -0400 Subject: [PATCH 5/8] fix incorrect commit + lint _concat.py --- torchtune/datasets/_concat.py | 171 +++++++++++++++++----------------- 1 file changed, 87 insertions(+), 84 deletions(-) diff --git a/torchtune/datasets/_concat.py b/torchtune/datasets/_concat.py index 352d20e372..cfe645d815 100644 --- a/torchtune/datasets/_concat.py +++ b/torchtune/datasets/_concat.py @@ -4,89 +4,92 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import pytest -from datasets import Dataset -from torch.utils.data import Dataset as TorchDataset -from torchtune.datasets._concat import ConcatDataset -from torchtune.datasets._packed import PackedDataset +from typing import List, Tuple + +from torch.utils.data import Dataset +from torchtune import utils + +from torchtune.datasets._packed import PackedDataset -class DummyDataset(TorchDataset): - def __init__(self, sample_size): - self.sample_size = sample_size - - def __getitem__(self, index): - if index >= 1000: - raise IndexError() - return { - "tokens": [index] * self.sample_size, - "labels": [index] * self.sample_size, - } - - def __len__(self): - return 1000 - - -class TestConcatDataset: - @pytest.fixture - def datasets(self): - ds1 = Dataset.from_list([{"data": f"ds1_{i}"} for i in range(4)]) - ds2 = Dataset.from_list([{"data": f"ds2_{i}"} for i in range(8)]) - ds3 = Dataset.from_list([{"data": f"ds3_{i}"} for i in range(15)]) - ds4 = Dataset.from_list([{"data": f"ds4_{i}"} for i in range(16)]) - ds5 = Dataset.from_list([{"data": f"ds5_{i}"} for i in range(23)]) - ds6 = Dataset.from_list([{"data": f"ds6_{i}"} for i in range(42)]) - return [ds1, ds2, ds3, ds4, ds5, ds6] - - @pytest.fixture - def torch_datasets(self): - ds1 = DummyDataset(4) - ds2 = DummyDataset(8) - ds3 = DummyDataset(15) - ds4 = DummyDataset(16) - ds5 = DummyDataset(23) - ds6 = DummyDataset(42) - return [ds1, ds2, ds3, ds4, ds5, ds6] - - def test_length(self, datasets): - """Test the correct computation of total length""" - multi_dataset = ConcatDataset(datasets) - - # sum of individual datasets lengths - expected_length = 4 + 8 + 15 + 16 + 23 + 42 # 108 - assert len(multi_dataset) == expected_length - - def test_getitem(self, datasets): - """Test item retrieval across dataset boundaries""" - multi_dataset = ConcatDataset(datasets) - - # Testing indices across different datasets - assert multi_dataset[-1] is None # Index out of range - assert multi_dataset[0] == {"data": "ds1_0"} - assert multi_dataset[3] == {"data": "ds1_3"} - assert multi_dataset[4] == {"data": "ds2_0"} - assert multi_dataset[10] == {"data": "ds2_6"} - assert multi_dataset[20] == {"data": "ds3_8"} - assert multi_dataset[35] == {"data": "ds4_8"} - assert multi_dataset[50] == {"data": "ds5_7"} - assert multi_dataset[70] == {"data": "ds6_4"} - assert multi_dataset[90] == {"data": "ds6_24"} - assert multi_dataset[108] is None # Index out of range - - def test_invalid_index_type(self, datasets): - """Test handling of invalid index types""" - multi_dataset = ConcatDataset(datasets) - - with pytest.raises(TypeError): - multi_dataset["invalid_type"] # Non-integer index - - def test_packed_dataset(self, torch_datasets): - torch_datasets[0] = PackedDataset( - torch_datasets[0], - max_seq_len=25, - max_packs=5, - split_across_pack=True, - ) - - with pytest.raises(ValueError): - concated_dataset = ConcatDataset(torch_datasets) +log = utils.get_logger("DEBUG") + + +class ConcatDataset(Dataset): + """ + A dataset class for concatenating multiple sub-datasets into a single dataset. This class enables the + unified handling of different datasets as if they were a single dataset, simplifying tasks such as + training models on multiple sources of data simultaneously. + + The class internally manages the aggregation of different datasets and allows transparent indexing across them. + However, it requires all constituent datasets to be fully loaded into memory, which might not be optimal for + very large datasets. + + Upon initialization, this class computes the cumulative length of all datasets and maintains an internal mapping + of indices to the respective datasets. This approach allows the :class:`~torchtune.datasets.ConcatDataset` + to delegate data retrieval to the appropriate sub-dataset transparently when a particular index is accessed. + + Note: + Using this class with very large datasets can lead to high memory consumption, as it requires all datasets to + be loaded into memory. For large-scale scenarios, consider other strategies that might stream data on demand. + + Args: + datasets (List[Dataset]): A list of datasets to concatenate. Each dataset must be an instance of a class + derived from :class:`~torch.utils.data.Dataset`. + Raises: + ValueError: if instanse of `PackedDataset` is in `datasets` + + Examples: + >>> dataset1 = MyCustomDataset(params1) + >>> dataset2 = MyCustomDataset(params2) + >>> concat_dataset = ConcatDataset([dataset1, dataset2]) + >>> print(len(concat_dataset)) # Total length of both datasets + >>> data_point = concat_dataset[1500] # Accesses an element from the appropriate dataset + + This can also be accomplished by passing in a list of datasets to the YAML config:: + + dataset: + - _component_: torchtune.datasets.instruct_dataset + source: vicgalle/alpaca-gpt4 + template: torchtune.data.AlpacaInstructTemplate + split: train + train_on_input: True + - _component_: torchtune.datasets.instruct_dataset + source: samsum + template: torchtune.data.SummarizeTemplate + column_map: {"output": "summary"} + output: summary + split: train + train_on_input: False + + This class primarily focuses on providing a unified interface to access elements from multiple datasets, + enhancing the flexibility in handling diverse data sources for training machine learning models. + """ + + def __init__(self, datasets: List[Dataset]): + self._datasets: List[Dataset] = datasets + + for dataset in self._datasets: + if isinstance(dataset, PackedDataset): + raise ValueError( + "ConcatDataset can't proceed instances of PackedDataset." + ) + + self._len: int = sum(len(dataset) for dataset in datasets) + self._indexes: List[Tuple[int, int, int]] = [] + + # Calculate distribution of indexes in all datasets + cumulative_index = 0 + for idx, dataset in enumerate(datasets): + next_cumulative_index = cumulative_index + len(dataset) + self._indexes.append((cumulative_index, next_cumulative_index, idx)) + cumulative_index = next_cumulative_index + + def __getitem__(self, index: int) -> Tuple[List[int], List[int]]: + for start, stop, dataset_index in self._indexes: + if start <= index < stop: + dataset = self._datasets[dataset_index] + return dataset[index - start] + + def __len__(self) -> int: + return self._len From 0a63bdb994dffa0f780d48eb363e4632fd720c01 Mon Sep 17 00:00:00 2001 From: Mark <75219117+krammnic@users.noreply.github.com> Date: Fri, 11 Oct 2024 09:02:09 +0300 Subject: [PATCH 6/8] Fix typo _concat.py --- torchtune/datasets/_concat.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchtune/datasets/_concat.py b/torchtune/datasets/_concat.py index cfe645d815..581bdf7ae5 100644 --- a/torchtune/datasets/_concat.py +++ b/torchtune/datasets/_concat.py @@ -38,7 +38,6 @@ class ConcatDataset(Dataset): derived from :class:`~torch.utils.data.Dataset`. Raises: ValueError: if instanse of `PackedDataset` is in `datasets` - Examples: >>> dataset1 = MyCustomDataset(params1) >>> dataset2 = MyCustomDataset(params2) @@ -72,7 +71,7 @@ def __init__(self, datasets: List[Dataset]): for dataset in self._datasets: if isinstance(dataset, PackedDataset): raise ValueError( - "ConcatDataset can't proceed instances of PackedDataset." + "ConcatDataset can't process instances of PackedDataset." ) self._len: int = sum(len(dataset) for dataset in datasets) From 3c24a538c5a623bd460aad30f814dfea4085c286 Mon Sep 17 00:00:00 2001 From: Mark <75219117+krammnic@users.noreply.github.com> Date: Fri, 11 Oct 2024 09:04:45 +0300 Subject: [PATCH 7/8] Trying to fix lint in _concat.py --- torchtune/datasets/_concat.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchtune/datasets/_concat.py b/torchtune/datasets/_concat.py index 581bdf7ae5..a139441ed9 100644 --- a/torchtune/datasets/_concat.py +++ b/torchtune/datasets/_concat.py @@ -36,8 +36,10 @@ class ConcatDataset(Dataset): Args: datasets (List[Dataset]): A list of datasets to concatenate. Each dataset must be an instance of a class derived from :class:`~torch.utils.data.Dataset`. + Raises: ValueError: if instanse of `PackedDataset` is in `datasets` + Examples: >>> dataset1 = MyCustomDataset(params1) >>> dataset2 = MyCustomDataset(params2) From bfa7d46a0ba3dd55c49cb66fdaed60b663dafe98 Mon Sep 17 00:00:00 2001 From: Mark <75219117+krammnic@users.noreply.github.com> Date: Fri, 11 Oct 2024 05:45:36 -0400 Subject: [PATCH 8/8] fix lint with flake8 _concat.py --- torchtune/datasets/_concat.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtune/datasets/_concat.py b/torchtune/datasets/_concat.py index a139441ed9..304a605641 100644 --- a/torchtune/datasets/_concat.py +++ b/torchtune/datasets/_concat.py @@ -36,10 +36,10 @@ class ConcatDataset(Dataset): Args: datasets (List[Dataset]): A list of datasets to concatenate. Each dataset must be an instance of a class derived from :class:`~torch.utils.data.Dataset`. - + Raises: ValueError: if instanse of `PackedDataset` is in `datasets` - + Examples: >>> dataset1 = MyCustomDataset(params1) >>> dataset2 = MyCustomDataset(params2)