Skip to content

Make get_param_names a class method on single GPU estimators to match Scikit-learn closer #6101

Merged
rapids-bot[bot] merged 11 commits intorapidsai:branch-24.12from
dantegd:2412-fix-classmethod
Nov 12, 2024
Merged

Make get_param_names a class method on single GPU estimators to match Scikit-learn closer #6101
rapids-bot[bot] merged 11 commits intorapidsai:branch-24.12from
dantegd:2412-fix-classmethod

Conversation

@dantegd
Copy link
Copy Markdown
Member

@dantegd dantegd commented Oct 7, 2024

Small difference between our estimators and Scikit-learn is that get_param_names are a classmethod in sklearn, and not in ours. This can make a few corner cases fail for using our estimators when Scikit-learn like estimators are expected. This PR fixes that.

Note: This will not include dask-based estimators for the time being since they depend on introspection at object creation time.

@dantegd dantegd requested a review from a team as a code owner October 7, 2024 22:43
@dantegd dantegd requested review from cjnolet and divyegala October 7, 2024 22:43
@github-actions github-actions Bot added the Cython / Python Cython or Python issue label Oct 7, 2024
Copy link
Copy Markdown
Member

@divyegala divyegala left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why were none of the Cython files changed?

@dantegd
Copy link
Copy Markdown
Member Author

dantegd commented Oct 8, 2024

lol because I wrote a script to do this and only checked .py files instead of .pyx files too

@dantegd dantegd changed the title Make get_param_names a class method to match Scikit-learn closer Make get_param_names a class method on single GPU estimators to match Scikit-learn closer Oct 8, 2024
@dantegd dantegd added improvement Improvement / enhancement to an existing function non-breaking Non-breaking change labels Oct 8, 2024
@betatim
Copy link
Copy Markdown
Member

betatim commented Oct 14, 2024

Do you want to make them private at the same time? On the one hand, then we'd be 100% the same. On the other hand, the fact that they are private in scikit-learn makes me wonder if this matters (as they aren't part of the public API)?

@dantegd
Copy link
Copy Markdown
Member Author

dantegd commented Oct 16, 2024

@betatim good point, I think matching as close as possible (i.e. entirely) is a good idea

### Implementing `_get_param_names()`

To support cloning, estimators need to implement the function `get_param_names()`. The returned value should be a list of strings of all estimator attributes that are necessary to duplicate the estimator. This method is used in `Base.get_params()` which will collect the collect the estimator param values from this list and pass this dictionary to a new estimator constructor. Therefore, all strings returned by `get_param_names()` should be arguments in `__init__()` otherwise an invalid argument exception will be raised. Most estimators implement `get_param_names()` similar to:
To support cloning, estimators need to implement the function `_get_param_names()`. The returned value should be a list of strings of all estimator attributes that are necessary to duplicate the estimator. This method is used in `Base.get_params()` which will collect the collect the estimator param values from this list and pass this dictionary to a new estimator constructor. Therefore, all strings returned by `_get_param_names()` should be arguments in `__init__()` otherwise an invalid argument exception will be raised. Most estimators implement `_get_param_names()` similar to:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
To support cloning, estimators need to implement the function `_get_param_names()`. The returned value should be a list of strings of all estimator attributes that are necessary to duplicate the estimator. This method is used in `Base.get_params()` which will collect the collect the estimator param values from this list and pass this dictionary to a new estimator constructor. Therefore, all strings returned by `_get_param_names()` should be arguments in `__init__()` otherwise an invalid argument exception will be raised. Most estimators implement `_get_param_names()` similar to:
To support cloning, estimators need to implement the function `_get_param_names()`. The returned value should be a list of strings of all estimator attributes that are necessary to duplicate the estimator. This method is used in `Base.get_params()` which will collect the estimator param values from this list and pass this dictionary to a new estimator constructor. Therefore, all strings returned by `_get_param_names()` should be arguments in `__init__()` otherwise an invalid argument exception will be raised. Most estimators implement `_get_param_names()` similar to:

@betatim
Copy link
Copy Markdown
Member

betatim commented Nov 11, 2024

It is a big rename diff :D

Looks good from a quick check. One thing I noticed: some are classmethods (the majority) but some aren't. Oversight? If on purpose it is maybe worth adding a comment to the ones that aren't to help people from the future understand why they are different.

@dantegd
Copy link
Copy Markdown
Member Author

dantegd commented Nov 11, 2024

@betatim the ones that haven't changed to be class methods are the dask-based estimators, currently they depend on some runtime behavior, I would suggest we do those on a follow up

@betatim
Copy link
Copy Markdown
Member

betatim commented Nov 12, 2024

Ok. Maybe they don't need a comment then.

Time to merge?

@dantegd
Copy link
Copy Markdown
Member Author

dantegd commented Nov 12, 2024

/merge

@rapids-bot rapids-bot Bot merged commit 8e195fb into rapidsai:branch-24.12 Nov 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Cython / Python Cython or Python issue improvement Improvement / enhancement to an existing function non-breaking Non-breaking change

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants