Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ci/scripts/setup_sklearn.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ if [ "$sklearn_version" == "main" ]; then
# install sklearn build dependencies
pip install threadpoolctl joblib scipy
# install sklearn from main branch of git repo
pip install git+https://github.com/scikit-learn/scikit-learn.git@main
pip install git+https://github.com/snath-xoc/scikit-learn.git@deprecate_porb_NuSVC
else
sed -i.bak -E "s/scikit-learn==[0-9a-zA-Z.]*/scikit-learn==${sklearn_version}.*/" requirements-test.txt
fi
20 changes: 13 additions & 7 deletions sklearnex/svm/nusvc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,15 @@
# limitations under the License.
# ==============================================================================

import warnings
from functools import wraps

import numpy as np
from sklearn.exceptions import NotFittedError
from sklearn.metrics import accuracy_score
from sklearn.svm import NuSVC as _sklearn_NuSVC
from sklearn.utils.metaestimators import available_if
from sklearn.utils.validation import (
_deprecate_positional_args,
check_array,
check_is_fitted,
)
from sklearn.utils.validation import _deprecate_positional_args, check_is_fitted

from daal4py.sklearn._n_jobs_support import control_n_jobs
from daal4py.sklearn._utils import sklearn_check_version
Expand Down Expand Up @@ -56,7 +53,7 @@ def __init__(
gamma="scale",
coef0=0.0,
shrinking=True,
probability=False,
probability="deprecated" if sklearn_check_version("1.8") else False,
tol=1e-3,
cache_size=200,
class_weight=None,
Expand Down Expand Up @@ -218,7 +215,16 @@ def _onedal_fit(self, X, y, sample_weight=None, queue=None):
self._onedal_estimator = onedal_NuSVC(**onedal_params)
self._onedal_estimator.fit(X, y, weights, queue=queue)

if self.probability:
probability = self.probability
if sklearn_check_version("1.8") and self.probability != "deprecated":
warnings.warn(
"parameter `probability` will be deprecated in version 1.8, "
"use `CalibratedClassifierCV(NuSVC(), ensemble=False)` "
"instead of `NuSVC(probability=True)`",
FutureWarning,
)
probability = False
if probability:
self._fit_proba(
X,
y,
Expand Down
20 changes: 13 additions & 7 deletions sklearnex/svm/svc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
# ==============================================================================

import warnings
from functools import wraps

import numpy as np
Expand All @@ -22,11 +23,7 @@
from sklearn.metrics import accuracy_score
from sklearn.svm import SVC as _sklearn_SVC
from sklearn.utils.metaestimators import available_if
from sklearn.utils.validation import (
_deprecate_positional_args,
check_array,
check_is_fitted,
)
from sklearn.utils.validation import _deprecate_positional_args, check_is_fitted

from daal4py.sklearn._n_jobs_support import control_n_jobs
from daal4py.sklearn._utils import sklearn_check_version
Expand Down Expand Up @@ -58,7 +55,7 @@ def __init__(
gamma="scale",
coef0=0.0,
shrinking=True,
probability=False,
probability="deprecated" if sklearn_check_version("1.8") else False,
tol=1e-3,
cache_size=200,
class_weight=None,
Expand Down Expand Up @@ -248,7 +245,16 @@ def _onedal_fit(self, X, y, sample_weight=None, queue=None):
self._onedal_estimator = onedal_SVC(**onedal_params)
self._onedal_estimator.fit(X, y, weights, queue=queue)

if self.probability:
probability = self.probability
if sklearn_check_version("1.8") and self.probability != "deprecated":
warnings.warn(
"parameter `probability` will be deprecated in version 1.8, "
"use `CalibratedClassifierCV(SVC(), ensemble=False)` "
"instead of `SVC(probability=True)`",
FutureWarning,
)
probability = False
if probability:
self._fit_proba(
X,
y,
Expand Down
Loading