Skip to content

Commit 16a121d

Browse files
alex-hhlhoestq
andauthored
Preserve features in iterable dataset.filter (#7209)
* add is_typed property to example iterables to prevent applying decode_examples multiple times * Update src/datasets/iterable_dataset.py Co-authored-by: Quentin Lhoest <[email protected]> --------- Co-authored-by: Quentin Lhoest <[email protected]>
1 parent afe875a commit 16a121d

File tree

1 file changed

+60
-8
lines changed

1 file changed

+60
-8
lines changed

src/datasets/iterable_dataset.py

Lines changed: 60 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,10 @@ def __iter__(self) -> Iterator[Tuple[Key, dict]]:
130130
def iter_arrow(self) -> Optional[Callable[[], Iterator[Tuple[Key, pa.Table]]]]:
131131
return None
132132

133+
@property
134+
def is_typed(self) -> bool:
135+
return False
136+
133137
def shuffle_data_sources(self, generator: np.random.Generator) -> "_BaseExamplesIterable":
134138
"""
135139
Either shuffle the shards/sources of the dataset, or propagate the shuffling to the underlying iterable.
@@ -393,6 +397,10 @@ def __init__(self, ex_iterable: _BaseExamplesIterable, batch_size: Optional[int]
393397
def iter_arrow(self):
394398
return self._iter_arrow
395399

400+
@property
401+
def is_typed(self):
402+
return self.ex_iterable.is_typed
403+
396404
def _init_state_dict(self) -> dict:
397405
self._state_dict = {
398406
"ex_iterable": self.ex_iterable._init_state_dict(),
@@ -518,6 +526,10 @@ def iter_arrow(self):
518526
if self.ex_iterable.iter_arrow:
519527
return self._iter_arrow
520528

529+
@property
530+
def is_typed(self):
531+
return self.ex_iterable.is_typed
532+
521533
def _init_state_dict(self) -> dict:
522534
self._state_dict = self.ex_iterable._init_state_dict()
523535
return self._state_dict
@@ -550,6 +562,10 @@ def __init__(self, ex_iterable: _BaseExamplesIterable, step: int, offset: int):
550562
self.offset = offset
551563
# TODO(QL): implement iter_arrow
552564

565+
@property
566+
def is_typed(self):
567+
return self.ex_iterable.is_typed
568+
553569
def _init_state_dict(self) -> dict:
554570
self._state_dict = self.ex_iterable._init_state_dict()
555571
return self._state_dict
@@ -593,6 +609,10 @@ def __init__(
593609
self.bool_strategy_func = np.all if (stopping_strategy == "all_exhausted") else np.any
594610
# TODO(QL): implement iter_arrow
595611

612+
@property
613+
def is_typed(self):
614+
return self.ex_iterables[0].is_typed
615+
596616
def _get_indices_iterator(self):
597617
# this is an infinite iterator to keep track of which iterator we want to pick examples from
598618
ex_iterable_idx = self._state_dict["ex_iterable_idx"] if self._state_dict else 0
@@ -687,6 +707,10 @@ def __init__(self, ex_iterables: List[_BaseExamplesIterable]):
687707
super().__init__()
688708
self.ex_iterables = ex_iterables
689709

710+
@property
711+
def is_typed(self):
712+
return self.ex_iterables[0].is_typed
713+
690714
@property
691715
def iter_arrow(self):
692716
if all(ex_iterable.iter_arrow is not None for ex_iterable in self.ex_iterables):
@@ -767,6 +791,10 @@ def __init__(self, ex_iterables: List[_BaseExamplesIterable]):
767791
self.ex_iterables = ex_iterables
768792
# TODO(QL): implement iter_arrow
769793

794+
@property
795+
def is_typed(self):
796+
return self.ex_iterables[0].is_typed
797+
770798
def _init_state_dict(self) -> dict:
771799
self._state_dict = {"ex_iterables": [ex_iterable._init_state_dict() for ex_iterable in self.ex_iterables]}
772800
return self._state_dict
@@ -826,6 +854,10 @@ def __init__(
826854
self.probabilities = probabilities
827855
# TODO(QL): implement iter_arrow
828856

857+
@property
858+
def is_typed(self):
859+
return self.ex_iterables[0].is_typed
860+
829861
def _get_indices_iterator(self):
830862
rng = deepcopy(self.generator)
831863
num_sources = len(self.ex_iterables)
@@ -929,6 +961,10 @@ def iter_arrow(self):
929961
if self.formatting and self.formatting.format_type == "arrow":
930962
return self._iter_arrow
931963

964+
@property
965+
def is_typed(self):
966+
return False
967+
932968
def _init_state_dict(self) -> dict:
933969
self._state_dict = {
934970
"ex_iterable": self.ex_iterable._init_state_dict(),
@@ -1185,6 +1221,10 @@ def iter_arrow(self):
11851221
if self.formatting and self.formatting.format_type == "arrow":
11861222
return self._iter_arrow
11871223

1224+
@property
1225+
def is_typed(self):
1226+
return self.ex_iterable.is_typed
1227+
11881228
def _init_state_dict(self) -> dict:
11891229
self._state_dict = {
11901230
"ex_iterable": self.ex_iterable._init_state_dict(),
@@ -1365,6 +1405,10 @@ def __init__(self, ex_iterable: _BaseExamplesIterable, buffer_size: int, generat
13651405
self.generator = generator
13661406
# TODO(QL): implement iter_arrow
13671407

1408+
@property
1409+
def is_typed(self):
1410+
return self.ex_iterable.is_typed
1411+
13681412
def _init_state_dict(self) -> dict:
13691413
self._state_dict = self.ex_iterable._init_state_dict()
13701414
self._original_state_dict = self.state_dict()
@@ -1435,6 +1479,10 @@ def __init__(
14351479
self.split_when_sharding = split_when_sharding
14361480
# TODO(QL): implement iter_arrow
14371481

1482+
@property
1483+
def is_typed(self):
1484+
return self.ex_iterable.is_typed
1485+
14381486
def _init_state_dict(self) -> dict:
14391487
self._state_dict = {"skipped": False, "ex_iterable": self.ex_iterable._init_state_dict()}
14401488
return self._state_dict
@@ -1498,6 +1546,10 @@ def __init__(
14981546
self.split_when_sharding = split_when_sharding
14991547
# TODO(QL): implement iter_arrow
15001548

1549+
@property
1550+
def is_typed(self):
1551+
return self.ex_iterable.is_typed
1552+
15011553
def _init_state_dict(self) -> dict:
15021554
self._state_dict = {"num_taken": 0, "ex_iterable": self.ex_iterable._init_state_dict()}
15031555
return self._state_dict
@@ -1600,6 +1652,10 @@ def iter_arrow(self):
16001652
if self.ex_iterable.iter_arrow is not None:
16011653
return self._iter_arrow
16021654

1655+
@property
1656+
def is_typed(self):
1657+
return True
1658+
16031659
def _init_state_dict(self) -> dict:
16041660
self._state_dict = self.ex_iterable._init_state_dict()
16051661
return self._state_dict
@@ -1914,7 +1970,7 @@ def _iter_pytorch(self):
19141970
return
19151971
else:
19161972
for key, example in ex_iterable:
1917-
if self.features:
1973+
if self.features and not ex_iterable.is_typed:
19181974
# `IterableDataset` automatically fills missing columns with None.
19191975
# This is done with `_apply_feature_types_on_example`.
19201976
example = _apply_feature_types_on_example(
@@ -2010,7 +2066,7 @@ def __iter__(self):
20102066
return
20112067

20122068
for key, example in ex_iterable:
2013-
if self.features:
2069+
if self.features and not ex_iterable.is_typed:
20142070
# `IterableDataset` automatically fills missing columns with None.
20152071
# This is done with `_apply_feature_types_on_example`.
20162072
example = _apply_feature_types_on_example(
@@ -2052,7 +2108,7 @@ def iter(self, batch_size: int, drop_last_batch: bool = False):
20522108
if drop_last_batch and len(examples) < batch_size: # ignore last batch
20532109
return
20542110
batch = _examples_to_batch(examples)
2055-
if self.features:
2111+
if self.features and not ex_iterable.is_typed:
20562112
# `IterableDataset` automatically fills missing columns with None.
20572113
# This is done with `_apply_feature_types_on_batch`.
20582114
batch = _apply_feature_types_on_batch(batch, self.features, token_per_repo_id=self._token_per_repo_id)
@@ -2405,10 +2461,6 @@ def filter(
24052461
if isinstance(input_columns, str):
24062462
input_columns = [input_columns]
24072463

2408-
# TODO(QL): keep the features (right now if we keep it it would call decode_example again on an already decoded example)
2409-
info = copy.deepcopy(self._info)
2410-
info.features = None
2411-
24122464
# We need the examples to be decoded for certain feature types like Image or Audio, so we use TypedExamplesIterable here
24132465
ex_iterable = FilteredExamplesIterable(
24142466
TypedExamplesIterable(self._ex_iterable, self._info.features, token_per_repo_id=self._token_per_repo_id)
@@ -2424,7 +2476,7 @@ def filter(
24242476
)
24252477
return IterableDataset(
24262478
ex_iterable=ex_iterable,
2427-
info=info,
2479+
info=self._info,
24282480
split=self._split,
24292481
formatting=self._formatting,
24302482
shuffling=copy.deepcopy(self._shuffling),

0 commit comments

Comments
 (0)