Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/cuml/cuml/internals/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -985,7 +985,7 @@ def from_input(
False (default) to not check at all.

check_rows: boolean (default: False)
Set to an int `i` to check that input X has `i` columns. Set to
Set to an int `i` to check that input X has `i` rows. Set to
False (default) to not check at all.

fail_on_order: boolean (default: False)
Expand Down
2 changes: 1 addition & 1 deletion python/cuml/cuml/internals/input_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def input_to_cuml_array(
(default) to not check at all.

check_rows: boolean (default: False)
Set to an int `i` to check that input X has `i` columns. Set to False
Set to an int `i` to check that input X has `i` rows. Set to False
(default) to not check at all.

fail_on_order: boolean (default: False)
Expand Down
2 changes: 2 additions & 0 deletions python/cuml/cuml/naive_bayes/naive_bayes.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@ def fit(self, X, y, sample_weight=None) -> "GaussianNB":
Weights applied to individual samples (1. for unweighted).
Currently sample weight is ignored.
"""
self.fit_called_ = False
return self._partial_fit(
X,
y,
Expand Down Expand Up @@ -406,6 +407,7 @@ def _partial_fit(
y = input_to_cupy_array(
y,
convert_to_dtype=(expected_y_dtype if convert_dtype else False),
check_rows=X.shape[0],
check_dtype=expected_y_dtype,
).array

Expand Down
35 changes: 34 additions & 1 deletion python/cuml/tests/test_sklearn_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,39 @@
"check_fit1d": "HDBSCAN does not raise ValueError for 1D input",
"check_fit2d_predict1d": "HDBSCAN does not handle 1D prediction input gracefully",
},
GaussianNB: {
"check_estimator_tags_renamed": "No support for modern tags infrastructure",
"check_no_attributes_set_in_init": "GaussianNB sets attributes during init",
"check_dont_overwrite_parameters": "GaussianNB overwrites parameters during fit",
"check_estimators_unfitted": "GaussianNB does not raise NotFittedError before fit",
"check_do_not_raise_errors_in_init_or_set_params": "GaussianNB raises errors in init or set_params",
"check_n_features_in_after_fitting": "GaussianNB does not check n_features_in consistency",
"check_estimators_dtypes": "GaussianNB does not handle dtypes properly",
"check_sample_weights_pandas_series": "GaussianNB does not handle pandas Series sample weights",
"check_sample_weights_not_an_array": "GaussianNB does not handle non-array sample weights",
"check_sample_weights_shape": "GaussianNB does not validate sample weights shape",
"check_sample_weight_equivalence_on_dense_data": "GaussianNB sample weight equivalence not implemented",
"check_complex_data": "GaussianNB does not handle complex data",
"check_dtype_object": "GaussianNB does not handle object dtype",
"check_estimators_empty_data_messages": "GaussianNB does not handle empty data",
"check_estimators_nan_inf": "GaussianNB does not check for NaN and inf",
"check_estimator_sparse_tag": "GaussianNB does not support sparse data",
"check_estimator_sparse_array": "GaussianNB does not handle sparse arrays gracefully",
"check_classifier_data_not_an_array": "GaussianNB does not handle non-array data",
"check_classifiers_classes": "GaussianNB does not handle string data properly",
"check_estimators_partial_fit_n_features": "GaussianNB does not check n_features consistency in partial_fit",
"check_classifiers_train": "GaussianNB does not handle list inputs",
"check_classifiers_train(readonly_memmap=True)": "GaussianNB does not handle readonly memmap",
"check_classifiers_train(readonly_memmap=True,X_dtype=float32)": "GaussianNB does not handle readonly memmap with float32",
"check_classifiers_regression_target": "GaussianNB does not handle regression targets",
"check_supervised_y_no_nan": "GaussianNB does not check for NaN in y",
"check_supervised_y_2d": "GaussianNB does not handle 2D y",
"check_parameters_default_constructible": "GaussianNB parameters are mutated on init",
"check_fit_check_is_fitted": "GaussianNB passes check_is_fitted before being fit",
"check_fit1d": "GaussianNB does not raise ValueError for 1D input",
"check_fit2d_predict1d": "GaussianNB does not handle 1D prediction input gracefully",
"check_requires_y_none": "GaussianNB does not handle y=None",
},
}


Expand Down Expand Up @@ -729,7 +762,7 @@ def test_sklearn_compatible_estimator(estimator, check):
# estimators. As a result they are currently skipped.
if isinstance(
estimator,
(GaussianNB, ComplementNB, CategoricalNB, BernoulliNB, MultinomialNB),
(ComplementNB, CategoricalNB, BernoulliNB, MultinomialNB),
):
pytest.skip(
"Estimator leads to additional MemoryErrors in other estimators (gh-7100)"
Expand Down