@@ -130,6 +130,10 @@ def __iter__(self) -> Iterator[Tuple[Key, dict]]:
130130 def iter_arrow (self ) -> Optional [Callable [[], Iterator [Tuple [Key , pa .Table ]]]]:
131131 return None
132132
133+ @property
134+ def is_typed (self ) -> bool :
135+ return False
136+
133137 def shuffle_data_sources (self , generator : np .random .Generator ) -> "_BaseExamplesIterable" :
134138 """
135139 Either shuffle the shards/sources of the dataset, or propagate the shuffling to the underlying iterable.
@@ -393,6 +397,10 @@ def __init__(self, ex_iterable: _BaseExamplesIterable, batch_size: Optional[int]
393397 def iter_arrow (self ):
394398 return self ._iter_arrow
395399
400+ @property
401+ def is_typed (self ):
402+ return self .ex_iterable .is_typed
403+
396404 def _init_state_dict (self ) -> dict :
397405 self ._state_dict = {
398406 "ex_iterable" : self .ex_iterable ._init_state_dict (),
@@ -518,6 +526,10 @@ def iter_arrow(self):
518526 if self .ex_iterable .iter_arrow :
519527 return self ._iter_arrow
520528
529+ @property
530+ def is_typed (self ):
531+ return self .ex_iterable .is_typed
532+
521533 def _init_state_dict (self ) -> dict :
522534 self ._state_dict = self .ex_iterable ._init_state_dict ()
523535 return self ._state_dict
@@ -550,6 +562,10 @@ def __init__(self, ex_iterable: _BaseExamplesIterable, step: int, offset: int):
550562 self .offset = offset
551563 # TODO(QL): implement iter_arrow
552564
565+ @property
566+ def is_typed (self ):
567+ return self .ex_iterable .is_typed
568+
553569 def _init_state_dict (self ) -> dict :
554570 self ._state_dict = self .ex_iterable ._init_state_dict ()
555571 return self ._state_dict
@@ -593,6 +609,10 @@ def __init__(
593609 self .bool_strategy_func = np .all if (stopping_strategy == "all_exhausted" ) else np .any
594610 # TODO(QL): implement iter_arrow
595611
612+ @property
613+ def is_typed (self ):
614+ return self .ex_iterables [0 ].is_typed
615+
596616 def _get_indices_iterator (self ):
597617 # this is an infinite iterator to keep track of which iterator we want to pick examples from
598618 ex_iterable_idx = self ._state_dict ["ex_iterable_idx" ] if self ._state_dict else 0
@@ -687,6 +707,10 @@ def __init__(self, ex_iterables: List[_BaseExamplesIterable]):
687707 super ().__init__ ()
688708 self .ex_iterables = ex_iterables
689709
710+ @property
711+ def is_typed (self ):
712+ return self .ex_iterables [0 ].is_typed
713+
690714 @property
691715 def iter_arrow (self ):
692716 if all (ex_iterable .iter_arrow is not None for ex_iterable in self .ex_iterables ):
@@ -767,6 +791,10 @@ def __init__(self, ex_iterables: List[_BaseExamplesIterable]):
767791 self .ex_iterables = ex_iterables
768792 # TODO(QL): implement iter_arrow
769793
794+ @property
795+ def is_typed (self ):
796+ return self .ex_iterables [0 ].is_typed
797+
770798 def _init_state_dict (self ) -> dict :
771799 self ._state_dict = {"ex_iterables" : [ex_iterable ._init_state_dict () for ex_iterable in self .ex_iterables ]}
772800 return self ._state_dict
@@ -826,6 +854,10 @@ def __init__(
826854 self .probabilities = probabilities
827855 # TODO(QL): implement iter_arrow
828856
857+ @property
858+ def is_typed (self ):
859+ return self .ex_iterables [0 ].is_typed
860+
829861 def _get_indices_iterator (self ):
830862 rng = deepcopy (self .generator )
831863 num_sources = len (self .ex_iterables )
@@ -929,6 +961,10 @@ def iter_arrow(self):
929961 if self .formatting and self .formatting .format_type == "arrow" :
930962 return self ._iter_arrow
931963
964+ @property
965+ def is_typed (self ):
966+ return False
967+
932968 def _init_state_dict (self ) -> dict :
933969 self ._state_dict = {
934970 "ex_iterable" : self .ex_iterable ._init_state_dict (),
@@ -1185,6 +1221,10 @@ def iter_arrow(self):
11851221 if self .formatting and self .formatting .format_type == "arrow" :
11861222 return self ._iter_arrow
11871223
1224+ @property
1225+ def is_typed (self ):
1226+ return self .ex_iterable .is_typed
1227+
11881228 def _init_state_dict (self ) -> dict :
11891229 self ._state_dict = {
11901230 "ex_iterable" : self .ex_iterable ._init_state_dict (),
@@ -1365,6 +1405,10 @@ def __init__(self, ex_iterable: _BaseExamplesIterable, buffer_size: int, generat
13651405 self .generator = generator
13661406 # TODO(QL): implement iter_arrow
13671407
1408+ @property
1409+ def is_typed (self ):
1410+ return self .ex_iterable .is_typed
1411+
13681412 def _init_state_dict (self ) -> dict :
13691413 self ._state_dict = self .ex_iterable ._init_state_dict ()
13701414 self ._original_state_dict = self .state_dict ()
@@ -1435,6 +1479,10 @@ def __init__(
14351479 self .split_when_sharding = split_when_sharding
14361480 # TODO(QL): implement iter_arrow
14371481
1482+ @property
1483+ def is_typed (self ):
1484+ return self .ex_iterable .is_typed
1485+
14381486 def _init_state_dict (self ) -> dict :
14391487 self ._state_dict = {"skipped" : False , "ex_iterable" : self .ex_iterable ._init_state_dict ()}
14401488 return self ._state_dict
@@ -1498,6 +1546,10 @@ def __init__(
14981546 self .split_when_sharding = split_when_sharding
14991547 # TODO(QL): implement iter_arrow
15001548
1549+ @property
1550+ def is_typed (self ):
1551+ return self .ex_iterable .is_typed
1552+
15011553 def _init_state_dict (self ) -> dict :
15021554 self ._state_dict = {"num_taken" : 0 , "ex_iterable" : self .ex_iterable ._init_state_dict ()}
15031555 return self ._state_dict
@@ -1600,6 +1652,10 @@ def iter_arrow(self):
16001652 if self .ex_iterable .iter_arrow is not None :
16011653 return self ._iter_arrow
16021654
1655+ @property
1656+ def is_typed (self ):
1657+ return True
1658+
16031659 def _init_state_dict (self ) -> dict :
16041660 self ._state_dict = self .ex_iterable ._init_state_dict ()
16051661 return self ._state_dict
@@ -1914,7 +1970,7 @@ def _iter_pytorch(self):
19141970 return
19151971 else :
19161972 for key , example in ex_iterable :
1917- if self .features :
1973+ if self .features and not ex_iterable . is_typed :
19181974 # `IterableDataset` automatically fills missing columns with None.
19191975 # This is done with `_apply_feature_types_on_example`.
19201976 example = _apply_feature_types_on_example (
@@ -2010,7 +2066,7 @@ def __iter__(self):
20102066 return
20112067
20122068 for key , example in ex_iterable :
2013- if self .features :
2069+ if self .features and not ex_iterable . is_typed :
20142070 # `IterableDataset` automatically fills missing columns with None.
20152071 # This is done with `_apply_feature_types_on_example`.
20162072 example = _apply_feature_types_on_example (
@@ -2052,7 +2108,7 @@ def iter(self, batch_size: int, drop_last_batch: bool = False):
20522108 if drop_last_batch and len (examples ) < batch_size : # ignore last batch
20532109 return
20542110 batch = _examples_to_batch (examples )
2055- if self .features :
2111+ if self .features and not ex_iterable . is_typed :
20562112 # `IterableDataset` automatically fills missing columns with None.
20572113 # This is done with `_apply_feature_types_on_batch`.
20582114 batch = _apply_feature_types_on_batch (batch , self .features , token_per_repo_id = self ._token_per_repo_id )
@@ -2405,10 +2461,6 @@ def filter(
24052461 if isinstance (input_columns , str ):
24062462 input_columns = [input_columns ]
24072463
2408- # TODO(QL): keep the features (right now if we keep it it would call decode_example again on an already decoded example)
2409- info = copy .deepcopy (self ._info )
2410- info .features = None
2411-
24122464 # We need the examples to be decoded for certain feature types like Image or Audio, so we use TypedExamplesIterable here
24132465 ex_iterable = FilteredExamplesIterable (
24142466 TypedExamplesIterable (self ._ex_iterable , self ._info .features , token_per_repo_id = self ._token_per_repo_id )
@@ -2424,7 +2476,7 @@ def filter(
24242476 )
24252477 return IterableDataset (
24262478 ex_iterable = ex_iterable ,
2427- info = info ,
2479+ info = self . _info ,
24282480 split = self ._split ,
24292481 formatting = self ._formatting ,
24302482 shuffling = copy .deepcopy (self ._shuffling ),
0 commit comments