Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions docs/source/cuml-accel/limitations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -284,11 +284,8 @@ Ridge

``Ridge`` will fall back to CPU in the following cases:

- If ``positive=True``.
- If ``solver="lbfgs"``.
- If ``positive=True`` or ``solver="lbfgs"``.
- If ``X`` is sparse.
- If ``X`` has more columns than rows.
- If ``y`` is multioutput.

Additionally, the following fitted attributes are currently not computed:

Expand Down
37 changes: 24 additions & 13 deletions python/cuml/cuml/accel/_wrappers/sklearn/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
# SPDX-License-Identifier: Apache-2.0
#

import sklearn
from packaging.version import Version

import cuml.linear_model
from cuml.accel.estimator_proxy import ProxyBase
from cuml.internals.input_utils import input_to_cuml_array
from cuml.internals.interop import UnsupportedOnGPU
from cuml.internals.array import CumlArray
from cuml.internals.memory_utils import using_output_type

__all__ = (
"LinearRegression",
Expand All @@ -17,6 +20,9 @@
)


SKLEARN_16 = Version(sklearn.__version__) >= Version("1.6.0")


class LinearRegression(ProxyBase):
_gpu_class = cuml.linear_model.LinearRegression
_not_implemented_attributes = frozenset(("rank_", "singular_"))
Expand All @@ -40,17 +46,22 @@ class Ridge(ProxyBase):
_not_implemented_attributes = frozenset(("n_iter_",))

def _gpu_fit(self, X, y, sample_weight=None):
X = input_to_cuml_array(X, convert_to_mem_type=False)[0]
y = input_to_cuml_array(y, convert_to_mem_type=False)[0]
if len(y.shape) > 1:
raise UnsupportedOnGPU("Multioutput `y` is not supported")

if X.shape[0] < X.shape[1]:
raise UnsupportedOnGPU(
"`X` with more columns than rows is not supported"
)

return self._gpu.fit(X, y, sample_weight=sample_weight)
self._gpu.fit(X, y, sample_weight=sample_weight)

# XXX: sklearn 1.6 changed the shape of `coef_` when fit with a 1
# column 2D y. The sklearn 1.6+ behavior is what we implement in
# cuml.Ridge, here we adjust the shape of `coef_` after the fit to
# match the older behavior. This will also trickle down to change the
# output shape of `predict` to match the older behavior transparently.
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is gross, but isolated to cuml.accel so I'm ok with it.

if not SKLEARN_16 and (y_shape := getattr(y, "shape", ())):
if len(y_shape) == 2 and y_shape[1] == 1:
with using_output_type("cupy"):
# Reshape coef_ to be a 2D array
self._gpu.coef_ = CumlArray(
data=self._gpu.coef_.reshape(1, -1)
)

return self


class Lasso(ProxyBase):
Expand Down
7 changes: 7 additions & 0 deletions python/cuml/cuml/internals/memory_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,3 +377,10 @@ def determine_array_memtype(X):
if isinstance(X, (pd.DataFrame, pd.Series)):
return MemoryType.host
return None


def cuda_ptr(X):
"""Returns a pointer to a backing device array, or None if not a device array"""
if (interface := getattr(X, "__cuda_array_interface__", None)) is not None:
return interface["data"][0]
return None
12 changes: 3 additions & 9 deletions python/cuml/cuml/linear_model/linear_regression.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ from cuml.internals.interop import (
to_cpu,
to_gpu,
)
from cuml.internals.memory_utils import cuda_ptr
from cuml.internals.mixins import FMajorInputTagMixin, RegressorMixin
from cuml.linear_model.base import (
LinearPredictMixin,
Expand Down Expand Up @@ -88,13 +89,6 @@ _divide_non_zero = cp.ElementwiseKernel(
)


def _cuda_ptr(X):
"""Returns a pointer to a backing device array, or None if not a device array"""
if (interface := getattr(X, "__cuda_array_interface__", None)) is not None:
return interface["data"][0]
return None


class LinearRegression(Base,
InteropMixin,
LinearPredictMixin,
Expand Down Expand Up @@ -354,8 +348,8 @@ class LinearRegression(Base,

cdef int algo = self._select_algo(X_m, y_m)

X_is_copy = _cuda_ptr(X) != X_m.ptr
y_is_copy = _cuda_ptr(y) != y_m.ptr
X_is_copy = cuda_ptr(X) != X_m.ptr
y_is_copy = cuda_ptr(y) != y_m.ptr

if y_m.ndim > 1 and y_m.shape[1] > 1:
# Fallback to cupy SVD implementation for multi-target problems
Expand Down
Loading