Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion python/cuml/cuml/kernel_ridge/kernel_ridge.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ import warnings
from cuml.internals.safe_imports import gpu_only_import_from
from cuml.internals.safe_imports import gpu_only_import
from cupyx import lapack, geterr, seterr
from cuml.internals.array import CumlArray
from cuml.common.array_descriptor import CumlArrayDescriptor
from cuml.internals.base import UniversalBase
from cuml.internals.api_decorators import (
device_interop_preparation,
enable_device_interop,
api_base_return_array,
)
from cuml.internals.mixins import RegressorMixin
from cuml.common.doc_utils import generate_docstring
Expand Down Expand Up @@ -293,6 +295,7 @@ class KernelRidge(UniversalBase, RegressorMixin):
self.X_fit_ = X_m
return self

@api_base_return_array()
@enable_device_interop
def predict(self, X):
"""
Expand All @@ -315,4 +318,4 @@ class KernelRidge(UniversalBase, RegressorMixin):
X, check_dtype=[np.float32, np.float64])

K = self._get_kernel(X_m, self.X_fit_)
return cp.dot(cp.asarray(K), cp.asarray(self.dual_coef_))
return CumlArray(cp.dot(cp.asarray(K), cp.asarray(self.dual_coef_)))
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from sklearn.decomposition import PCA, TruncatedSVD
from sklearn.kernel_ridge import KernelRidge
from sklearn.manifold import TSNE
from sklearn.model_selection import GridSearchCV
from sklearn.neighbors import (
NearestNeighbors,
KNeighborsClassifier,
Expand Down Expand Up @@ -191,17 +190,6 @@ def test_kernel_ridge():
X = 5 * rng.rand(10000, 1)
y = np.sin(X).ravel()

kr = GridSearchCV(
KernelRidge(kernel="rbf", gamma=0.1),
param_grid={
"alpha": [1e0, 0.1, 1e-2, 1e-3],
"gamma": np.logspace(-2, 2, 5),
},
)
kr = KernelRidge(kernel="rbf", gamma=0.1)
kr.fit(X, y)

y_pred = kr.predict(X)

assert not isinstance(
y_pred, cp.ndarray
), f"y_pred should be a np.ndarray, but is a {type(y_pred)}"
kr.predict(X)
20 changes: 19 additions & 1 deletion python/cuml/cuml/tests/test_kernel_ridge.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2022-2024, NVIDIA CORPORATION.
# Copyright (c) 2022-2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -20,6 +20,7 @@
import math
import pytest
from sklearn.metrics.pairwise import pairwise_kernels as skl_pairwise_kernels
import cuml
from cuml.metrics import pairwise_kernels, PAIRWISE_KERNEL_FUNCTIONS
from cuml import KernelRidge as cuKernelRidge
from cuml.internals.safe_imports import cpu_only_import
Expand Down Expand Up @@ -304,6 +305,23 @@ def test_estimator(kernel_arg, arrays, gamma, degree, coef0):
)


def test_predict_output_type():
rng = np.random.RandomState(42)

X = 5 * rng.rand(10000, 1)
y = np.sin(X).ravel()

kr = cuKernelRidge(kernel="rbf", gamma=0.1)
kr.fit(X, y)

res = kr.predict(X)
assert isinstance(res, np.ndarray)

with cuml.using_output_type("cupy"):
res = kr.predict(X)
assert isinstance(res, cp.ndarray)


def test_precomputed():
rs = np.random.RandomState(23)
X = rs.normal(size=(10, 10))
Expand Down