Skip to content

Commit 32230d8

Browse files
authored
Accelerate linear model predict on C-ordered inputs (#7329)
This started out as a cleanup PR, but moved to a performance improvement after some benchmarking. `LinearRegression`, `ElasticNet`, `Lasso`, and `Ridge` all share the same `predict` method. This calculates `X.dot(coef.T) + intercept`. Previously we used a function from `libcuml` to compute the single target case, and `cupy` to handle the multitarget case. After some benchmarking, I no longer think using `libcuml` at all here is worth it. It's simpler to always take the `cupy` path, and `cupy` already handles dispatching to cublas appropriately to handle disparate layouts (C vs F). For F-ordered inputs we see roughly the same performance as before. For C-ordered inputs, we see anything from mild speedups (150 us now, vs 200 us before) on small data, to up to 10x speedup on larger data (0.75 ms now vs 8.4 ms before). Presumably this is due to avoiding unnecessary copies to force a uniform F order as we did before. Authors: - Jim Crist-Harif (https://github.com/jcrist) Approvers: - Victor Lafargue (https://github.com/viclafargue) - Simon Adorf (https://github.com/csadorf) URL: #7329
1 parent df3b839 commit 32230d8

3 files changed

Lines changed: 60 additions & 153 deletions

File tree

python/cuml/cuml/linear_model/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515

1616
set(cython_sources "")
17-
add_module_gpu_default("base.pyx" ${linearregression_algo} ${elastic_net_algo} ${ridge_algo} ${linear_model_algo})
1817
add_module_gpu_default("linear_regression.pyx" ${linearregression_algo} ${linear_model_algo})
1918
add_module_gpu_default("ridge.pyx" ${ridge_algo} ${linear_model_algo})
2019

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#
2+
# Copyright (c) 2020-2025, NVIDIA CORPORATION.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
import cuml.internals
17+
from cuml.common.doc_utils import generate_docstring
18+
from cuml.internals.array import CumlArray
19+
from cuml.internals.input_utils import input_to_cuml_array
20+
21+
22+
class LinearPredictMixin:
23+
@generate_docstring(
24+
return_values={
25+
"name": "preds",
26+
"type": "dense",
27+
"description": "Predicted values",
28+
"shape": "(n_samples, 1)",
29+
}
30+
)
31+
@cuml.internals.api_base_return_array_skipall
32+
def predict(self, X, *, convert_dtype=True) -> CumlArray:
33+
"""
34+
Predicts `y` values for `X`.
35+
"""
36+
if getattr(self, "coef_", None) is None:
37+
raise ValueError(
38+
"LinearModel.predict() cannot be called before fit(). "
39+
"Please fit the model first."
40+
)
41+
42+
X = input_to_cuml_array(
43+
X,
44+
check_dtype=self.coef_.dtype,
45+
convert_to_dtype=(self.coef_.dtype if convert_dtype else None),
46+
check_cols=self.n_features_in_,
47+
order="K",
48+
).array
49+
X_cp = X.to_output("cupy")
50+
51+
coef = self.coef_.to_output("cupy")
52+
53+
intercept = self.intercept_
54+
if isinstance(intercept, CumlArray):
55+
intercept = intercept.to_output("cupy")
56+
57+
out = X_cp @ coef.T
58+
out += intercept
59+
60+
return CumlArray(out, index=X.index)

python/cuml/cuml/linear_model/base.pyx

Lines changed: 0 additions & 152 deletions
This file was deleted.

0 commit comments

Comments
 (0)