Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions python/cuml/cuml/internals/api_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,27 +55,32 @@ def _find_arg(sig, arg_name, default_position):

# Check for default name in input args
if arg_name in sig.parameters:
return arg_name, params.index(arg_name)
param = sig.parameters[arg_name]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't remember off the top of my head, but the refactor of decorators you are doing removes this auto inspection of signature, no? Just wondering since it'd be nice to not have magic behavior of this.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain a bit more? The code here looks more or less the same, so maybe you mean something that effects something somewhere else because of this change?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dantegd I am a big confused by your question as well. Yes, the infrastructure revision will eliminate much of the magic and guess work, but here I am just inspecting the signature similar like before, but account for a provided default argument similar to how @betatim structured this in the earlier PR.

return arg_name, params.index(arg_name), param.default
# Otherwise use argument in list by position
elif arg_name is ...:
index = int(_has_self(sig)) + default_position
return params[index], index
param = params[index]
return param, index, sig.parameters[param].default
else:
raise ValueError(f"Unable to find parameter '{arg_name}'.")


def _get_value(args, kwargs, name, index):
def _get_value(args, kwargs, name, index, default_value):
"""Determine value for a given set of args, kwargs, name and index."""
try:
return kwargs[name]
except KeyError:
try:
return args[index]
except IndexError:
raise IndexError(
f"Specified arg idx: {index}, and argument name: {name}, "
"were not found in args or kwargs."
)
if default_value is not inspect._empty:
return default_value
else:
raise IndexError(
f"Specified arg idx: {index}, and argument name: {name}, "
"were not found in args or kwargs."
)


def _make_decorator_function(
Expand Down Expand Up @@ -166,7 +171,7 @@ def wrapper(*args, **kwargs):
if self_val is None:
assert input_val is not None
out_type = iu.determine_array_type(input_val)
elif input_val is None:
elif input_val is None or input_val is inspect._empty:
out_type = self_val.output_type
if out_type == "input":
out_type = self_val._input_type
Expand Down
6 changes: 5 additions & 1 deletion python/cuml/cuml/neighbors/nearest_neighbors.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2019-2024, NVIDIA CORPORATION.
# Copyright (c) 2019-2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -28,6 +28,7 @@ import math
import cuml.internals
from cuml.internals.base import UniversalBase
from cuml.common.array_descriptor import CumlArrayDescriptor
from cuml.internals import api_base_return_generic
from cuml.internals.array import CumlArray
from cuml.internals.array_sparse import SparseCumlArray
from cuml.common.doc_utils import generate_docstring
Expand Down Expand Up @@ -691,6 +692,9 @@ class NearestNeighbors(UniversalBase,
if out_type in {'cupy', 'numpy', 'numba'}:
I_ndarr = I_ndarr[:, 1:]
D_ndarr = D_ndarr[:, 1:]
elif out_type == "cuml":
I_ndarr = CumlArray.from_input(I_ndarr[:, 1:], force_contiguous=True)
D_ndarr = CumlArray.from_input(D_ndarr[:, 1:], force_contiguous=True)
else:
I_ndarr.drop(I_ndarr.columns[0], axis=1)
D_ndarr.drop(D_ndarr.columns[0], axis=1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,18 @@ def test_proxy_facade():
assert original_value == proxy_value


def test_defaults_args_only_methods():
# Check that estimator methods that take no arguments work
# These are slightly weird because basically everything else takes
# a X as input.
X = np.random.rand(1000, 3)
y = X[:, 0] + np.sin(6 * np.pi * X[:, 1]) + 0.1 * np.random.randn(1000)

nn = NearestNeighbors(metric="chebyshev", n_neighbors=3)
nn.fit(X[:, 0].reshape((-1, 1)), y)
nn.kneighbors()


def test_kernel_ridge():
rng = np.random.RandomState(42)

Expand Down