diff --git a/cpp/src/solver/cd.cuh b/cpp/src/solver/cd.cuh index feb0fae2c4..51f534318c 100644 --- a/cpp/src/solver/cd.cuh +++ b/cpp/src/solver/cd.cuh @@ -211,8 +211,9 @@ void cdFit(const raft::handle_t& handle, math_t scalar = math_t(n_rows) + l2_alpha; raft::matrix::setValue(squared.data(), squared.data(), scalar, n_cols, stream); } else { + /* (n_cols * n_rows) may overflow, upcast for indexing */ raft::linalg::colNorm( - squared.data(), input, n_cols, n_rows, stream); + squared.data(), input, int64_t(n_cols), int64_t(n_rows), stream); raft::linalg::addScalar(squared.data(), squared.data(), l2_alpha, n_cols, stream); } @@ -233,10 +234,10 @@ void cdFit(const raft::handle_t& handle, for (int j = 0; j < n_cols; j++) { raft::common::nvtx::range iter_scope("ML::Solver::cdFit::col-%d", j); - int ci = ri[j]; + int64_t ci = ri[j]; math_t* coef_loc = coef + ci; math_t* squared_loc = squared.data() + ci; - math_t* input_col_loc = input + (ci * n_rows); + math_t* input_col_loc = input + (ci * int64_t(n_rows)); // remember current coef raft::copy(&(convStateLoc->coef), coef_loc, 1, stream);