Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions python/cuml/cuml/accel/tests/scikit-learn/xfail-list.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -530,14 +530,9 @@
- "sklearn.linear_model.tests.test_coordinate_descent::test_enet_sample_weight_consistency[42-None-False-0.01-True]"
- "sklearn.linear_model.tests.test_coordinate_descent::test_enet_sample_weight_does_not_overwrite_sample_weight[False]"
- "sklearn.linear_model.tests.test_coordinate_descent::test_enet_sample_weight_does_not_overwrite_sample_weight[True]"
- "sklearn.linear_model.tests.test_coordinate_descent::test_enet_toy"
- "sklearn.linear_model.tests.test_coordinate_descent::test_lassoCV_does_not_set_precompute[False-False]"
- "sklearn.linear_model.tests.test_coordinate_descent::test_lassoCV_does_not_set_precompute[auto-False]"
- "sklearn.linear_model.tests.test_coordinate_descent::test_lasso_alpha_warning"
- "sklearn.linear_model.tests.test_coordinate_descent::test_lasso_dual_gap"
- "sklearn.linear_model.tests.test_coordinate_descent::test_lasso_readonly_data"
- "sklearn.linear_model.tests.test_coordinate_descent::test_lasso_toy"
- "sklearn.linear_model.tests.test_coordinate_descent::test_lasso_zero"
- "sklearn.linear_model.tests.test_coordinate_descent::test_warm_start_convergence"
- "sklearn.linear_model.tests.test_coordinate_descent::test_warm_start_convergence_with_regularizer_decrement"
- "sklearn.linear_model.tests.test_coordinate_descent::test_elasticnet_precompute_gram_weighted_samples"
Expand Down
99 changes: 97 additions & 2 deletions python/cuml/cuml/linear_model/elastic_net.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@

from inspect import signature

import cupy as cp
import numpy as np

from cuml.common import input_to_cuml_array
from cuml.common.array_descriptor import CumlArrayDescriptor
from cuml.common.doc_utils import generate_docstring
Expand Down Expand Up @@ -147,6 +150,8 @@ class ElasticNet(Base,
The estimated coefficients for the linear regression model.
intercept_ : array
The independent term. If `fit_intercept` is False, will be 0.
dual_gap_ : float
Given param alpha, the dual gaps at the end of the optimization.

Notes
-----
Expand Down Expand Up @@ -213,13 +218,15 @@ class ElasticNet(Base,
return {
"intercept_": to_gpu(model.intercept_, order="F"),
"coef_": to_gpu(model.coef_, order="F"),
"dual_gap_": model.dual_gap_,
**super()._attrs_from_cpu(model),
}

def _attrs_to_cpu(self, model):
return {
"intercept_": to_cpu(self.intercept_),
"coef_": to_cpu(self.coef_),
"dual_gap_": self.dual_gap_,
**super()._attrs_to_cpu(model),
}

Expand Down Expand Up @@ -309,6 +316,72 @@ class ElasticNet(Base,
msg = "l1_ratio value has to be between 0.0 and 1.0"
raise ValueError(msg.format(l1_ratio))

def _compute_dual_gap(
self,
X_m: CumlArray,
y_m: CumlArray,
sample_weight_m: CumlArray | None = None,
) -> float:
"""Compute `dual_gap_` by its objective formulation"""
# TODO: this should all probably be migrated into the CD solver itself,
# much of this is repeated work. In addition, the stopping criteria
# used in cuML's CD solver doesn't take into account the dual gap
# (while sklearn's does). Either way, the formulation below matches
# what sklearn does to compute the gap, we just won't have the same
# coefficients found in cuml as are found in sklearn.
X = X_m.to_output("cupy")
if X.ndim == 1:
X = X[:, None]

y = y_m.to_output("cupy")
if y.ndim == 2:
y = y.ravel()

if sample_weight_m is not None:
sample_weight = sample_weight_m.to_output("cupy")
else:
sample_weight = None

w = self.coef_.to_output("cupy")

n_samples = len(y)

if self.fit_intercept:
X_mean = cp.average(X, axis=0, weights=sample_weight)
y_mean = cp.average(y, axis=0, weights=sample_weight)
X -= X_mean
y -= y_mean

if sample_weight is not None:
sample_weight = sample_weight * (n_samples / cp.sum(sample_weight))
sample_weight_sqrt = cp.sqrt(sample_weight)
X *= sample_weight_sqrt[:, None]
y *= sample_weight_sqrt

alpha = self.alpha * self.l1_ratio * n_samples
beta = self.alpha * (1.0 - self.l1_ratio) * n_samples

R = y - X @ w
XtA = cp.dot(X.T, R) - beta * w
dual_norm_XtA = cp.max(cp.abs(XtA))
R_norm2 = cp.dot(R, R)
w_norm2 = cp.dot(w, w)
if (dual_norm_XtA > alpha):
const = alpha / dual_norm_XtA
A_norm2 = R_norm2 * (const ** 2)
gap = 0.5 * (R_norm2 + A_norm2)
else:
const = 1.0
gap = R_norm2

l1_norm = cp.sum(cp.abs(w))

gap += (alpha * l1_norm
- const * cp.dot(R.T, y)
+ 0.5 * beta * (1 + const ** 2) * (w_norm2))

return float(gap) / n_samples

@generate_docstring()
@warn_legacy_device_interop
def fit(self, X, y, convert_dtype=True,
Expand All @@ -317,8 +390,28 @@ class ElasticNet(Base,
Fit the model with X and y.

"""
X_m, _, self.n_features_in_, self.dtype = input_to_cuml_array(X)
y_m, _, _, _ = input_to_cuml_array(y)
X_m, n_rows, self.n_features_in_, self.dtype = input_to_cuml_array(
X,
convert_to_dtype=(np.float32 if convert_dtype else None),
check_dtype=[np.float32, np.float64],
)
y_m, _, _, _ = input_to_cuml_array(
y,
convert_to_dtype=(self.dtype if convert_dtype else None),
check_dtype=self.dtype,
check_rows=n_rows,
)
if sample_weight is not None:
sample_weight_m, _, _, _ = input_to_cuml_array(
sample_weight,
check_dtype=self.dtype,
convert_to_dtype=(self.dtype if convert_dtype else None),
check_rows=n_rows,
check_cols=1,
)
else:
sample_weight_m = None

if hasattr(X_m, 'index'):
self.feature_names_in_ = X_m.index

Expand All @@ -345,6 +438,8 @@ class ElasticNet(Base,
self.coef_ = self.solver_model.coef_
self.intercept_ = self.solver_model.intercept_

self.dual_gap_ = self._compute_dual_gap(X_m, y_m, sample_weight_m)

return self

def set_params(self, **params):
Expand Down
2 changes: 2 additions & 0 deletions python/cuml/cuml/linear_model/lasso.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ class Lasso(ElasticNet):
The estimated coefficients for the linear regression model.
intercept_ : array
The independent term. If `fit_intercept` is False, will be 0.
dual_gap_ : float
Given param alpha, the dual gaps at the end of the optimization.

Notes
-----
Expand Down
25 changes: 25 additions & 0 deletions python/cuml/cuml/tests/test_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,14 @@
from scipy.sparse import csr_matrix
from sklearn.datasets import load_breast_cancer, load_digits
from sklearn.linear_model import ElasticNet as skElasticNet
from sklearn.linear_model import Lasso as skLasso
from sklearn.linear_model import LinearRegression as skLinearRegression
from sklearn.linear_model import LogisticRegression as skLog
from sklearn.linear_model import Ridge as skRidge
from sklearn.model_selection import train_test_split

from cuml import ElasticNet as cuElasticNet
from cuml import Lasso as cuLasso
from cuml import LinearRegression as cuLinearRegression
from cuml import LogisticRegression as cuLog
from cuml import Ridge as cuRidge
Expand Down Expand Up @@ -1185,3 +1187,26 @@ def test_elasticnet_model(datatype, solver, nrows, column_info, ntargets):
total_tol=1e-0,
with_sign=True,
)


@pytest.mark.parametrize("cls", ["elasticnet", "lasso"])
@pytest.mark.parametrize("fit_intercept", [True, False])
@pytest.mark.parametrize("weighted", [False, True])
def test_dual_gap(cls, fit_intercept, weighted):
X, y = make_regression(random_state=42)
if cls == "elasticnet":
model = cuElasticNet(fit_intercept=fit_intercept, tol=1e-4)
sk_model = skElasticNet(fit_intercept=fit_intercept)
else:
model = cuLasso(fit_intercept=fit_intercept, tol=1e-4)
sk_model = skLasso(fit_intercept=fit_intercept)

if weighted:
sample_weight = np.random.default_rng(42).uniform(size=len(y))
else:
sample_weight = None

model.fit(X, y, sample_weight=sample_weight)
sk_model.fit(X, y, sample_weight=sample_weight)

np.testing.assert_allclose(model.dual_gap_, sk_model.dual_gap_)
Loading