Skip to content

Commit 841b107

Browse files
authored
Merge branch 'branch-25.10' into java/device-matrix-cagra-index-tests
2 parents 00f10f6 + 12ebfa1 commit 841b107

13 files changed

Lines changed: 294 additions & 195 deletions

File tree

.github/workflows/build.yaml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ jobs:
5151
fail-fast: false
5252
matrix:
5353
cuda_version:
54-
- '12.9.1'
55-
- '13.0.0'
54+
- &latest_cuda12 '12.9.1'
55+
- &latest_cuda13 '13.0.1'
5656
with:
5757
build_type: ${{ inputs.build_type || 'branch' }}
5858
branch: ${{ inputs.branch }}
@@ -72,8 +72,8 @@ jobs:
7272
fail-fast: false
7373
matrix:
7474
cuda_version:
75-
- '12.9.1'
76-
- '13.0.0'
75+
- *latest_cuda12
76+
- *latest_cuda13
7777
with:
7878
build_type: ${{ inputs.build_type || 'branch' }}
7979
branch: ${{ inputs.branch }}
@@ -93,8 +93,8 @@ jobs:
9393
fail-fast: false
9494
matrix:
9595
cuda_version:
96-
- '12.9.1'
97-
- '13.0.0'
96+
- *latest_cuda12
97+
- *latest_cuda13
9898
with:
9999
build_type: ${{ inputs.build_type || 'branch' }}
100100
branch: ${{ inputs.branch }}

.github/workflows/pr.yaml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,8 @@ jobs:
159159
fail-fast: false
160160
matrix:
161161
cuda_version:
162-
- '12.9.1'
163-
- '13.0.0'
162+
- &latest_cuda12 '12.9.1'
163+
- &latest_cuda13 '13.0.1'
164164
with:
165165
build_type: pull-request
166166
node_type: "gpu-l4-latest-1"
@@ -189,8 +189,8 @@ jobs:
189189
fail-fast: false
190190
matrix:
191191
cuda_version:
192-
- '12.9.1'
193-
- '13.0.0'
192+
- *latest_cuda12
193+
- *latest_cuda13
194194
with:
195195
build_type: pull-request
196196
node_type: "gpu-l4-latest-1"
@@ -207,8 +207,8 @@ jobs:
207207
fail-fast: false
208208
matrix:
209209
cuda_version:
210-
- '12.9.1'
211-
- '13.0.0'
210+
- *latest_cuda12
211+
- *latest_cuda13
212212
with:
213213
build_type: pull-request
214214
node_type: "gpu-l4-latest-1"

.github/workflows/test.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ jobs:
6060
matrix:
6161
cuda_version:
6262
- '12.9.1'
63-
- '13.0.0'
63+
- '13.0.1'
6464
with:
6565
build_type: ${{ inputs.build_type }}
6666
branch: ${{ inputs.branch }}

cpp/src/cluster/detail/kmeans_mg.cuh

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,20 @@ void initKMeansPlusPlus(const raft::resources& handle,
175175
// X which will be used as the initial centroid for kmeans++
176176
// 1.3 - Communicate the initial centroid chosen by rank-r' to all other
177177
// ranks
178-
std::mt19937 gen(params.rng_state.seed);
179-
std::uniform_int_distribution<> dis(0, n_rank - 1);
180-
int rp = dis(gen);
178+
// Choose rp on rank 0 and broadcast to all ranks to guarantee agreement
179+
int rp = 0;
180+
if (my_rank == KMEANS_COMM_ROOT) {
181+
std::mt19937 gen(params.rng_state.seed);
182+
std::uniform_int_distribution<> dis(0, n_rank - 1);
183+
rp = dis(gen);
184+
}
185+
{
186+
rmm::device_scalar<int> rp_d(stream);
187+
raft::copy(rp_d.data(), &rp, 1, stream);
188+
comm.bcast<int>(rp_d.data(), 1, /*root=*/KMEANS_COMM_ROOT, stream);
189+
raft::copy(&rp, rp_d.data(), 1, stream);
190+
raft::resource::sync_stream(handle);
191+
}
181192

182193
// buffer to flag the sample that is chosen as initial centroids
183194
std::vector<std::uint8_t> h_isSampleCentroid(n_samples);

java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSDeviceMatrix.java

Lines changed: 0 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -30,67 +30,4 @@ default CuVSHostMatrix toHost() {
3030
toHost(hostMatrix);
3131
return hostMatrix;
3232
}
33-
34-
default CuVSDeviceMatrix toDevice(CuVSResources resources) {
35-
return new CuVSDeviceMatrixDelegate(this);
36-
}
37-
38-
class CuVSDeviceMatrixDelegate implements CuVSDeviceMatrix {
39-
40-
private final CuVSDeviceMatrix deviceMatrix;
41-
42-
private CuVSDeviceMatrixDelegate(CuVSDeviceMatrix deviceMatrix) {
43-
this.deviceMatrix = deviceMatrix;
44-
}
45-
46-
@Override
47-
public long size() {
48-
return deviceMatrix.size();
49-
}
50-
51-
@Override
52-
public long columns() {
53-
return deviceMatrix.columns();
54-
}
55-
56-
@Override
57-
public DataType dataType() {
58-
return deviceMatrix.dataType();
59-
}
60-
61-
@Override
62-
public RowView getRow(long row) {
63-
return deviceMatrix.getRow(row);
64-
}
65-
66-
@Override
67-
public void toArray(int[][] array) {
68-
deviceMatrix.toArray(array);
69-
}
70-
71-
@Override
72-
public void toArray(float[][] array) {
73-
deviceMatrix.toArray(array);
74-
}
75-
76-
@Override
77-
public void toArray(byte[][] array) {
78-
deviceMatrix.toArray(array);
79-
}
80-
81-
@Override
82-
public void toHost(CuVSHostMatrix hostMatrix) {
83-
deviceMatrix.toHost(hostMatrix);
84-
}
85-
86-
@Override
87-
public void toDevice(CuVSDeviceMatrix deviceMatrix, CuVSResources cuVSResources) {
88-
this.deviceMatrix.toDevice(deviceMatrix, cuVSResources);
89-
}
90-
91-
@Override
92-
public void close() {
93-
// Do nothing
94-
}
95-
}
9633
}

java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSHostMatrix.java

Lines changed: 0 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -21,76 +21,9 @@
2121
public interface CuVSHostMatrix extends CuVSMatrix {
2222
int get(int row, int col);
2323

24-
default CuVSHostMatrix toHost() {
25-
return new CuVSHostMatrixDelegate(this);
26-
}
27-
2824
default CuVSDeviceMatrix toDevice(CuVSResources resources) {
2925
var deviceMatrix = CuVSMatrix.deviceBuilder(resources, size(), columns(), dataType()).build();
3026
toDevice(deviceMatrix, resources);
3127
return deviceMatrix;
3228
}
33-
34-
class CuVSHostMatrixDelegate implements CuVSHostMatrix {
35-
private final CuVSHostMatrix hostMatrix;
36-
37-
public CuVSHostMatrixDelegate(CuVSHostMatrix cuVSHostMatrix) {
38-
this.hostMatrix = cuVSHostMatrix;
39-
}
40-
41-
@Override
42-
public int get(int row, int col) {
43-
return hostMatrix.get(row, col);
44-
}
45-
46-
@Override
47-
public long size() {
48-
return hostMatrix.size();
49-
}
50-
51-
@Override
52-
public long columns() {
53-
return hostMatrix.columns();
54-
}
55-
56-
@Override
57-
public DataType dataType() {
58-
return hostMatrix.dataType();
59-
}
60-
61-
@Override
62-
public RowView getRow(long row) {
63-
return hostMatrix.getRow(row);
64-
}
65-
66-
@Override
67-
public void toArray(int[][] array) {
68-
hostMatrix.toArray(array);
69-
}
70-
71-
@Override
72-
public void toArray(float[][] array) {
73-
hostMatrix.toArray(array);
74-
}
75-
76-
@Override
77-
public void toArray(byte[][] array) {
78-
hostMatrix.toArray(array);
79-
}
80-
81-
@Override
82-
public void toHost(CuVSHostMatrix hostMatrix) {
83-
this.hostMatrix.toHost(hostMatrix);
84-
}
85-
86-
@Override
87-
public void toDevice(CuVSDeviceMatrix deviceMatrix, CuVSResources cuVSResources) {
88-
hostMatrix.toDevice(deviceMatrix, cuVSResources);
89-
}
90-
91-
@Override
92-
public void close() {
93-
// Do nothing
94-
}
95-
}
9629
}

java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/BruteForceIndexImpl.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ private BruteForceIndexImpl(
7878
Objects.requireNonNull(dataset);
7979
try (dataset) {
8080
this.resources = resources;
81-
assert dataset instanceof CuVSMatrixBaseImpl;
82-
this.bruteForceIndexReference = build((CuVSMatrixBaseImpl) dataset, bruteForceIndexParams);
81+
assert dataset instanceof CuVSMatrixInternal;
82+
this.bruteForceIndexReference = build((CuVSMatrixInternal) dataset, bruteForceIndexParams);
8383
}
8484
}
8585

@@ -124,7 +124,7 @@ public void close() {
124124
* index
125125
*/
126126
private IndexReference build(
127-
CuVSMatrixBaseImpl dataset, BruteForceIndexParams bruteForceIndexParams) {
127+
CuVSMatrixInternal dataset, BruteForceIndexParams bruteForceIndexParams) {
128128
long rows = dataset.size();
129129
long cols = dataset.columns();
130130

java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CagraIndexImpl.java

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ private CagraIndexImpl(
7777
CagraIndexParams indexParameters, CuVSMatrix dataset, CuVSResources resources) {
7878
Objects.requireNonNull(dataset);
7979
this.resources = resources;
80-
assert dataset instanceof CuVSMatrixBaseImpl;
81-
this.cagraIndexReference = build(indexParameters, (CuVSMatrixBaseImpl) dataset);
80+
assert dataset instanceof CuVSMatrixInternal;
81+
this.cagraIndexReference = build(indexParameters, (CuVSMatrixInternal) dataset);
8282
}
8383

8484
/**
@@ -123,11 +123,11 @@ private CagraIndexImpl(
123123

124124
this.resources = resources;
125125

126-
assert graph instanceof CuVSMatrixBaseImpl;
127-
assert dataset instanceof CuVSMatrixBaseImpl;
126+
assert graph instanceof CuVSMatrixInternal;
127+
assert dataset instanceof CuVSMatrixInternal;
128128

129129
this.cagraIndexReference =
130-
fromGraph(metric, (CuVSMatrixBaseImpl) graph, (CuVSMatrixBaseImpl) dataset);
130+
fromGraph(metric, (CuVSMatrixInternal) graph, (CuVSMatrixInternal) dataset);
131131
}
132132

133133
private void checkNotDestroyed() {
@@ -160,7 +160,7 @@ public void close() {
160160
* @return an instance of {@link IndexReference} that holds the pointer to the
161161
* index
162162
*/
163-
private IndexReference build(CagraIndexParams indexParameters, CuVSMatrixBaseImpl dataset) {
163+
private IndexReference build(CagraIndexParams indexParameters, CuVSMatrixInternal dataset) {
164164
long rows = dataset.size();
165165

166166
try (var indexParams = segmentFromIndexParams(indexParameters);
@@ -409,8 +409,8 @@ public CuVSDeviceMatrix getGraph() {
409409

410410
private IndexReference fromGraph(
411411
CagraIndexParams.CuvsDistanceType metric,
412-
CuVSMatrixBaseImpl graph,
413-
CuVSMatrixBaseImpl dataset) {
412+
CuVSMatrixInternal graph,
413+
CuVSMatrixInternal dataset) {
414414
try (var localArena = Arena.ofConfined()) {
415415
var index = createCagraIndex();
416416
try (var resourcesAccess = resources.access()) {

0 commit comments

Comments
 (0)