Skip to content

Commit 15c286e

Browse files
committed
infer features
1 parent 590354e commit 15c286e

File tree

1 file changed

+39
-17
lines changed

1 file changed

+39
-17
lines changed

src/datasets/iterable_dataset.py

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -176,16 +176,8 @@ def __init__(self, ex_iterables: List[_BaseExamplesIterable]):
176176
self.ex_iterables = ex_iterables
177177

178178
def __iter__(self):
179-
_example_from_previous_iterable = None
180179
for ex_iterable in self.ex_iterables:
181-
for example_idx, (key, example) in enumerate(ex_iterable):
182-
if example_idx == 0 and _example_from_previous_iterable is not None:
183-
if sorted(example) != sorted(_example_from_previous_iterable):
184-
raise ValueError(
185-
f"The examples iterables must have the same columns but one has {sorted(_example_from_previous_iterable)} and the next has {sorted(example)}."
186-
)
187-
yield key, example
188-
_example_from_previous_iterable = example
180+
yield from ex_iterable
189181

190182
def shuffle_data_sources(
191183
self, generator: np.random.Generator
@@ -578,6 +570,11 @@ def __init__(self, ex_iterable: _BaseExamplesIterable, features: Features):
578570

579571
def __iter__(self):
580572
for key, example in self.ex_iterable:
573+
example = dict(example)
574+
# add missing columns
575+
for column_name in self.features:
576+
if column_name not in example:
577+
example[column_name] = None
581578
# we encode the example for ClassLabel feature types for example
582579
encoded_example = self.features.encode_example(example)
583580
# Decode example for Audio feature, e.g.
@@ -674,6 +671,11 @@ def _iter_shard(self, shard_idx: int):
674671

675672
def _apply_feature_types(self, example):
676673
if self.features:
674+
example = dict(example)
675+
# add missing columns
676+
for column_name in self.features:
677+
if column_name not in example:
678+
example[column_name] = None
677679
# we encode the example for ClassLabel feature types for example
678680
encoded_example = self.features.encode_example(example)
679681
# Decode example for Audio feature, e.g.
@@ -1214,6 +1216,25 @@ def cast(
12141216
shuffling=copy.deepcopy(self._shuffling),
12151217
token_per_repo_id=self._token_per_repo_id,
12161218
)
1219+
1220+
def _resolve_features(self):
1221+
if self.features is not None:
1222+
return self
1223+
elif isinstance(self._ex_iterable, TypedExamplesIterable):
1224+
features = self._ex_iterable.features
1225+
else:
1226+
features = _infer_features_from_batch(self._head())
1227+
info = self.info.copy()
1228+
info.features = features
1229+
return iterable_dataset(
1230+
ex_iterable=self._ex_iterable,
1231+
info=info,
1232+
split=self._split,
1233+
format_type=self._format_type,
1234+
shuffling=copy.deepcopy(self._shuffling),
1235+
token_per_repo_id=self._token_per_repo_id,
1236+
)
1237+
12171238

12181239

12191240
def iterable_dataset(
@@ -1265,20 +1286,21 @@ def _concatenate_iterable_datasets(
12651286
>>> ds3 = _concatenate_iterable_datasets([ds1, ds2])
12661287
```
12671288
"""
1268-
ex_iterables = [
1269-
TypedExamplesIterable(d._ex_iterable, d.features)
1270-
if not isinstance(d._ex_iterable, TypedExamplesIterable) and d.features is not None
1271-
else d._ex_iterable
1272-
for d in dsets
1273-
]
1289+
dsets = [d._resolve_features() for d in dsets]
1290+
features = Features()
1291+
for dset in dsets:
1292+
features.update(dset.features)
1293+
ex_iterables = [d._ex_iterable for d in dsets]
12741294
if axis == 0:
12751295
ex_iterable = VerticallyConcatenatedMultiSourcesExamplesIterable(ex_iterables)
12761296
else:
12771297
ex_iterable = HorizontallyConcatenatedMultiSourcesExamplesIterable(ex_iterables)
1278-
# Set new info - we reset the features
1298+
# Set new info - we update the features
12791299
if info is None:
12801300
info = DatasetInfo.from_merge([d.info for d in dsets])
1281-
info.features = None
1301+
else:
1302+
info = info.copy()
1303+
info.features = features
12821304
# Get all the auth tokens per repository - in case the datasets come from different private repositories
12831305
token_per_repo_id = {repo_id: token for dataset in dsets for repo_id, token in dataset._token_per_repo_id.items()}
12841306
# Return new daset

0 commit comments

Comments
 (0)