Skip to content

Commit 226c677

Browse files
authored
[maintenance] prepare next release for sklearn 1.8 API breaking changes (#2578)
* Update validation.py * Update dispatcher.py * Update dispatcher.py * Update dispatcher.py * Update run_test.sh * Update _pairwise.py * Update _pairwise.py * Update _pairwise.py * Update _pairwise.py * Update dispatcher.py * Update run_test.sh * Update dispatcher.py * Update _dlpack.py * Update _common.py * Update logistic_regression.py * Update logistic_path.py * Update _common.py * attempt to fix * attempt to fix * attempt to fix * add logistic regression to testing * formatting * force value of cv * ugly solution * remove vestigial newline * isort fixes * move imports to respect sklearn versions * Update setup_sklearn.sh * Update __init__.py * Update __init__.py * Update __init__.py * Update test_model_builders.py * formatting * try to fix logreg testing * try to fix issues with testing mb * fix logic again * try to fix daal4py again * Update test_logreg.py * Update test_logreg.py * formatting * Update deselected_tests.yaml * Update logistic_regression.py * Update setup_sklearn.sh * changes as requested from reviews * make suggestion from review * fix again * formatting fixg * false flag on pyright * Update _common.py * Update test_patching.py * Update test_patching.py * fix para
1 parent fada56a commit 226c677

File tree

12 files changed

+348
-326
lines changed

12 files changed

+348
-326
lines changed

.ci/scripts/select_sklearn_tests.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,12 @@ def parse_tests_tree(entry, prefix=""):
4141
"covariance/tests": "test_covariance.py",
4242
"decomposition/tests": ["test_pca.py", "test_incremental_pca.py"],
4343
"ensemble/tests": "test_forest.py",
44-
"linear_model/tests": ["test_base.py", "test_coordinate_descent.py", "test_ridge.py"],
44+
"linear_model/tests": [
45+
"test_base.py",
46+
"test_coordinate_descent.py",
47+
"test_logistic.py",
48+
"test_ridge.py",
49+
],
4550
"manifold/tests": "test_t_sne.py",
4651
"metrics/tests": ["test_pairwise.py", "test_ranking.py"],
4752
"model_selection/tests": ["test_split.py", "test_validation.py"],

daal4py/mb/__init__.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,26 @@ def convert_model(model) -> "GBTDAALModel | LogisticDAALModel":
5353
offers faster prediction methods.
5454
"""
5555
if isinstance(model, LogisticRegression):
56+
# The multi_class keyword is removed in scikit-learn 1.8, and OvR functionality
57+
# has been replaced by other estimators. Therefore checking for linear classifiers
58+
# only dependent on the solver.
5659
if model.classes_.shape[0] > 2:
57-
if (model.multi_class == "ovr") or (
58-
model.multi_class == "auto" and model.solver == "liblinear"
59-
):
60-
raise TypeError(
61-
"Supplied 'model' object is a linear classifier, but not multinomial logistic"
62-
" (hint: pass multi_class='multinomial' to 'LogisticRegression')."
63-
)
64-
elif (model.classes_.shape[0] == 2) and (model.multi_class == "multinomial"):
60+
if not hasattr(model, "multi_class"):
61+
if model.solver == "liblinear":
62+
raise TypeError(
63+
"Supplied 'model' object is a linear classifier, but not multinomial logistic"
64+
)
65+
else:
66+
if (model.multi_class == "ovr") or (
67+
model.multi_class == "auto" and model.solver == "liblinear"
68+
):
69+
raise TypeError(
70+
"Supplied 'model' object is a linear classifier, but not multinomial logistic"
71+
" (hint: pass multi_class='multinomial' to 'LogisticRegression')."
72+
)
73+
elif (model.classes_.shape[0] == 2) and (
74+
getattr(model, "multi_class", "auto") == "multinomial"
75+
):
6576
raise TypeError(
6677
"Supplied 'model' object is not a logistic regressor "
6778
"(hint: pass multi_class='auto' to 'LogisticRegression')."

daal4py/sklearn/linear_model/logistic_path.py

Lines changed: 57 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,14 @@
2020
import scipy.optimize as optimize
2121
import scipy.sparse as sparse
2222
import sklearn.linear_model._logistic as logistic_module
23-
from sklearn.linear_model._sag import sag_solver
24-
from sklearn.utils import (
25-
check_array,
26-
check_consistent_length,
27-
check_random_state,
28-
compute_class_weight,
23+
from sklearn.linear_model._logistic import _LOGISTIC_SOLVER_CONVERGENCE_MSG
24+
from sklearn.linear_model._logistic import (
25+
LogisticRegression as LogisticRegression_original,
2926
)
27+
from sklearn.linear_model._logistic import _check_solver
28+
from sklearn.utils import check_array, check_consistent_length, check_random_state
3029
from sklearn.utils.optimize import _check_optimize_result, _newton_cg
31-
from sklearn.utils.validation import _check_sample_weight, check_is_fitted
30+
from sklearn.utils.validation import check_is_fitted
3231

3332
import daal4py as d4p
3433

@@ -44,35 +43,6 @@
4443
_daal4py_loss_and_grad,
4544
)
4645

47-
if sklearn_check_version("1.1"):
48-
from sklearn._loss.loss import HalfBinomialLoss, HalfMultinomialLoss
49-
from sklearn.linear_model._linear_loss import LinearModelLoss
50-
from sklearn.linear_model._logistic import _LOGISTIC_SOLVER_CONVERGENCE_MSG
51-
from sklearn.linear_model._logistic import (
52-
LogisticRegression as LogisticRegression_original,
53-
)
54-
from sklearn.linear_model._logistic import (
55-
_check_multi_class,
56-
_check_solver,
57-
_fit_liblinear,
58-
)
59-
else:
60-
from sklearn.linear_model._logistic import _LOGISTIC_SOLVER_CONVERGENCE_MSG
61-
from sklearn.linear_model._logistic import (
62-
LogisticRegression as LogisticRegression_original,
63-
)
64-
from sklearn.linear_model._logistic import (
65-
_check_multi_class,
66-
_check_solver,
67-
_fit_liblinear,
68-
_logistic_grad_hess,
69-
_logistic_loss,
70-
_logistic_loss_and_grad,
71-
_multinomial_grad_hess,
72-
_multinomial_loss,
73-
_multinomial_loss_grad,
74-
)
75-
7646
if sklearn_check_version("1.7.1"):
7747
from sklearn.utils.fixes import _get_additional_lbfgs_options_dict
7848
else:
@@ -86,6 +56,25 @@ def _get_additional_lbfgs_options_dict(k, v):
8656
from sklearn.preprocessing import LabelBinarizer, LabelEncoder
8757

8858

59+
# This code is a patch for sklearn 1.8, which is related to https://github.com/scikit-learn/scikit-learn/pull/32073
60+
# where the multi_class keyword is deprecated and this aspect is removed.
61+
def _check_multi_class(multi_class, solver, n_classes):
62+
"""Computes the multi class type, either "multinomial" or "ovr".
63+
For `n_classes` > 2 and a solver that supports it, returns "multinomial".
64+
For all other cases, in particular binary classification, return "ovr".
65+
"""
66+
if multi_class == "auto":
67+
if solver in ("liblinear",):
68+
multi_class = "ovr"
69+
elif n_classes > 2:
70+
multi_class = "multinomial"
71+
else:
72+
multi_class = "ovr"
73+
if multi_class == "multinomial" and solver in ("liblinear",):
74+
raise ValueError("Solver %s does not support a multinomial backend." % solver)
75+
return multi_class
76+
77+
8978
# Code adapted from sklearn.linear_model.logistic version 0.21
9079
def __logistic_regression_path(
9180
X,
@@ -110,46 +99,6 @@ def __logistic_regression_path(
11099
l1_ratio=None,
111100
n_threads=1,
112101
):
113-
_patching_status = PatchingConditionsChain(
114-
"sklearn.linear_model.LogisticRegression.fit"
115-
)
116-
_dal_ready = _patching_status.and_conditions(
117-
[
118-
(
119-
solver in ["lbfgs", "newton-cg"],
120-
f"'{solver}' solver is not supported. "
121-
"Only 'lbfgs' and 'newton-cg' solvers are supported.",
122-
),
123-
(not sparse.issparse(X), "X is sparse. Sparse input is not supported."),
124-
(sample_weight is None, "Sample weights are not supported."),
125-
(class_weight is None, "Class weights are not supported."),
126-
]
127-
)
128-
if not _dal_ready:
129-
_patching_status.write_log()
130-
return lr_path_original(
131-
X,
132-
y,
133-
pos_class=pos_class,
134-
Cs=Cs,
135-
fit_intercept=fit_intercept,
136-
max_iter=max_iter,
137-
tol=tol,
138-
verbose=verbose,
139-
solver=solver,
140-
coef=coef,
141-
class_weight=class_weight,
142-
dual=dual,
143-
penalty=penalty,
144-
intercept_scaling=intercept_scaling,
145-
multi_class=multi_class,
146-
random_state=random_state,
147-
check_input=check_input,
148-
max_squared_sum=max_squared_sum,
149-
sample_weight=sample_weight,
150-
l1_ratio=l1_ratio,
151-
**({"n_threads": n_threads} if sklearn_check_version("1.1") else {}),
152-
)
153102

154103
# Comment 2025-08-04: this file might have dead code paths from unsupported solvers.
155104
# It appears to have initially been a copy-paste of scikit-learn with a few additions
@@ -269,7 +218,6 @@ def __logistic_regression_path(
269218
func = _daal4py_loss_
270219
grad = _daal4py_grad_
271220
hess = _daal4py_grad_hess_
272-
warm_start_sag = {"coef": w0.T}
273221
else:
274222
target = y_bin
275223
if solver == "lbfgs":
@@ -280,7 +228,6 @@ def __logistic_regression_path(
280228
func = _daal4py_loss_
281229
grad = _daal4py_grad_
282230
hess = _daal4py_grad_hess_
283-
warm_start_sag = {"coef": np.expand_dims(w0, axis=1)}
284231

285232
coefs = list()
286233
n_iter = np.zeros(len(Cs), dtype=np.int32)
@@ -385,8 +332,6 @@ def _func_(x, *args):
385332
for i, ci in enumerate(coefs):
386333
coefs[i] = np.delete(ci, 0, axis=-1)
387334

388-
_patching_status.write_log()
389-
390335
return np.array(coefs), np.array(Cs), n_iter
391336

392337

@@ -427,20 +372,21 @@ def daal4py_predict(self, X, resultsToEvaluate):
427372
f"sklearn.linear_model.LogisticRegression.{_function_name}"
428373
)
429374
if _function_name != "predict":
375+
multi_class = getattr(self, "multi_class", "auto")
430376
_patching_status.and_conditions(
431377
[
432378
(
433379
self.classes_.size == 2
434-
or logistic_module._check_multi_class(
435-
self.multi_class if self.multi_class != "deprecated" else "auto",
380+
or _check_multi_class(
381+
multi_class if multi_class != "deprecated" else "auto",
436382
self.solver,
437383
self.classes_.size,
438384
)
439385
!= "ovr",
440386
f"selected multiclass option is not supported for n_classes > 2.",
441387
),
442388
(
443-
not (self.classes_.size == 2 and self.multi_class == "multinomial"),
389+
not (self.classes_.size == 2 and multi_class == "multinomial"),
444390
"multi_class='multinomial' not supported with binary data",
445391
),
446392
],
@@ -502,52 +448,35 @@ def daal4py_predict(self, X, resultsToEvaluate):
502448
return LogisticRegression_original.predict_log_proba(self, X)
503449

504450

505-
def logistic_regression_path(
506-
X,
507-
y,
508-
pos_class=None,
509-
Cs=10,
510-
fit_intercept=True,
511-
max_iter=100,
512-
tol=1e-4,
513-
verbose=0,
514-
solver="lbfgs",
515-
coef=None,
516-
class_weight=None,
517-
dual=False,
518-
penalty="l2",
519-
intercept_scaling=1.0,
520-
multi_class="auto",
521-
random_state=None,
522-
check_input=True,
523-
max_squared_sum=None,
524-
sample_weight=None,
525-
l1_ratio=None,
526-
n_threads=1,
527-
):
528-
return __logistic_regression_path(
529-
X,
530-
y,
531-
pos_class=pos_class,
532-
Cs=Cs,
533-
fit_intercept=fit_intercept,
534-
max_iter=max_iter,
535-
tol=tol,
536-
verbose=verbose,
537-
solver=solver,
538-
coef=coef,
539-
class_weight=class_weight,
540-
dual=dual,
541-
penalty=penalty,
542-
intercept_scaling=intercept_scaling,
543-
multi_class=multi_class,
544-
random_state=random_state,
545-
check_input=check_input,
546-
max_squared_sum=max_squared_sum,
547-
sample_weight=sample_weight,
548-
l1_ratio=l1_ratio,
549-
n_threads=n_threads,
451+
def logistic_regression_path(*args, **kwargs):
452+
453+
_patching_status = PatchingConditionsChain(
454+
"sklearn.linear_model.LogisticRegression.fit"
455+
)
456+
_dal_ready = _patching_status.and_conditions(
457+
[
458+
(
459+
kwargs["solver"] in ["lbfgs", "newton-cg"],
460+
f"'{kwargs['solver']}' solver is not supported. "
461+
"Only 'lbfgs' and 'newton-cg' solvers are supported.",
462+
),
463+
(not sparse.issparse(args[0]), "X is sparse. Sparse input is not supported."),
464+
(kwargs["sample_weight"] is None, "Sample weights are not supported."),
465+
(kwargs["class_weight"] is None, "Class weights are not supported."),
466+
]
550467
)
468+
if not _dal_ready:
469+
_patching_status.write_log()
470+
return lr_path_original(*args, **kwargs)
471+
472+
if sklearn_check_version("1.8"):
473+
kwargs.pop("classes", None)
474+
res = __logistic_regression_path(*(args[:2]), **kwargs)
475+
else:
476+
res = __logistic_regression_path(*args, **kwargs)
477+
478+
_patching_status.write_log()
479+
return res
551480

552481

553482
@control_n_jobs(

daal4py/sklearn/monkeypatch/tests/_models_info.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor, NearestNeighbors
3131
from sklearn.svm import SVC
3232

33-
from daal4py.sklearn._utils import daal_check_version
33+
from daal4py.sklearn._utils import daal_check_version, sklearn_check_version
3434

3535
MODELS_INFO = [
3636
{
@@ -84,7 +84,10 @@
8484
"dataset": "classifier",
8585
},
8686
{
87-
"model": LogisticRegression(max_iter=100, multi_class="multinomial"),
87+
"model": LogisticRegression(
88+
max_iter=100,
89+
**({} if sklearn_check_version("1.8") else {"multi_class": "multinomial"})
90+
),
8891
"methods": [
8992
"decision_function",
9093
"predict",

daal4py/sklearn/utils/validation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,9 +280,9 @@ def _daal_check_array(
280280
array_converted : object
281281
The converted and validated array.
282282
"""
283-
if force_all_finite not in (True, False, "allow-nan"):
283+
if force_all_finite not in (True, False, "allow-nan", None):
284284
raise ValueError(
285-
'force_all_finite should be a bool or "allow-nan"'
285+
'force_all_finite should be a bool, None, or "allow-nan"'
286286
". Got {!r} instead".format(force_all_finite)
287287
)
288288

deselected_tests.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,9 @@ public:
457457
# single-threaded computation is used which is a frequent case for CI
458458
- decomposition/tests/test_pca.py::test_pca_dtype_preservation[full] <1.5
459459

460+
# Convergence failure due to different settings and cost function implementation
461+
- linear_model/tests/test_logistic.py::test_logistic_regression_path_convergence_fail
462+
460463
# --------------------------------------------------------
461464
# The following tests currently fail with GPU offloading
462465
gpu:

onedal/datatypes/_dlpack.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,14 @@ def dlpack_to_numpy(obj):
4545
raise TypeError(f"cannot move {type(obj)} to cpu")
4646

4747
# convert to numpy
48-
if hasattr(obj, "__array__"):
49-
# ``copy`` param for the ``asarray`` is not set.
50-
# The object is copied only if needed
51-
obj = np.asarray(obj)
52-
else:
48+
try:
49+
# Some frameworks implement an __array__ method just to
50+
# throw a RuntimeError when used (array_api_strict, dpctl),
51+
# or a TypeError (array_api-strict) rather than an AttributeError
52+
# therefore a try catch is necessary (logic is essentially a
53+
# getattr call + some)
54+
obj = obj.__array__()
55+
except (AttributeError, RuntimeError, TypeError):
5356
# requires numpy 1.23
5457
try:
5558
obj = np.from_dlpack(obj)

0 commit comments

Comments
 (0)