@@ -173,6 +173,17 @@ def shard_data_sources(self, shard_idx: int) -> "CyclingMultiSourcesExamplesIter
173173
174174
175175class VerticallyConcatenatedMultiSourcesExamplesIterable (_BaseExamplesIterable ):
176+ """
177+ VerticallyConcatenatedMultiSourcesExamplesIterable simply chains the input iterables.
178+ It doesn't require the examples iterables to always yield the same columns.
179+ Instead, this is handled by the `IterableDataset` class or `TypedExamplesIterable`.
180+
181+ For information, `IterableDataset` merges the features of all the datasets to concatenate into one.
182+ We use `IterableDataset._resolve_features` to obtain the features of all the datasets to concatenate.
183+
184+ Then for each example, `IterableDataset` and `TypedExamplesIterable` automatically fill missing columns with None.
185+ This is done with `_apply_feature_types`.
186+ """
176187 def __init__ (self , ex_iterables : List [_BaseExamplesIterable ]):
177188 self .ex_iterables = ex_iterables
178189
@@ -210,6 +221,20 @@ def _check_column_names(column_names: List[str]):
210221
211222
212223class HorizontallyConcatenatedMultiSourcesExamplesIterable (_BaseExamplesIterable ):
224+ """
225+ HorizontallyConcatenatedMultiSourcesExamplesIterable merges examples together for the input list of iterables.
226+ It also checks that there are no duplicate columns (otherwise we don't know which one to keep).
227+ This check is done once when yielding the first example.
228+
229+ However it doesn't fill missing columns with None.
230+ Instead, this is handled by the `IterableDataset` class or `TypedExamplesIterable`.
231+
232+ For information, `IterableDataset` merges the features of all the datasets to concatenate into one.
233+ We use `IterableDataset._resolve_features` to obtain the features of all the datasets to concatenate.
234+
235+ Then for each example, `IterableDataset` and `TypedExamplesIterable` automatically fill missing columns with None.
236+ This is done with `_apply_feature_types`.
237+ """
213238 def __init__ (self , ex_iterables : List [_BaseExamplesIterable ]):
214239 self .ex_iterables = ex_iterables
215240
@@ -565,36 +590,52 @@ def n_shards(self) -> int:
565590 return self .ex_iterable .n_shards
566591
567592
593+ def _apply_feature_types (
594+ example : dict , features : Features , token_per_repo_id : Dict [str , Union [str , bool , None ]]
595+ ) -> dict :
596+ example = dict (example )
597+ # add missing columns
598+ for column_name in features :
599+ if column_name not in example :
600+ example [column_name ] = None
601+ # we encode the example for ClassLabel feature types for example
602+ encoded_example = features .encode_example (example )
603+ # Decode example for Audio feature, e.g.
604+ decoded_example = features .decode_example (encoded_example , token_per_repo_id = token_per_repo_id )
605+ return decoded_example
606+
607+
568608class TypedExamplesIterable (_BaseExamplesIterable ):
569- def __init__ (self , ex_iterable : _BaseExamplesIterable , features : Features ):
609+ def __init__ (
610+ self ,
611+ ex_iterable : _BaseExamplesIterable ,
612+ features : Features ,
613+ token_per_repo_id : Dict [str , Union [str , bool , None ]],
614+ ):
570615 self .ex_iterable = ex_iterable
571616 self .features = features
617+ self .token_per_repo_id = token_per_repo_id
572618
573619 def __iter__ (self ):
620+ # Then for each example, `TypedExamplesIterable` automatically fills missing columns with None.
621+ # This is done with `_apply_feature_types`.
574622 for key , example in self .ex_iterable :
575- example = dict (example )
576- # add missing columns
577- for column_name in self .features :
578- if column_name not in example :
579- example [column_name ] = None
580- # we encode the example for ClassLabel feature types for example
581- encoded_example = self .features .encode_example (example )
582- # Decode example for Audio feature, e.g.
583- decoded_example = self .features .decode_example (encoded_example )
584- yield key , decoded_example
623+ yield key , _apply_feature_types (example , self .features , token_per_repo_id = self .token_per_repo_id )
585624
586625 def shuffle_data_sources (self , generator : np .random .Generator ) -> "TypedExamplesIterable" :
587626 """Shuffle the wrapped examples iterable."""
588627 return TypedExamplesIterable (
589628 self .ex_iterable .shuffle_data_sources (generator ),
590629 features = self .features ,
630+ token_per_repo_id = self .token_per_repo_id ,
591631 )
592632
593633 def shard_data_sources (self , shard_idx : int ) -> "TypedExamplesIterable" :
594634 """Keep only the requested shard."""
595635 return TypedExamplesIterable (
596636 self .ex_iterable .shard_data_sources (shard_idx ),
597637 features = self .features ,
638+ token_per_repo_id = self .token_per_repo_id ,
598639 )
599640
600641 @property
@@ -637,7 +678,7 @@ def __init__(
637678 self ._format_type = format_type
638679 self ._shuffling = shuffling
639680 self ._epoch = 0
640- self ._token_per_repo_id = token_per_repo_id or {}
681+ self ._token_per_repo_id : Dict [ str , Union [ str , bool , None ]] = token_per_repo_id or {}
641682
642683 def _head (self , n = 5 ):
643684 return _examples_to_batch ([x for key , x in islice (self ._iter (), n )])
@@ -671,24 +712,14 @@ def _iter_shard(self, shard_idx: int):
671712 ex_iterable = self ._ex_iterable
672713 yield from ex_iterable .shard_data_sources (shard_idx )
673714
674- def _apply_feature_types (self , example ):
675- if self .features :
676- example = dict (example )
677- # add missing columns
678- for column_name in self .features :
679- if column_name not in example :
680- example [column_name ] = None
681- # we encode the example for ClassLabel feature types for example
682- encoded_example = self .features .encode_example (example )
683- # Decode example for Audio feature, e.g.
684- decoded_example = self .features .decode_example (encoded_example , token_per_repo_id = self ._token_per_repo_id )
685- return decoded_example
686- else :
687- return example
688-
689715 def __iter__ (self ):
690716 for key , example in self ._iter ():
691- yield self ._apply_feature_types (example )
717+ if self .features :
718+ # `IterableDataset` automatically fills missing columns with None.
719+ # This is done with `_apply_feature_types`.
720+ yield _apply_feature_types (example , self .features , token_per_repo_id = self ._token_per_repo_id )
721+ else :
722+ yield example
692723
693724 def with_format (
694725 self ,
@@ -790,7 +821,7 @@ def map(
790821 info = self ._info .copy ()
791822 info .features = None
792823 ex_iterable = MappedExamplesIterable (
793- TypedExamplesIterable (self ._ex_iterable , self ._info .features )
824+ TypedExamplesIterable (self ._ex_iterable , self ._info .features , token_per_repo_id = self . _token_per_repo_id )
794825 if self ._info .features is not None
795826 else self ._ex_iterable ,
796827 function = function ,
@@ -859,7 +890,7 @@ def filter(
859890
860891 # We need the examples to be decoded for certain feature types like Image or Audio, so we use TypedExamplesIterable here
861892 ex_iterable = FilteredExamplesIterable (
862- TypedExamplesIterable (self ._ex_iterable , self ._info .features )
893+ TypedExamplesIterable (self ._ex_iterable , self ._info .features , token_per_repo_id = self . _token_per_repo_id )
863894 if self ._info .features is not None
864895 else self ._ex_iterable ,
865896 function = function ,
@@ -1325,6 +1356,7 @@ def _concatenate_iterable_datasets(
13251356 else :
13261357 ex_iterable = HorizontallyConcatenatedMultiSourcesExamplesIterable (ex_iterables )
13271358 # Set new info - we update the features
1359+ # setting the features also ensures to fill missing columns with None
13281360 if info is None :
13291361 info = DatasetInfo .from_merge ([d .info for d in dsets ])
13301362 else :
@@ -1358,8 +1390,9 @@ def _interleave_iterable_datasets(
13581390 Output:
13591391 :class:`datasets.IterableDataset`
13601392 """
1393+ # TODO(QL): merge the features as in _concatenate_iterable_datasets() and don't use TypedExamplesIterable
13611394 ex_iterables = [
1362- TypedExamplesIterable (d ._ex_iterable , d .features )
1395+ TypedExamplesIterable (d ._ex_iterable , d .features , token_per_repo_id = d . _token_per_repo_id )
13631396 if not isinstance (d ._ex_iterable , TypedExamplesIterable ) and d .features is not None
13641397 else d ._ex_iterable
13651398 for d in datasets
@@ -1373,6 +1406,7 @@ def _interleave_iterable_datasets(
13731406 ex_iterables , generator = generator , probabilities = probabilities
13741407 )
13751408 # Set new info - we reset the features
1409+ # TODO(QL): merge the features as in _concatenate_iterable_datasets() and use them here
13761410 if info is None :
13771411 info = DatasetInfo .from_merge ([d .info for d in datasets ])
13781412 info .features = None
0 commit comments