Skip to content

Fix overflow in coordinate descent#7399

Merged
rapids-bot[bot] merged 1 commit intorapidsai:mainfrom
jcrist:fix-overflow-in-cd
Oct 29, 2025
Merged

Fix overflow in coordinate descent#7399
rapids-bot[bot] merged 1 commit intorapidsai:mainfrom
jcrist:fix-overflow-in-cd

Conversation

@jcrist
Copy link
Copy Markdown
Member

@jcrist jcrist commented Oct 28, 2025

Previously our coordinate descent solver would fail on problems where n_cols * n_rows > INT_MAX due to an int overflow.

There were two locations where this was happening:

  • Calculating the offset into the input matrix
  • Within the computation of the L2Norm

The former is a quick local fix. The latter I also fixed locally by switching from an int to a int64_t in the template call. However, I'm not sure if that's the best fix, or if it'd be better to handle this upstream within the template itself to avoid overflow of the index types. This was easy to do so I did it here for now.

I've checked, and with this we can solve very large coordinate descent problems, with the dimension limitation now being INT_MAX in both rows and columns. Moving larger than that would require using the 64 bit cublas API, but I have no need for that now.

Additionally, on the python side if a user tries to pass a larger value they'll get a nicer python-side error, rather than a cublas error code (and a potentially corrupted handle).

Fixes #6736.

@jcrist jcrist self-assigned this Oct 28, 2025
@jcrist jcrist requested a review from a team as a code owner October 28, 2025 15:44
@jcrist jcrist requested a review from viclafargue October 28, 2025 15:44
@jcrist jcrist added bug Something isn't working non-breaking Non-breaking change labels Oct 28, 2025
@jcrist
Copy link
Copy Markdown
Member Author

jcrist commented Oct 28, 2025

I didn't add a test since that would require a non-negligible amount of GPU memory (~10 GiB for inputs, + whatever the solver overhead is)

The following quick check though runs fine now, but used to result in cublas errors:

Script:

import cupy as cp
from cuml.datasets import make_regression
from cuml.linear_model import ElasticNet

N = 1_000_000
M = 2200  # N * M exceeds max int32
X, y = make_regression(n_samples=N, n_features=M, random_state=42)
X = cp.asfortranarray(X)
weights = cp.random.uniform(0, 1, size=y.shape, dtype=y.dtype)

for sample_weight in [None, weights]:
    model = ElasticNet().fit(X, y, sample_weight=sample_weight)
    print(model.score(X, y))

Output:

0.8838382363319397
0.8566173315048218

Copy link
Copy Markdown
Contributor

@viclafargue viclafargue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for spotting this! I would strongly suggest updating the function signature to make sure that n_rows is int64_t and also make ci be a int64_t too. This would prevent issues related to future code update using these as multiplicative operands. More importantly the RAFT operations are templated, and may possibly be vulnerable to integer overflows (especially the ones involving both rows and columns, see here). Using int64_t would solve this too. Additionally there might be similar patterns in the multi-GPU version of CD (see here or here).

@jcrist
Copy link
Copy Markdown
Member Author

jcrist commented Oct 28, 2025

I would strongly suggest updating the function signature to make sure that n_rows is int64_t and also make ci be a int64_t too.

I'm not sure I follow. Even with this PR, n_rows and n_cols can only be at most INT_MAX due to constraints in cublas and other operations. I don't think we should use a larger integer type in these method signatures than we can actually support. Internally of course is fine.

More importantly the RAFT operations are templated, and may possibly be vulnerable to integer overflows (especially the ones involving both rows and columns, see here). Using int64_t would solve this too.

Yes, that's what the problem in colNorm was that I fixed here. I didn't find a problem in matrixVectorBinaryMult that you linked, but of course we could also use int64_t there too if that'd make you more comfortable.

Additionally there might be similar patterns in the multi-GPU version of CD (see here or here).

I don't want to touch the multi-gpu versions here, please keep this PR limited to just the single GPU code.


Are you suggesting we do something like

int64_t n_rows64 = n_rows;
int64_t n_cols64 = n_cols;

then use the 64bit versions everywhere withing fitCD (along with updating the type of ci)? If so I'm happy to make that change.

@viclafargue
Copy link
Copy Markdown
Contributor

viclafargue commented Oct 28, 2025

I'm not sure I follow. Even with this PR, n_rows and n_cols can only be at most INT_MAX due to constraints in cublas and other operations. I don't think we should use a larger integer type in these method signatures than we can actually support. Internally of course is fine.

Yes, we won't support very large number of rows, but we want to make sure what you are fixing here won't reappear if we multiply the number of rows with something else.

Are you suggesting we do something like
int64_t n_rows64 = n_rows;
int64_t n_cols64 = n_cols;

Could work yes.

I didn't find a problem in matrixVectorBinaryMult that you linked, but of course we could also use int64_t there too if that'd make you more comfortable.

If there isn't any issue with matrixVectorBinaryMult then, its fine. But, having the inputs to RAFT operations as int64_t would ensure that we never have multiplications fail internally inside of RAFT.

Previously our coordinate descent solver would fail on problems where
`n_cols * n_rows > INT_MAX` due to an `int` overflow.

There were two locations where this was happening:

- Calculating the offset into the input matrix
- Within the computation of the L2Norm

The former is a quick local fix. The latter I also fixed locally by
switching from an `int` to a `int64_t` in the template call. However,
I'm not sure if that's the best fix, or if it'd be better to handle this
upstream within the template itself to avoid overflow of the index
types. This was easy to do so I did it here for now.

I've checked, and with this we can solve very large coordinate descent
problems, with the dimension limitation now being `INT_MAX` in both rows
and columns. Moving larger than that would require using the 64 bit
cublas API, but I have no need for that now.

Additionally, on the python side if a user tries to pass a larger value
they'll get a nicer python-side error, rather than a cublas error code
(and a potentially corrupted handle).
@jcrist jcrist force-pushed the fix-overflow-in-cd branch from 1866fc9 to 5678dd6 Compare October 28, 2025 18:06
@jcrist
Copy link
Copy Markdown
Member Author

jcrist commented Oct 28, 2025

I updated ci to int64_t, but left everything else the same.

  • Some operations only support int, we'd have to downcast in those locations to avoid a compile-time error, which didn't necessarily seem better than what I have here.
  • matrixVectorBinaryMult actually does the wrong thing if provided int64_t types (haven't looked into why), while it continues to give the correct answer when passed int.

Feels to me like the solution I have here is the best option for now. If you believe otherwise, I'd love to hear some specific suggestions for improvements.

@jcrist
Copy link
Copy Markdown
Member Author

jcrist commented Oct 29, 2025

/merge

@rapids-bot rapids-bot Bot merged commit c7cd9ca into rapidsai:main Oct 29, 2025
105 of 106 checks passed
@jcrist jcrist deleted the fix-overflow-in-cd branch October 29, 2025 12:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working CUDA/C++ non-breaking Non-breaking change

Projects

None yet

Development

Successfully merging this pull request may close these issues.

ElasticNet.fit causes cublas error on large enough data

3 participants