Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions python/cuml/cuml/model_selection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,13 @@
with CUDA-based data and cuML estimators, but all of the underlying code
is due to the scikit-learn developers."""

from cuml.model_selection._split import StratifiedKFold, train_test_split
from cuml.model_selection._split import (
KFold,
StratifiedKFold,
train_test_split,
)

__all__ = ["train_test_split", "GridSearchCV", "StratifiedKFold"]
__all__ = ["train_test_split", "KFold", "GridSearchCV", "StratifiedKFold"]


def __getattr__(name):
Expand Down
122 changes: 108 additions & 14 deletions python/cuml/cuml/model_selection/_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import List, Optional, Tuple, Union

import cudf
Expand All @@ -27,6 +28,7 @@
determine_df_obj_type,
)
from cuml.internals.output_utils import output_to_df_obj_like
from cuml.internals.utils import check_random_seed


def _compute_stratify_split_indices(
Expand Down Expand Up @@ -419,7 +421,104 @@ def _process_df_objs(
return X_train, X_test


class StratifiedKFold:
class _KFoldBase(ABC):
"""Base class for k-fold split."""

def __init__(self, n_splits=5, shuffle=False, random_state=None):
if n_splits < 2 or not isinstance(n_splits, int):
raise ValueError(
f"n_splits {n_splits} is not a integer at least 2"
)

self.n_splits = n_splits
self.shuffle = shuffle
self.seed = random_state

@abstractmethod
def split(self, x, y):
"""Generate indices to split data into training and test set.
Comment thread
csadorf marked this conversation as resolved.
Outdated

Parameters
----------
X : array-like of shape (n_samples, n_features)
Training data, where `n_samples` is the number of samples
and `n_features` is the number of features.

y : array-like of shape (n_samples,), default=None
The target variable for supervised learning problems.

Yields
------
train_idx : CuPy ndarray
The training set indices for that split.

test_idx : CuPy ndarray
The testing set indices for that split.
Comment thread
trivialfis marked this conversation as resolved.
Outdated
"""
raise NotImplementedError()


class KFold(_KFoldBase):
Comment thread
csadorf marked this conversation as resolved.
"""K-Folds cross-validator.

Provides train/test indices to split data in train/test sets. Split dataset into k
consecutive folds (without shuffling by default).

Each fold is then used once as a validation set while the k - 1 remaining folds form
the training set.

Parameters
Comment thread
trivialfis marked this conversation as resolved.
----------
n_splits :
Number of folds. Must be at least 2.

shuffle :
Whether to shuffle the samples before splitting.

random_state : int, CuPy RandomState or NumPy RandomState optional
If shuffle is true, seeds the generator. Unseeded by default. Pass an int for
reproducible output across multiple function calls.

"""

def __init__(self, n_splits=5, shuffle=False, random_state=None):
Comment thread
trivialfis marked this conversation as resolved.
Outdated
super().__init__(
n_splits=n_splits, shuffle=shuffle, random_state=random_state
)

def split(self, x, y=None):
n_samples = x.shape[0]
if y is not None and n_samples != len(y):
raise ValueError("Expecting same length of x and y")
if n_samples < self.n_splits:
raise ValueError(
f"n_splits: {self.n_splits} must be smaller than the number of samples: {n_samples}."
)

indices = cp.arange(n_samples)

if self.shuffle:
cp.random.RandomState(check_random_seed(self.seed)).shuffle(indices)

fold_sizes = cp.full(
self.n_splits, n_samples // self.n_splits, dtype=cp.int64
)
fold_sizes[: n_samples % self.n_splits] += 1

current = 0
for fold_size in fold_sizes:
start, stop = current, current + fold_size
test = indices[start:stop]

mask = cp.zeros(n_samples, dtype=cp.bool_)
mask[start:stop] = True

train = indices[cp.logical_not(mask)]
yield train, test
current = stop


class StratifiedKFold(_KFoldBase):
"""
A cudf based implementation of Stratified K-Folds cross-validator.

Expand All @@ -431,10 +530,13 @@ class StratifiedKFold:
----------
n_splits : int, default=5
Number of folds. Must be at least 2.

shuffle : boolean, default=False
Whether to shuffle each class's samples before splitting.
random_state : int (default=None)
Random seed

random_state : int, CuPy RandomState or NumPy RandomState optional
If shuffle is true, seeds the generator. Unseeded by default. Pass an int for
reproducible output across multiple function calls.

Examples
--------
Expand All @@ -455,17 +557,9 @@ class StratifiedKFold:
"""

def __init__(self, n_splits=5, shuffle=False, random_state=None):
if n_splits < 2 or not isinstance(n_splits, int):
raise ValueError(
f"n_splits {n_splits} is not a integer at least 2"
)

if random_state is not None and not isinstance(random_state, int):
raise ValueError(f"random_state {random_state} is not an integer")

self.n_splits = n_splits
self.shuffle = shuffle
self.seed = random_state
super().__init__(
n_splits=n_splits, shuffle=shuffle, random_state=random_state
)

def get_n_splits(self, X=None, y=None):
return self.n_splits
Expand Down
52 changes: 51 additions & 1 deletion python/cuml/tests/test_stratified_kfold.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@

import cudf
import cupy as cp
import numpy as np
import pytest

from cuml.model_selection import StratifiedKFold
from cuml.datasets import make_regression
from cuml.model_selection import KFold, StratifiedKFold


def get_x_y(n_samples, n_classes):
Expand Down Expand Up @@ -63,3 +65,51 @@ def test_invalid_folds(n_splits):
kf = StratifiedKFold(n_splits=n_splits)
for train_index, test_index in kf.split(X, y):
break


@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("n_splits", [5, 10])
@pytest.mark.parametrize(
"random_state",
[
1,
np.random.RandomState(1),
cp.random.RandomState(1),
],
)
def test_kfold(shuffle, n_splits, random_state) -> None:
n_samples = 256
n_features = 16
X, y = make_regression(n_samples, n_features, random_state=1)
kfold = KFold(n_splits=n_splits, shuffle=shuffle, random_state=random_state)
n_test_total = 0

for train_idx, test_idx in kfold.split(X, y):
n_test_total += test_idx.size

assert train_idx.shape[0] + test_idx.shape[0] == n_samples
fold_size = X.shape[0] // n_splits
assert test_idx.shape[0] in (fold_size, fold_size + 1)
assert cp.all(train_idx >= 0)
assert cp.all(test_idx >= 0)
indices = cp.concatenate([train_idx, test_idx])
assert len(indices.shape) == 1
assert indices.size == n_samples
Comment thread
betatim marked this conversation as resolved.
uniques = cp.unique(indices)
sorted_uniques = cp.sort(uniques)

assert uniques.size == n_samples, indices
arr = cp.arange(n_samples)
cp.testing.assert_allclose(sorted_uniques, arr)

assert n_test_total == n_samples


# Since the kfold only uses the shape of the input, not the actual data, we only have a
# small test for dataframe.
def test_kfold_dataframe() -> None:
n_samples = 4096
X, y = get_x_y(n_samples, 2)
kfold = KFold(n_splits=5, shuffle=True)
for train_idx, test_idx in kfold.split(X, y):
assert train_idx.shape[0] + test_idx.shape[0] == n_samples
Loading