@@ -176,16 +176,8 @@ def __init__(self, ex_iterables: List[_BaseExamplesIterable]):
176176 self .ex_iterables = ex_iterables
177177
178178 def __iter__ (self ):
179- _example_from_previous_iterable = None
180179 for ex_iterable in self .ex_iterables :
181- for example_idx , (key , example ) in enumerate (ex_iterable ):
182- if example_idx == 0 and _example_from_previous_iterable is not None :
183- if sorted (example ) != sorted (_example_from_previous_iterable ):
184- raise ValueError (
185- f"The examples iterables must have the same columns but one has { sorted (_example_from_previous_iterable )} and the next has { sorted (example )} ."
186- )
187- yield key , example
188- _example_from_previous_iterable = example
180+ yield from ex_iterable
189181
190182 def shuffle_data_sources (
191183 self , generator : np .random .Generator
@@ -578,6 +570,11 @@ def __init__(self, ex_iterable: _BaseExamplesIterable, features: Features):
578570
579571 def __iter__ (self ):
580572 for key , example in self .ex_iterable :
573+ example = dict (example )
574+ # add missing columns
575+ for column_name in self .features :
576+ if column_name not in example :
577+ example [column_name ] = None
581578 # we encode the example for ClassLabel feature types for example
582579 encoded_example = self .features .encode_example (example )
583580 # Decode example for Audio feature, e.g.
@@ -674,6 +671,11 @@ def _iter_shard(self, shard_idx: int):
674671
675672 def _apply_feature_types (self , example ):
676673 if self .features :
674+ example = dict (example )
675+ # add missing columns
676+ for column_name in self .features :
677+ if column_name not in example :
678+ example [column_name ] = None
677679 # we encode the example for ClassLabel feature types for example
678680 encoded_example = self .features .encode_example (example )
679681 # Decode example for Audio feature, e.g.
@@ -1214,6 +1216,25 @@ def cast(
12141216 shuffling = copy .deepcopy (self ._shuffling ),
12151217 token_per_repo_id = self ._token_per_repo_id ,
12161218 )
1219+
1220+ def _resolve_features (self ):
1221+ if self .features is not None :
1222+ return self
1223+ elif isinstance (self ._ex_iterable , TypedExamplesIterable ):
1224+ features = self ._ex_iterable .features
1225+ else :
1226+ features = _infer_features_from_batch (self ._head ())
1227+ info = self .info .copy ()
1228+ info .features = features
1229+ return iterable_dataset (
1230+ ex_iterable = self ._ex_iterable ,
1231+ info = info ,
1232+ split = self ._split ,
1233+ format_type = self ._format_type ,
1234+ shuffling = copy .deepcopy (self ._shuffling ),
1235+ token_per_repo_id = self ._token_per_repo_id ,
1236+ )
1237+
12171238
12181239
12191240def iterable_dataset (
@@ -1265,20 +1286,21 @@ def _concatenate_iterable_datasets(
12651286 >>> ds3 = _concatenate_iterable_datasets([ds1, ds2])
12661287 ```
12671288 """
1268- ex_iterables = [
1269- TypedExamplesIterable (d ._ex_iterable , d .features )
1270- if not isinstance (d ._ex_iterable , TypedExamplesIterable ) and d .features is not None
1271- else d ._ex_iterable
1272- for d in dsets
1273- ]
1289+ dsets = [d ._resolve_features () for d in dsets ]
1290+ features = Features ()
1291+ for dset in dsets :
1292+ features .update (dset .features )
1293+ ex_iterables = [d ._ex_iterable for d in dsets ]
12741294 if axis == 0 :
12751295 ex_iterable = VerticallyConcatenatedMultiSourcesExamplesIterable (ex_iterables )
12761296 else :
12771297 ex_iterable = HorizontallyConcatenatedMultiSourcesExamplesIterable (ex_iterables )
1278- # Set new info - we reset the features
1298+ # Set new info - we update the features
12791299 if info is None :
12801300 info = DatasetInfo .from_merge ([d .info for d in dsets ])
1281- info .features = None
1301+ else :
1302+ info = info .copy ()
1303+ info .features = features
12821304 # Get all the auth tokens per repository - in case the datasets come from different private repositories
12831305 token_per_repo_id = {repo_id : token for dataset in dsets for repo_id , token in dataset ._token_per_repo_id .items ()}
12841306 # Return new daset
0 commit comments