Skip to content

Commit 040b9be

Browse files
committed
Skip NaiveBayes estimators as they break other estimators
1 parent bb7854f commit 040b9be

1 file changed

Lines changed: 50 additions & 12 deletions

File tree

python/cuml/cuml/tests/test_sklearn_compatibility.py

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,12 @@
1616
from functools import partial
1717

1818
import pytest
19-
from sklearn.kernel_ridge import KernelRidge
20-
from sklearn.naive_bayes import (
21-
BernoulliNB,
22-
CategoricalNB,
23-
ComplementNB,
24-
GaussianNB,
25-
MultinomialNB,
26-
)
2719
from sklearn.utils import estimator_checks
2820

2921
from cuml.cluster import DBSCAN, HDBSCAN, KMeans
3022
from cuml.decomposition import PCA, IncrementalPCA, TruncatedSVD
3123
from cuml.ensemble import RandomForestClassifier, RandomForestRegressor
24+
from cuml.kernel_ridge import KernelRidge
3225
from cuml.linear_model import (
3326
ElasticNet,
3427
Lasso,
@@ -37,6 +30,13 @@
3730
Ridge,
3831
)
3932
from cuml.manifold import TSNE, UMAP
33+
from cuml.naive_bayes import (
34+
BernoulliNB,
35+
CategoricalNB,
36+
ComplementNB,
37+
GaussianNB,
38+
MultinomialNB,
39+
)
4040
from cuml.neighbors import (
4141
KernelDensity,
4242
KNeighborsClassifier,
@@ -74,6 +74,34 @@
7474
"check_fit1d": "KMeans does not raise ValueError for 1D input",
7575
"check_fit2d_predict1d": "KMeans does not handle 1D prediction input gracefully",
7676
},
77+
KernelRidge: {
78+
"check_estimator_tags_renamed": "No support for modern tags infrastructure",
79+
"check_no_attributes_set_in_init": "KernelRidge sets attributes during init",
80+
"check_dont_overwrite_parameters": "KernelRidge overwrites parameters during fit",
81+
"check_estimators_unfitted": "KernelRidge does not raise NotFittedError before fit",
82+
"check_do_not_raise_errors_in_init_or_set_params": "KernelRidge raises errors in init or set_params",
83+
"check_n_features_in_after_fitting": "KernelRidge does not check n_features_in consistency",
84+
"check_estimators_dtypes": "KernelRidge does not handle dtypes properly",
85+
"check_sample_weights_pandas_series": "KernelRidge does not handle pandas Series sample weights",
86+
"check_sample_weights_not_an_array": "KernelRidge does not handle non-array sample weights",
87+
"check_complex_data": "KernelRidge does not handle complex data",
88+
"check_dtype_object": "KernelRidge does not handle object dtype",
89+
"check_estimators_empty_data_messages": "KernelRidge does not handle empty data",
90+
"check_estimators_nan_inf": "KernelRidge does not check for NaN and inf",
91+
"check_estimator_sparse_tag": "KernelRidge does not support sparse data",
92+
"check_estimator_sparse_array": "KernelRidge does not handle sparse arrays gracefully",
93+
"check_estimator_sparse_matrix": "KernelRidge does not handle sparse matrices gracefully",
94+
"check_regressors_train": "KernelRidge does not handle list inputs",
95+
"check_regressors_train(readonly_memmap=True)": "KernelRidge does not handle readonly memmap",
96+
"check_regressors_train(readonly_memmap=True,X_dtype=float32)": "KernelRidge does not handle readonly memmap with float32",
97+
"check_regressor_data_not_an_array": "KernelRidge does not handle non-array data",
98+
"check_supervised_y_2d": "KernelRidge does not handle 2D y",
99+
"check_supervised_y_no_nan": "KernelRidge does not check for NaN in y",
100+
"check_parameters_default_constructible": "KernelRidge parameters are mutated on init",
101+
"check_fit1d": "KernelRidge does not raise ValueError for 1D input",
102+
"check_fit2d_predict1d": "KernelRidge does not handle 1D prediction input gracefully",
103+
"check_requires_y_none": "KernelRidge does not handle y=None",
104+
},
77105
LogisticRegression: {
78106
"check_estimator_tags_renamed": "No support for modern tags infrastructure",
79107
"check_no_attributes_set_in_init": "LogisticRegression sets attributes during init",
@@ -323,7 +351,7 @@
323351
LinearSVR: {
324352
"check_estimator_tags_renamed": "No support for modern tags infrastructure",
325353
"check_no_attributes_set_in_init": "LinearSVR sets attributes during init",
326-
"check_dont_overwrite_parameters": "LinearSVR overwrites parameters during fit",
354+
# "check_dont_overwrite_parameters": "LinearSVR overwrites parameters during fit",
327355
"check_estimators_unfitted": "LinearSVR does not raise NotFittedError before fit",
328356
"check_do_not_raise_errors_in_init_or_set_params": "LinearSVR raises errors in init or set_params",
329357
"check_n_features_in_after_fitting": "LinearSVR does not check n_features_in consistency",
@@ -415,7 +443,7 @@
415443
"check_estimators_nan_inf": "SVR does not check for NaN and inf",
416444
"check_estimator_sparse_tag": "SVR does not support sparse data",
417445
"check_estimator_sparse_array": "SVR does not handle sparse arrays gracefully",
418-
"check_estimator_sparse_matrix": "SVR does not handle sparse matrices gracefully",
446+
# "check_estimator_sparse_matrix": "SVR does not handle sparse matrices gracefully",
419447
"check_regressors_train": "SVR does not handle list inputs",
420448
"check_regressors_train(readonly_memmap=True)": "SVR does not handle readonly memmap",
421449
"check_regressors_train(readonly_memmap=True,X_dtype=float32)": "SVR does not handle readonly memmap with float32",
@@ -514,9 +542,9 @@
514542
"check_dtype_object": "UMAP does not handle object dtype",
515543
"check_estimators_nan_inf": "UMAP does not check for NaN and inf",
516544
"check_estimator_sparse_tag": "UMAP does not support sparse data",
517-
"check_estimator_sparse_matrix": "UMAP does not handle sparse matrices gracefully",
545+
# "check_estimator_sparse_matrix": "UMAP does not handle sparse matrices gracefully",
518546
"check_transformer_data_not_an_array": "UMAP does not handle non-array data",
519-
"check_transformers_unfitted": "UMAP does not raise error when transform called before fit",
547+
# "check_transformers_unfitted": "UMAP does not raise error when transform called before fit",
520548
"check_parameters_default_constructible": "UMAP parameters are mutated on init",
521549
"check_fit_check_is_fitted": "UMAP passes check_is_fitted before being fit",
522550
},
@@ -697,6 +725,16 @@ def test_sklearn_compatible_estimator(estimator, check):
697725
# Check that all estimators pass the "common estimator" checks
698726
# provided by scikit-learn
699727

728+
# These estimators lead to additional MemoryErrors in the other
729+
# estimators. As a result they are currently skipped.
730+
if isinstance(
731+
estimator,
732+
(GaussianNB, ComplementNB, CategoricalNB, BernoulliNB, MultinomialNB),
733+
):
734+
pytest.skip(
735+
"Estimator leads to additional MemoryErrors in other estimators"
736+
)
737+
700738
check_name = _check_name(check)
701739

702740
if check_name == "check_estimators_pickle" and isinstance(

0 commit comments

Comments
 (0)