11/*
2- * SPDX-FileCopyrightText: Copyright (c) 2019-2025 , NVIDIA CORPORATION.
2+ * SPDX-FileCopyrightText: Copyright (c) 2019-2026 , NVIDIA CORPORATION.
33 * SPDX-License-Identifier: Apache-2.0
44 */
55
2727#include < thrust/device_ptr.h>
2828#include < thrust/execution_policy.h>
2929#include < thrust/iterator/constant_iterator.h>
30+ #include < thrust/iterator/counting_iterator.h>
3031#include < thrust/reverse.h>
32+ #include < thrust/transform.h>
3133
3234#include < cuvs/distance/distance.hpp>
3335#include < cuvs/distance/grammian.hpp>
@@ -40,6 +42,50 @@ namespace SVM {
4042
4143namespace { // unnamed namespace to avoid multiple definition error
4244
45+ /* *
46+ * @brief Extract columns from a matrix for precomputed kernels
47+ *
48+ * Given a matrix src of shape (n_rows_src, n_cols_src), extract columns
49+ * specified by col_indices and store in dst of shape (n_rows_src, n_cols_dst).
50+ *
51+ * @param [out] dst destination matrix, size [n_rows_src x n_cols_dst]
52+ * @param [in] src source matrix, size [n_rows_src x n_cols_src]
53+ * @param [in] n_rows_src number of rows in source matrix
54+ * @param [in] col_indices column indices to extract, size [n_cols_dst]
55+ * @param [in] n_cols_dst number of columns to extract
56+ */
57+ template <typename math_t >
58+ CUML_KERNEL void extractColumnsKernel (
59+ math_t * dst, const math_t * src, int n_rows_src, const int * col_indices, int n_cols_dst)
60+ {
61+ int64_t tid = static_cast <int64_t >(threadIdx .x ) + static_cast <int64_t >(blockIdx .x ) * blockDim .x ;
62+ int64_t total = static_cast <int64_t >(n_rows_src) * n_cols_dst;
63+ if (tid < total) {
64+ int64_t row = tid % n_rows_src;
65+ int64_t col = tid / n_rows_src;
66+ int src_col = col_indices[col];
67+ // Both source and destination are column-major:
68+ // src[row, col] = src[row + col * n_rows_src]
69+ // dst[row, col] = dst[row + col * n_rows_src] = dst[tid]
70+ dst[tid] = src[row + static_cast <int64_t >(src_col) * n_rows_src];
71+ }
72+ }
73+
74+ template <typename math_t >
75+ void extractColumnsForPrecomputed (math_t * dst,
76+ const math_t * src,
77+ int n_rows_src,
78+ const int * col_indices,
79+ int n_cols_dst,
80+ cudaStream_t stream)
81+ {
82+ int total = n_rows_src * n_cols_dst;
83+ int TPB = 256 ;
84+ int n_blocks = raft::ceildiv (total, TPB);
85+ extractColumnsKernel<math_t >
86+ <<<n_blocks, TPB, 0 , stream>>> (dst, src, n_rows_src, col_indices, n_cols_dst);
87+ }
88+
4389/* *
4490 * @brief Re-raise working set indexes to SVR scope [0..2*n_rows)
4591 *
@@ -322,6 +368,8 @@ class KernelCache {
322368 * @param dense_extract_byte_limit sparse rows will be extracted as dense
323369 * up to this limit to speed up kernel computation. Only valid
324370 * for sparse input. (default 1GB)
371+ * @param is_precomputed if true, the matrix is a precomputed kernel matrix
372+ * and no kernel computation is performed
325373 */
326374 KernelCache (const raft::handle_t & handle,
327375 MatrixViewType matrix,
@@ -333,7 +381,8 @@ class KernelCache {
333381 float cache_size = 200 ,
334382 SvmType svmType = C_SVC,
335383 size_t kernel_tile_byte_limit = 1 << 30 ,
336- size_t dense_extract_byte_limit = 1 << 30 )
384+ size_t dense_extract_byte_limit = 1 << 30 ,
385+ bool is_precomputed = false )
337386 : batch_cache(n_rows, cache_size, handle.get_stream()),
338387 handle (handle),
339388 kernel(kernel),
@@ -343,6 +392,7 @@ class KernelCache {
343392 n_cols(n_cols),
344393 n_ws(n_ws),
345394 svmType(svmType),
395+ is_precomputed(is_precomputed),
346396 kernel_tile(0 , handle.get_stream()),
347397 matrix_l2(0 , handle.get_stream()),
348398 matrix_l2_ws(0 , handle.get_stream()),
@@ -353,7 +403,7 @@ class KernelCache {
353403 indptr_batched(0 , handle.get_stream()),
354404 ws_cache_idx(n_ws * 2 , handle.get_stream())
355405 {
356- ASSERT (kernel != nullptr , " Kernel pointer required for KernelCache!" );
406+ ASSERT (kernel != nullptr || is_precomputed , " Kernel pointer required for KernelCache!" );
357407 stream = handle.get_stream ();
358408
359409 batching_enabled = false ;
@@ -386,8 +436,8 @@ class KernelCache {
386436 x_ws_dense.resize (n_ws * static_cast <size_t >(n_cols), stream);
387437 }
388438
389- // store matrix l2 norm for RBF kernels
390- if (kernel_type == cuvs::distance::kernels::KernelType::RBF) {
439+ // store matrix l2 norm for RBF kernels (not needed for precomputed)
440+ if (!is_precomputed && kernel_type == cuvs::distance::kernels::KernelType::RBF) {
391441 matrix_l2.resize (n_rows, stream);
392442 matrix_l2_ws.resize (n_ws, stream);
393443 ML::SVM::matrixRowNorm (handle, matrix, matrix_l2.data (), raft::linalg::NormType::L2Norm);
@@ -507,33 +557,41 @@ class KernelCache {
507557 ML::SVM::extractRows<math_t >(matrix, x_ws_dense.data (), ws_idx_mod.data (), n_ws, handle);
508558 }
509559
510- // extract dot array for RBF
511- if (kernel_type == cuvs::distance::kernels::KernelType::RBF) {
512- selectValueSubset (matrix_l2_ws.data (), matrix_l2.data (), ws_idx_mod.data (), n_ws);
513- }
560+ if (is_precomputed) {
561+ // For precomputed kernels, x_ws_dense contains K[ws, :] (shape n_ws x n_cols)
562+ // We need to extract columns ws to get K[ws, ws]
563+ // Since n_cols == n_rows for precomputed, we extract columns using ws_idx_mod
564+ extractColumnsForPrecomputed (
565+ kernel_tile.data (), x_ws_dense.data (), n_ws, ws_idx_mod.data (), n_ws, stream);
566+ } else {
567+ // extract dot array for RBF
568+ if (kernel_type == cuvs::distance::kernels::KernelType::RBF) {
569+ selectValueSubset (matrix_l2_ws.data (), matrix_l2.data (), ws_idx_mod.data (), n_ws);
570+ }
514571
515- // compute kernel
516- {
517- if (sparse_extract) {
518- auto ws_view = getViewWithFixedDimension (*x_ws_csr, n_ws, n_cols);
519- KernelOp (handle,
520- kernel,
521- ws_view,
522- ws_view,
523- kernel_tile.data (),
524- matrix_l2_ws.data (),
525- matrix_l2_ws.data ());
526- } else {
527- KernelOp (handle,
528- kernel,
529- x_ws_dense.data (),
530- n_ws,
531- n_cols,
532- x_ws_dense.data (),
533- n_ws,
534- kernel_tile.data (),
535- matrix_l2_ws.data (),
536- matrix_l2_ws.data ());
572+ // compute kernel
573+ {
574+ if (sparse_extract) {
575+ auto ws_view = getViewWithFixedDimension (*x_ws_csr, n_ws, n_cols);
576+ KernelOp (handle,
577+ kernel,
578+ ws_view,
579+ ws_view,
580+ kernel_tile.data (),
581+ matrix_l2_ws.data (),
582+ matrix_l2_ws.data ());
583+ } else {
584+ KernelOp (handle,
585+ kernel,
586+ x_ws_dense.data (),
587+ n_ws,
588+ n_cols,
589+ x_ws_dense.data (),
590+ n_ws,
591+ kernel_tile.data (),
592+ matrix_l2_ws.data (),
593+ matrix_l2_ws.data ());
594+ }
537595 }
538596 }
539597 return kernel_tile.data ();
@@ -641,24 +699,51 @@ class KernelCache {
641699 int * ws_idx_new = batch_descriptor.nz_da_idx + n_cached;
642700 math_t * tile_new = kernel_tile.data () + (size_t )n_cached * batch_size;
643701
644- auto batch_matrix = getMatrixBatch (
645- matrix, batch_size, offset, host_indptr.data (), indptr_batched.data (), stream);
646-
647- // compute kernel
648- math_t * norm_with_offset = matrix_l2.data () != nullptr ? matrix_l2.data () + offset : nullptr ;
649- if (sparse_extract) {
650- auto ws_view = getViewWithFixedDimension (*x_ws_csr, n_uncached, n_cols);
651- KernelOp (
652- handle, kernel, batch_matrix, ws_view, tile_new, norm_with_offset, matrix_l2_ws.data ());
702+ if (is_precomputed) {
703+ // For precomputed kernels, extract K[offset:offset+batch_size, ws_idx_new]
704+ // Input matrix is column-major: K[row, col] = K[row + col * n_rows]
705+ // Output tile_new is column-major: tile_new[i, j] = tile_new[i + j * batch_size]
706+ if constexpr (isDenseType<MatrixViewType>()) {
707+ const math_t * matrix_data = getDenseData (matrix);
708+ thrust::counting_iterator<int > iter (0 );
709+ int n_elems = batch_size * n_uncached;
710+ int matrix_rows = n_rows; // Copy member to local for lambda capture
711+ thrust::transform (
712+ thrust::cuda::par.on (stream),
713+ iter,
714+ iter + n_elems,
715+ tile_new,
716+ [matrix_data, ws_idx_new, matrix_rows, offset, batch_size] __device__ (int tid) {
717+ // Column-major output: tile_new[row, col] = tile_new[row + col * batch_size]
718+ int row = tid % batch_size;
719+ int col = tid / batch_size;
720+ int src_row = offset + row;
721+ int src_col = ws_idx_new[col];
722+ // Column-major input: K[row, col] = K[row + col * matrix_rows]
723+ return matrix_data[src_row + src_col * matrix_rows];
724+ });
725+ }
653726 } else {
654- KernelOp (handle,
655- kernel,
656- batch_matrix,
657- x_ws_dense.data (),
658- n_uncached,
659- tile_new,
660- norm_with_offset,
661- matrix_l2_ws.data ());
727+ auto batch_matrix = getMatrixBatch (
728+ matrix, batch_size, offset, host_indptr.data (), indptr_batched.data (), stream);
729+
730+ // compute kernel
731+ math_t * norm_with_offset =
732+ matrix_l2.data () != nullptr ? matrix_l2.data () + offset : nullptr ;
733+ if (sparse_extract) {
734+ auto ws_view = getViewWithFixedDimension (*x_ws_csr, n_uncached, n_cols);
735+ KernelOp (
736+ handle, kernel, batch_matrix, ws_view, tile_new, norm_with_offset, matrix_l2_ws.data ());
737+ } else {
738+ KernelOp (handle,
739+ kernel,
740+ batch_matrix,
741+ x_ws_dense.data (),
742+ n_uncached,
743+ tile_new,
744+ norm_with_offset,
745+ matrix_l2_ws.data ());
746+ }
662747 }
663748
664749 RAFT_CUDA_TRY (cudaPeekAtLastError ());
@@ -757,6 +842,7 @@ class KernelCache {
757842
758843 cuvs::distance::kernels::GramMatrixBase<math_t >* kernel;
759844 cuvs::distance::kernels::KernelType kernel_type;
845+ bool is_precomputed; // !< if true, matrix is precomputed kernel
760846
761847 int n_rows; // !< number of rows in x
762848 int n_cols; // !< number of columns in x
0 commit comments