-
Notifications
You must be signed in to change notification settings - Fork 623
Multiple CPU interop fixes for serialization and cloning #6223
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 16 commits
3d0760e
ca66732
497d181
dce1539
0d797aa
0ef4895
a00af9d
4e70f4c
de3e234
f14a14b
5a02fd0
e0cd0d5
d206bb8
e70c3fb
3785c4e
fd17f09
e7a35a1
742404e
fa41369
4fb4982
b2c31b2
5a5bf59
8eeb5b7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,7 +14,7 @@ | |
|
|
||
| import pytest | ||
| import numpy as np | ||
| import cupy as cp | ||
| from sklearn import clone | ||
|
betatim marked this conversation as resolved.
Outdated
|
||
| from sklearn.datasets import make_classification, make_regression, make_blobs | ||
| from sklearn.linear_model import ( | ||
| LinearRegression, | ||
|
|
@@ -172,6 +172,59 @@ def test_proxy_facade(): | |
|
|
||
| assert original_value == proxy_value | ||
|
|
||
|
|
||
| def test_proxy_clone(): | ||
| # Test that cloning a proxy estimator preserves parameters, even those we | ||
| # translate for the cuml class | ||
| pca = PCA(n_components=42, svd_solver="arpack") | ||
| pca_clone = clone(pca) | ||
|
|
||
| assert pca.get_params() == pca_clone.get_params() | ||
|
|
||
|
|
||
| def test_proxy_params(): | ||
| # Test that parameters match between constructor and get_params() | ||
| # Mix of default and non-default values | ||
| pca = PCA( | ||
| n_components=5, | ||
| copy=False, | ||
| # Pass in an argument and set it to its default value | ||
| whiten=False, | ||
| ) | ||
|
|
||
| params = pca.get_params() | ||
| assert params["n_components"] == 5 | ||
| assert params["copy"] is False | ||
| assert params["whiten"] is False | ||
| # A parameter we never touched, should be the default | ||
| assert params["tol"] == 0.0 | ||
|
|
||
| # Check that get_params doesn't return any unexpected parameters | ||
| expected_params = set( | ||
| [ | ||
| "n_components", | ||
| "copy", | ||
| "whiten", | ||
| "tol", | ||
| "svd_solver", | ||
| "n_oversamples", | ||
| "random_state", | ||
| "iterated_power", | ||
| "power_iteration_normalizer", | ||
| ] | ||
| ) | ||
| assert set(params.keys()) == expected_params | ||
|
|
||
|
|
||
| def test_roundtrip(): | ||
| import cuml | ||
| from sklearn import cluster | ||
|
betatim marked this conversation as resolved.
Outdated
|
||
|
|
||
| km = cluster.KMeans(n_clusters=13) | ||
| ckm = cuml.KMeans.from_sklearn(km) | ||
|
|
||
| assert ckm.n_clusters == 13 | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lucky number 13 :D
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should merge this PR. It improves things and fixes several things. We can keep improving the |
||
|
|
||
|
|
||
| def test_defaults_args_only_methods(): | ||
| # Check that estimator methods that take no arguments work | ||
|
|
@@ -186,6 +239,7 @@ def test_defaults_args_only_methods(): | |
|
|
||
|
|
||
| def test_kernel_ridge(): | ||
| import cupy as cp | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why move this here? Maybe we should leave a comment for people from the future to explain why it can't be imported at the top of the file (or move it back if this was just for debugging) |
||
| rng = np.random.RandomState(42) | ||
|
|
||
| X = 5 * rng.rand(10000, 1) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.