diff --git a/python/cuml/cuml/ensemble/randomforest_common.pyx b/python/cuml/cuml/ensemble/randomforest_common.pyx index 57ebe1237c..b959a97bde 100644 --- a/python/cuml/cuml/ensemble/randomforest_common.pyx +++ b/python/cuml/cuml/ensemble/randomforest_common.pyx @@ -192,6 +192,10 @@ class BaseRandomForestModel(UniversalBase): self.treelite_serialized_model = None self._cpu_model_class_lock = threading.RLock() + def __len__(self): + """Return the number of estimators in the ensemble.""" + return self.n_estimators + def _get_max_feat_val(self) -> float: if isinstance(self.max_features, int): return self.max_features/self.n_cols diff --git a/python/cuml/cuml/tests/test_random_forest.py b/python/cuml/cuml/tests/test_random_forest.py index 76ffe70224..80b5a6648f 100644 --- a/python/cuml/cuml/tests/test_random_forest.py +++ b/python/cuml/cuml/tests/test_random_forest.py @@ -1447,3 +1447,10 @@ def test_rf_predict_returns_int(): clf = cuml.ensemble.RandomForestClassifier().fit(X, y) pred = clf.predict(X) assert pred.dtype == np.int64 + + +def test_ensemble_estimator_length(): + X, y = make_classification() + clf = cuml.ensemble.RandomForestClassifier(n_estimators=3) + clf.fit(X, y) + assert len(clf) == 3