Skip to content

Commit 08a7b38

Browse files
More rigorous shape inference in to_tf_dataset (#4763)
* More rigorous shape inference in to_tf_dataset * Simplify the new shape inference * Read length from Sequence features instead of just sampling batches * make style * Remove Sequence-specific code
1 parent c44e4f4 commit 08a7b38

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

src/datasets/arrow_dataset.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def _get_output_signature(
224224
collate_fn_args: dict,
225225
cols_to_retain: Optional[List[str]] = None,
226226
batch_size: Optional[int] = None,
227-
num_test_batches: int = 10,
227+
num_test_batches: int = 200,
228228
):
229229
"""Private method used by `to_tf_dataset()` to find the shapes and dtypes of samples from this dataset
230230
after being passed through the collate_fn. Tensorflow needs an exact signature for tf.numpy_function, so
@@ -253,11 +253,9 @@ def _get_output_signature(
253253

254254
if len(dataset) == 0:
255255
raise ValueError("Unable to get the output signature because the dataset is empty.")
256-
if batch_size is None:
257-
test_batch_size = min(len(dataset), 8)
258-
else:
256+
if batch_size is not None:
259257
batch_size = min(len(dataset), batch_size)
260-
test_batch_size = batch_size
258+
test_batch_size = min(len(dataset), 2)
261259

262260
test_batches = []
263261
for _ in range(num_test_batches):

0 commit comments

Comments
 (0)