Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions python/cuml/cuml/naive_bayes/naive_bayes.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@ def fit(self, X, y, sample_weight=None) -> "GaussianNB":
Weights applied to individual samples (1. for unweighted).
Currently sample weight is ignored.
"""
self.fit_called_ = False
return self._partial_fit(
X,
y,
Expand Down Expand Up @@ -409,6 +410,12 @@ def _partial_fit(
check_dtype=expected_y_dtype,
).array

if X.shape[0] != y.shape[0]:
raise ValueError(
"X and y must have the same number of samples. "
f"Got {X.shape[0]} and {y.shape[0]}."
)

Comment thread
betatim marked this conversation as resolved.
Outdated
if _classes is not None:
_classes, *_ = input_to_cuml_array(
_classes,
Expand Down
35 changes: 34 additions & 1 deletion python/cuml/tests/test_sklearn_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,39 @@
"check_fit1d": "HDBSCAN does not raise ValueError for 1D input",
"check_fit2d_predict1d": "HDBSCAN does not handle 1D prediction input gracefully",
},
GaussianNB: {
"check_estimator_tags_renamed": "No support for modern tags infrastructure",
"check_no_attributes_set_in_init": "GaussianNB sets attributes during init",
"check_dont_overwrite_parameters": "GaussianNB overwrites parameters during fit",
"check_estimators_unfitted": "GaussianNB does not raise NotFittedError before fit",
"check_do_not_raise_errors_in_init_or_set_params": "GaussianNB raises errors in init or set_params",
"check_n_features_in_after_fitting": "GaussianNB does not check n_features_in consistency",
"check_estimators_dtypes": "GaussianNB does not handle dtypes properly",
"check_sample_weights_pandas_series": "GaussianNB does not handle pandas Series sample weights",
"check_sample_weights_not_an_array": "GaussianNB does not handle non-array sample weights",
"check_sample_weights_shape": "GaussianNB does not validate sample weights shape",
"check_sample_weight_equivalence_on_dense_data": "GaussianNB sample weight equivalence not implemented",
"check_complex_data": "GaussianNB does not handle complex data",
"check_dtype_object": "GaussianNB does not handle object dtype",
"check_estimators_empty_data_messages": "GaussianNB does not handle empty data",
"check_estimators_nan_inf": "GaussianNB does not check for NaN and inf",
"check_estimator_sparse_tag": "GaussianNB does not support sparse data",
"check_estimator_sparse_array": "GaussianNB does not handle sparse arrays gracefully",
"check_classifier_data_not_an_array": "GaussianNB does not handle non-array data",
"check_classifiers_classes": "GaussianNB does not handle string data properly",
"check_estimators_partial_fit_n_features": "GaussianNB does not check n_features consistency in partial_fit",
"check_classifiers_train": "GaussianNB does not handle list inputs",
"check_classifiers_train(readonly_memmap=True)": "GaussianNB does not handle readonly memmap",
"check_classifiers_train(readonly_memmap=True,X_dtype=float32)": "GaussianNB does not handle readonly memmap with float32",
"check_classifiers_regression_target": "GaussianNB does not handle regression targets",
"check_supervised_y_no_nan": "GaussianNB does not check for NaN in y",
"check_supervised_y_2d": "GaussianNB does not handle 2D y",
"check_parameters_default_constructible": "GaussianNB parameters are mutated on init",
"check_fit_check_is_fitted": "GaussianNB passes check_is_fitted before being fit",
"check_fit1d": "GaussianNB does not raise ValueError for 1D input",
"check_fit2d_predict1d": "GaussianNB does not handle 1D prediction input gracefully",
"check_requires_y_none": "GaussianNB does not handle y=None",
},
}


Expand Down Expand Up @@ -729,7 +762,7 @@ def test_sklearn_compatible_estimator(estimator, check):
# estimators. As a result they are currently skipped.
if isinstance(
estimator,
(GaussianNB, ComplementNB, CategoricalNB, BernoulliNB, MultinomialNB),
(ComplementNB, CategoricalNB, BernoulliNB, MultinomialNB),
):
pytest.skip(
"Estimator leads to additional MemoryErrors in other estimators (gh-7100)"
Expand Down