Skip to content

Commit 65fafbe

Browse files
committed
comments, typing, fix missing token_per_repo_id
1 parent 453089f commit 65fafbe

File tree

4 files changed

+80
-37
lines changed

4 files changed

+80
-37
lines changed

src/datasets/features/audio.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,9 @@ def encode_example(self, value: Union[str, dict]) -> dict:
101101
f"An audio sample should have one of 'path' or 'bytes' but they are missing or None in {value}."
102102
)
103103

104-
def decode_example(self, value: dict, token_per_repo_id=None) -> dict:
104+
def decode_example(
105+
self, value: dict, token_per_repo_id: Optional[Dict[str, Union[str, bool, None]]] = None
106+
) -> dict:
105107
"""Decode example audio file into audio data.
106108
107109
Args:
@@ -211,7 +213,9 @@ def path_to_bytes(path):
211213
storage = pa.StructArray.from_arrays([bytes_array, path_array], ["bytes", "path"], mask=bytes_array.is_null())
212214
return array_cast(storage, self.pa_type)
213215

214-
def _decode_non_mp3_path_like(self, path, format=None, token_per_repo_id=None):
216+
def _decode_non_mp3_path_like(
217+
self, path, format=None, token_per_repo_id: Optional[Dict[str, Union[str, bool, None]]] = None
218+
):
215219
try:
216220
import librosa
217221
except ImportError as err:

src/datasets/features/features.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,7 +1222,7 @@ def encode_nested_example(schema, obj, level=0):
12221222
return obj
12231223

12241224

1225-
def decode_nested_example(schema, obj, token_per_repo_id=None):
1225+
def decode_nested_example(schema, obj, token_per_repo_id: Optional[Dict[str, Union[str, bool, None]]] = None):
12261226
"""Decode a nested example.
12271227
This is used since some features (in particular Audio and Image) have some logic during decoding.
12281228
@@ -1613,7 +1613,7 @@ def encode_batch(self, batch):
16131613
encoded_batch[key] = [encode_nested_example(self[key], obj) for obj in column]
16141614
return encoded_batch
16151615

1616-
def decode_example(self, example: dict, token_per_repo_id=None):
1616+
def decode_example(self, example: dict, token_per_repo_id: Optional[Dict[str, Union[str, bool, None]]] = None):
16171617
"""Decode example with custom feature decoding.
16181618
16191619
Args:

src/datasets/formatting/dataset_wrappers/torch_iterable_dataset.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import fsspec
22
import torch
33

4-
from ...iterable_dataset import IterableDataset
4+
from ...iterable_dataset import IterableDataset, _apply_feature_types
55
from ...utils.logging import get_logger
66

77

@@ -46,7 +46,12 @@ def __iter__(self):
4646
)
4747
for shard_idx in shards_indices:
4848
for key, example in self._iter_shard(shard_idx):
49-
yield self._apply_feature_types(example)
49+
if self.features:
50+
yield _apply_feature_types(
51+
example, self.features, token_per_repo_id=self._token_per_repo_id
52+
)
53+
else:
54+
yield example
5055
logger.debug(
5156
f"dataloader worker#{worker_info.id}, ': Finished iterating over {len(shards_indices)}/{self.n_shards} shards."
5257
)

src/datasets/iterable_dataset.py

Lines changed: 65 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,17 @@ def shard_data_sources(self, shard_idx: int) -> "CyclingMultiSourcesExamplesIter
173173

174174

175175
class VerticallyConcatenatedMultiSourcesExamplesIterable(_BaseExamplesIterable):
176+
"""
177+
VerticallyConcatenatedMultiSourcesExamplesIterable simply chains the input iterables.
178+
It doesn't require the examples iterables to always yield the same columns.
179+
Instead, this is handled by the `IterableDataset` class or `TypedExamplesIterable`.
180+
181+
For information, `IterableDataset` merges the features of all the datasets to concatenate into one.
182+
We use `IterableDataset._resolve_features` to obtain the features of all the datasets to concatenate.
183+
184+
Then for each example, `IterableDataset` and `TypedExamplesIterable` automatically fill missing columns with None.
185+
This is done with `_apply_feature_types`.
186+
"""
176187
def __init__(self, ex_iterables: List[_BaseExamplesIterable]):
177188
self.ex_iterables = ex_iterables
178189

@@ -210,6 +221,20 @@ def _check_column_names(column_names: List[str]):
210221

211222

212223
class HorizontallyConcatenatedMultiSourcesExamplesIterable(_BaseExamplesIterable):
224+
"""
225+
HorizontallyConcatenatedMultiSourcesExamplesIterable merges examples together for the input list of iterables.
226+
It also checks that there are no duplicate columns (otherwise we don't know which one to keep).
227+
This check is done once when yielding the first example.
228+
229+
However it doesn't fill missing columns with None.
230+
Instead, this is handled by the `IterableDataset` class or `TypedExamplesIterable`.
231+
232+
For information, `IterableDataset` merges the features of all the datasets to concatenate into one.
233+
We use `IterableDataset._resolve_features` to obtain the features of all the datasets to concatenate.
234+
235+
Then for each example, `IterableDataset` and `TypedExamplesIterable` automatically fill missing columns with None.
236+
This is done with `_apply_feature_types`.
237+
"""
213238
def __init__(self, ex_iterables: List[_BaseExamplesIterable]):
214239
self.ex_iterables = ex_iterables
215240

@@ -565,36 +590,52 @@ def n_shards(self) -> int:
565590
return self.ex_iterable.n_shards
566591

567592

593+
def _apply_feature_types(
594+
example: dict, features: Features, token_per_repo_id: Dict[str, Union[str, bool, None]]
595+
) -> dict:
596+
example = dict(example)
597+
# add missing columns
598+
for column_name in features:
599+
if column_name not in example:
600+
example[column_name] = None
601+
# we encode the example for ClassLabel feature types for example
602+
encoded_example = features.encode_example(example)
603+
# Decode example for Audio feature, e.g.
604+
decoded_example = features.decode_example(encoded_example, token_per_repo_id=token_per_repo_id)
605+
return decoded_example
606+
607+
568608
class TypedExamplesIterable(_BaseExamplesIterable):
569-
def __init__(self, ex_iterable: _BaseExamplesIterable, features: Features):
609+
def __init__(
610+
self,
611+
ex_iterable: _BaseExamplesIterable,
612+
features: Features,
613+
token_per_repo_id: Dict[str, Union[str, bool, None]],
614+
):
570615
self.ex_iterable = ex_iterable
571616
self.features = features
617+
self.token_per_repo_id = token_per_repo_id
572618

573619
def __iter__(self):
620+
# Then for each example, `TypedExamplesIterable` automatically fills missing columns with None.
621+
# This is done with `_apply_feature_types`.
574622
for key, example in self.ex_iterable:
575-
example = dict(example)
576-
# add missing columns
577-
for column_name in self.features:
578-
if column_name not in example:
579-
example[column_name] = None
580-
# we encode the example for ClassLabel feature types for example
581-
encoded_example = self.features.encode_example(example)
582-
# Decode example for Audio feature, e.g.
583-
decoded_example = self.features.decode_example(encoded_example)
584-
yield key, decoded_example
623+
yield key, _apply_feature_types(example, self.features, token_per_repo_id=self.token_per_repo_id)
585624

586625
def shuffle_data_sources(self, generator: np.random.Generator) -> "TypedExamplesIterable":
587626
"""Shuffle the wrapped examples iterable."""
588627
return TypedExamplesIterable(
589628
self.ex_iterable.shuffle_data_sources(generator),
590629
features=self.features,
630+
token_per_repo_id=self.token_per_repo_id,
591631
)
592632

593633
def shard_data_sources(self, shard_idx: int) -> "TypedExamplesIterable":
594634
"""Keep only the requested shard."""
595635
return TypedExamplesIterable(
596636
self.ex_iterable.shard_data_sources(shard_idx),
597637
features=self.features,
638+
token_per_repo_id=self.token_per_repo_id,
598639
)
599640

600641
@property
@@ -637,7 +678,7 @@ def __init__(
637678
self._format_type = format_type
638679
self._shuffling = shuffling
639680
self._epoch = 0
640-
self._token_per_repo_id = token_per_repo_id or {}
681+
self._token_per_repo_id: Dict[str, Union[str, bool, None]] = token_per_repo_id or {}
641682

642683
def _head(self, n=5):
643684
return _examples_to_batch([x for key, x in islice(self._iter(), n)])
@@ -671,24 +712,14 @@ def _iter_shard(self, shard_idx: int):
671712
ex_iterable = self._ex_iterable
672713
yield from ex_iterable.shard_data_sources(shard_idx)
673714

674-
def _apply_feature_types(self, example):
675-
if self.features:
676-
example = dict(example)
677-
# add missing columns
678-
for column_name in self.features:
679-
if column_name not in example:
680-
example[column_name] = None
681-
# we encode the example for ClassLabel feature types for example
682-
encoded_example = self.features.encode_example(example)
683-
# Decode example for Audio feature, e.g.
684-
decoded_example = self.features.decode_example(encoded_example, token_per_repo_id=self._token_per_repo_id)
685-
return decoded_example
686-
else:
687-
return example
688-
689715
def __iter__(self):
690716
for key, example in self._iter():
691-
yield self._apply_feature_types(example)
717+
if self.features:
718+
# `IterableDataset` automatically fills missing columns with None.
719+
# This is done with `_apply_feature_types`.
720+
yield _apply_feature_types(example, self.features, token_per_repo_id=self._token_per_repo_id)
721+
else:
722+
yield example
692723

693724
def with_format(
694725
self,
@@ -790,7 +821,7 @@ def map(
790821
info = self._info.copy()
791822
info.features = None
792823
ex_iterable = MappedExamplesIterable(
793-
TypedExamplesIterable(self._ex_iterable, self._info.features)
824+
TypedExamplesIterable(self._ex_iterable, self._info.features, token_per_repo_id=self._token_per_repo_id)
794825
if self._info.features is not None
795826
else self._ex_iterable,
796827
function=function,
@@ -859,7 +890,7 @@ def filter(
859890

860891
# We need the examples to be decoded for certain feature types like Image or Audio, so we use TypedExamplesIterable here
861892
ex_iterable = FilteredExamplesIterable(
862-
TypedExamplesIterable(self._ex_iterable, self._info.features)
893+
TypedExamplesIterable(self._ex_iterable, self._info.features, token_per_repo_id=self._token_per_repo_id)
863894
if self._info.features is not None
864895
else self._ex_iterable,
865896
function=function,
@@ -1325,6 +1356,7 @@ def _concatenate_iterable_datasets(
13251356
else:
13261357
ex_iterable = HorizontallyConcatenatedMultiSourcesExamplesIterable(ex_iterables)
13271358
# Set new info - we update the features
1359+
# setting the features also ensures to fill missing columns with None
13281360
if info is None:
13291361
info = DatasetInfo.from_merge([d.info for d in dsets])
13301362
else:
@@ -1358,8 +1390,9 @@ def _interleave_iterable_datasets(
13581390
Output:
13591391
:class:`datasets.IterableDataset`
13601392
"""
1393+
# TODO(QL): merge the features as in _concatenate_iterable_datasets() and don't use TypedExamplesIterable
13611394
ex_iterables = [
1362-
TypedExamplesIterable(d._ex_iterable, d.features)
1395+
TypedExamplesIterable(d._ex_iterable, d.features, token_per_repo_id=d._token_per_repo_id)
13631396
if not isinstance(d._ex_iterable, TypedExamplesIterable) and d.features is not None
13641397
else d._ex_iterable
13651398
for d in datasets
@@ -1373,6 +1406,7 @@ def _interleave_iterable_datasets(
13731406
ex_iterables, generator=generator, probabilities=probabilities
13741407
)
13751408
# Set new info - we reset the features
1409+
# TODO(QL): merge the features as in _concatenate_iterable_datasets() and use them here
13761410
if info is None:
13771411
info = DatasetInfo.from_merge([d.info for d in datasets])
13781412
info.features = None

0 commit comments

Comments
 (0)