From 2444d31422b4c4c32da9f743412ccdf096cf2ea4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Mon, 11 Aug 2025 13:46:35 -0700 Subject: [PATCH] Update arrow_dataset.py --- src/datasets/arrow_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 70bd671930a..3e66605eed3 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -4647,13 +4647,13 @@ def train_test_split( This method is similar to scikit-learn `train_test_split`. Args: - test_size (`numpy.random.Generator`, *optional*): + test_size (`Union[float, int, None]`, *optional*): Size of the test split If `float`, should be between `0.0` and `1.0` and represent the proportion of the dataset to include in the test split. If `int`, represents the absolute number of test samples. If `None`, the value is set to the complement of the train size. If `train_size` is also `None`, it will be set to `0.25`. - train_size (`numpy.random.Generator`, *optional*): + train_size (`Union[float, int, None]`, *optional*): Size of the train split If `float`, should be between `0.0` and `1.0` and represent the proportion of the dataset to include in the train split. If `int`, represents the absolute number of train samples.