diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 43301d23041..dbdff64953b 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -4868,7 +4868,7 @@ def train_test_split( try: train_indices, test_indices = next( stratified_shuffle_split_generate_indices( - self.with_format("numpy")[stratify_by_column], n_train, n_test, rng=generator + np.asarray(self.with_format("numpy")[stratify_by_column]), n_train, n_test, rng=generator ) ) except Exception as error: