Skip to content

Support dual_gap_ on ElasticNet & Lasso#6714

Closed
jcrist wants to merge 3 commits intorapidsai:branch-25.06from
jcrist:dual-gap
Closed

Support dual_gap_ on ElasticNet & Lasso#6714
jcrist wants to merge 3 commits intorapidsai:branch-25.06from
jcrist:dual-gap

Conversation

@jcrist
Copy link
Copy Markdown
Member

@jcrist jcrist commented May 13, 2025

This is to improve sklearn compatibility. Both ElasticNet and Lasso in sklearn compute the dual gap as part of the coordinate descent solver, and store it as dual_gap_ on the fit model.

For now, we compute the dual_gap_ from the final fit state and store it the same. Two notes:

  • Our CD solver doesn't use the dual gap as part of its stopping criteria, while sklearn's does. This means that we in practice stop fitting earlier in the optimization, resulting in a larger dual gap for the same tolerance. We could (and maybe should) update our solver to better match that of sklearn's.
  • Our CD solver computes much of what's needed for calculating the dual gap as part of the fit. Doing this with cupy after the fit repeats work. However, for problems of a meaningful size computing the dual gap afterwards is small compared to the cost of the fit. I think doing the easier thing and leaving things in python makes sense for now.

Fixes #6467.

@jcrist jcrist self-assigned this May 13, 2025
@jcrist jcrist requested a review from a team as a code owner May 13, 2025 19:09
@jcrist jcrist added the Cython / Python Cython or Python issue label May 13, 2025
@jcrist jcrist requested review from bdice and cjnolet May 13, 2025 19:09
@jcrist jcrist added improvement Improvement / enhancement to an existing function non-breaking Non-breaking change cuml-accel Issues related to cuml.accel labels May 13, 2025
Copy link
Copy Markdown
Contributor

@csadorf csadorf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM to me overall, but it seems like CI is picking up some real regressions.

Also, can you create an issue for the necessary follow-up work to address the extra work for dual gap computation?

jcrist added 2 commits May 15, 2025 07:40
This is to improve sklearn compatibility. Both `ElasticNet` and `Lasso`
in sklearn compute the dual gap as part of the coordinate descent
solver, and store it as `dual_gap_` on the fit model.

For now, we compute the `dual_gap_` from the final fit state and store
it the same. Two notes:

- Our CD solver doesn't use the dual gap as part of its stopping
  criteria, while sklearn's does. This means that we in practice stop
  fitting earlier in the optimization, resulting in a larger dual gap
  for the same tolerance. We could (and maybe should) update our solver
  to better match that of sklearn's.
- Our CD solver computes much of what's needed for calculating the dual
  gap as part of the fit. Doing this with cupy after the fit repeats
  work. However, for problems of a meaningful size computing the dual
  gap afterwards is negligible compared to the cost of the fit. I think
  doing the easier thing and leaving things in python makes sense for
  now.
@jcrist
Copy link
Copy Markdown
Member Author

jcrist commented May 15, 2025

I believe I've fixed the regressions.

I was trying to benchmark to see the effect this would have, but ran into errors on larger problems (see #6736).

A few general observations:

  • ElasticNet.fit is fast, sklearn runs problems of the size I'm able to run here much much slower
  • Computing the dual_gap_ after fitting scales with the size of the input data, but not with the number of iterations taken or the tolerance.
  • There's a slowdown, but is it important when we're going from 0.3 to 0.5 s for a single fit call? Hard to say. I do suspect the ratio of overhead will decrease as problem size increases, but currently ElasticNet.fit causes cublas error on large enough data #6736 prevents running ElasticNet on larger problems so we can't benchmark that (and also no current users could be running problems of that size).

Anyway, here's the benchmark:

from cuml.datasets import make_regression
from cuml.linear_model import ElasticNet
from time import perf_counter

def bench(N, M, n_runs=5):
    model = ElasticNet()
    X, y = make_regression(n_samples=N, n_features=M, random_state=42)

    times = []
    for _ in range(n_runs):
        start = perf_counter()
        model.fit(X, y)
        stop = perf_counter()
        times.append(stop - start)

    duration = sum(times) / n_runs
    print(f"shape = ({N}, {M}): {duration:.2} s")


for N in [100_000, 1_000_000]:
    for M in [500, 2000]:
        bench(N, M)

And the results:

# Before
shape = (100000, 500): 0.055 s
shape = (100000, 2000): 0.15 s
shape = (1000000, 500): 0.065 s
shape = (1000000, 2000): 0.32 s

# After
shape = (100000, 500): 0.08 s
shape = (100000, 2000): 0.17 s
shape = (1000000, 500): 0.12 s
shape = (1000000, 2000): 0.53 s

Right now I feel like merging this still makes sense, but we might want to prioritize a followup later to fix the bug and move computation into the solver itself.

@jcrist
Copy link
Copy Markdown
Member Author

jcrist commented May 15, 2025

Dask tests are failing due to #6737. Looks like there's 2 xpassing tests in the sklearn test suite. If we think this is still worth merging, I'll fix and push again.

Copy link
Copy Markdown
Contributor

@csadorf csadorf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the benchmark results, I think I will have to have another look at this, considering that we might introduce a significant performance regression.

@jcrist
Copy link
Copy Markdown
Member Author

jcrist commented May 19, 2025

Closing this since we'll want to handle this in the solver itself. I've opened #6759 to track that.

@jcrist jcrist closed this May 19, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cuml-accel Issues related to cuml.accel Cython / Python Cython or Python issue improvement Improvement / enhancement to an existing function non-breaking Non-breaking change

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] Missing dual_gap_ attribute in Lasso/ElasticNet implementation

3 participants