Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions python/cuml/cuml/internals/base.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,6 @@ class Base(TagsMixin,
self._input_type = None
self._input_mem_type = None
self.target_dtype = None
self.n_features_in_ = None

nvtx_benchmark = os.getenv('NVTX_BENCHMARK')
if nvtx_benchmark and nvtx_benchmark.lower() == 'true':
Expand Down Expand Up @@ -479,7 +478,12 @@ class Base(TagsMixin,
if isinstance(X, int):
self.n_features_in_ = X
else:
self.n_features_in_ = X.shape[1]
shape = X.shape
# dataframes can have only one dimension
if len(shape) == 1:
self.n_features_in_ = 1
else:
self.n_features_in_ = shape[1]

def _more_tags(self):
# 'preserves_dtype' tag's Scikit definition currently only applies to
Expand Down
6 changes: 3 additions & 3 deletions python/cuml/cuml/tests/test_cuml_descr_decor.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def get_input(self):

# === Standard Functions ===
def fit(self, X, convert_dtype=True) -> "DummyTestEstimator":

self._set_base_attributes(output_type=X, n_features=X)
return self

def predict(self, X, convert_dtype=True) -> CumlArray:
Expand Down Expand Up @@ -228,11 +228,11 @@ def calc_n_features(shape):
# When cudf and shape[1] is used, a series is created which will
# remove the last shape
if input_type == "cudf" and shape[1] == 1:
return None
return 1

return shape[1]

return None
return 1

assert est._input_type == input_type
assert est.target_dtype is None
Expand Down
Loading