Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions java/cuvs-java/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
<maven.compiler.source>22</maven.compiler.source>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
<version>${project.version}</version>
</properties>

<distributionManagement>
Expand Down Expand Up @@ -155,6 +156,22 @@
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>properties-maven-plugin</artifactId>
<version>1.2.1</version>
<executions>
<execution>
<phase>generate-resources</phase>
<goals>
<goal>write-project-properties</goal>
</goals>
<configuration>
<outputFile>${project.build.outputDirectory}/version.properties</outputFile>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-assembly-plugin</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.util.ArrayList;
import java.util.List;
import java.util.ServiceLoader;

/**
Expand All @@ -43,25 +45,34 @@ private static CuVSProvider loadProvider() {
}

static CuVSProvider builtinProvider() {
if (Runtime.version().feature() > 21 && isLinuxAmd64()) {
var supportedJavaRuntime = Runtime.version().feature() > 21;
var supportedOs = System.getProperty("os.name").startsWith("Linux");
var supportedArchitecture = System.getProperty("os.arch").equals("amd64");
if (supportedJavaRuntime && supportedOs && supportedArchitecture) {
try {
var cls = Class.forName("com.nvidia.cuvs.spi.JDKProvider");
var ctr = MethodHandles.lookup().findConstructor(cls, MethodType.methodType(void.class));
var ctr =
MethodHandles.lookup()
.findStatic(cls, "create", MethodType.methodType(CuVSProvider.class));
return (CuVSProvider) ctr.invoke();
} catch (ProviderInitializationException e) {
return new UnsupportedProvider("cannot create JDKProvider: " + e.getMessage());
} catch (Throwable e) {
throw new AssertionError(e);
}
}
return new UnsupportedProvider();
}
List<String> unsupportedReasons = new ArrayList<>();
if (!supportedJavaRuntime) {
unsupportedReasons.add("cuvs-java requires Java Runtime version 22 or greater");
}
if (!supportedOs) {
unsupportedReasons.add("cuvs-java supports only Linux");
}
if (!supportedArchitecture) {
unsupportedReasons.add("cuvs-java supports only x86");
}

/**
* Returns true iff the architecture is x64 (amd64) and the OS Linux
* (the * OS we currently support for the native lib).
*/
static boolean isLinuxAmd64() {
String name = System.getProperty("os.name");
return (name.startsWith("Linux")) && System.getProperty("os.arch").equals("amd64");
return new UnsupportedProvider(String.join("; ", unsupportedReasons));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.nvidia.cuvs.spi;

class ProviderInitializationException extends Exception {
ProviderInitializationException(String message, Throwable cause) {
super(message, cause);
}

public ProviderInitializationException(String message) {
super(message);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,51 +24,57 @@
*/
final class UnsupportedProvider implements CuVSProvider {

private final String reasons;

public UnsupportedProvider(String reasons) {
this.reasons = reasons;
}

@Override
public CuVSResources newCuVSResources(Path tempDirectory) {
throw new UnsupportedOperationException();
throw new UnsupportedOperationException(reasons);
}

@Override
public BruteForceIndex.Builder newBruteForceIndexBuilder(CuVSResources cuVSResources) {
throw new UnsupportedOperationException();
throw new UnsupportedOperationException(reasons);
}

@Override
public CagraIndex.Builder newCagraIndexBuilder(CuVSResources cuVSResources) {
throw new UnsupportedOperationException();
throw new UnsupportedOperationException(reasons);
}

@Override
public HnswIndex.Builder newHnswIndexBuilder(CuVSResources cuVSResources) {
throw new UnsupportedOperationException();
throw new UnsupportedOperationException(reasons);
}

@Override
public TieredIndex.Builder newTieredIndexBuilder(CuVSResources cuVSResources) {
throw new UnsupportedOperationException();
throw new UnsupportedOperationException(reasons);
}

@Override
public CagraIndex mergeCagraIndexes(CagraIndex[] indexes) {
throw new UnsupportedOperationException();
throw new UnsupportedOperationException(reasons);
}

@Override
public CuVSMatrix.Builder<CuVSHostMatrix> newHostMatrixBuilder(
long size, long dimensions, CuVSMatrix.DataType dataType) {
throw new UnsupportedOperationException();
throw new UnsupportedOperationException(reasons);
}

@Override
public CuVSMatrix.Builder<CuVSDeviceMatrix> newDeviceMatrixBuilder(
CuVSResources cuVSResources, long size, long dimensions, CuVSMatrix.DataType dataType) {
throw new UnsupportedOperationException();
throw new UnsupportedOperationException(reasons);
}

@Override
public GPUInfoProvider gpuInfoProvider() {
throw new UnsupportedOperationException();
throw new UnsupportedOperationException(reasons);
}

@Override
Expand All @@ -79,26 +85,26 @@ public CuVSMatrix.Builder<CuVSDeviceMatrix> newDeviceMatrixBuilder(
int rowStride,
int columnStride,
CuVSMatrix.DataType dataType) {
throw new UnsupportedOperationException();
throw new UnsupportedOperationException(reasons);
}

@Override
public MethodHandle newNativeMatrixBuilder() {
throw new UnsupportedOperationException();
throw new UnsupportedOperationException(reasons);
}

@Override
public CuVSMatrix newMatrixFromArray(float[][] vectors) {
throw new UnsupportedOperationException();
throw new UnsupportedOperationException(reasons);
}

@Override
public CuVSMatrix newMatrixFromArray(int[][] vectors) {
throw new UnsupportedOperationException();
throw new UnsupportedOperationException(reasons);
}

@Override
public CuVSMatrix newMatrixFromArray(byte[][] vectors) {
throw new UnsupportedOperationException();
throw new UnsupportedOperationException(reasons);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.nvidia.cuvs.internal.common;

import java.lang.foreign.*;
import java.lang.invoke.MethodHandle;

public class NativeLibraryUtils {

private NativeLibraryUtils() {}

private static final SymbolLookup LOOKUP =
SymbolLookup.libraryLookup(System.mapLibraryName("jvm"), Arena.ofAuto())
.or(SymbolLookup.loaderLookup())
.or(Linker.nativeLinker().defaultLookup());

// void * JVM_LoadLibrary(const char *name, jboolean throwException);
public static MethodHandle JVM_LoadLibrary$mh =
Linker.nativeLinker()
.downcallHandle(
LOOKUP.find("JVM_LoadLibrary").orElseThrow(),
FunctionDescriptor.of(
ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.JAVA_BOOLEAN));
// void JVM_UnloadLibrary(void * handle);
public static MethodHandle JVM_UnloadLibrary$mh =
Linker.nativeLinker()
.downcallHandle(
LOOKUP.find("JVM_UnloadLibrary").orElseThrow(),
FunctionDescriptor.ofVoid(ValueLayout.ADDRESS));
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,17 @@
*/
package com.nvidia.cuvs.spi;

import static com.nvidia.cuvs.internal.common.NativeLibraryUtils.JVM_LoadLibrary$mh;
import static com.nvidia.cuvs.internal.common.NativeLibraryUtils.JVM_UnloadLibrary$mh;
import static com.nvidia.cuvs.internal.common.Util.*;
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsVersionGet;
import static com.nvidia.cuvs.internal.panama.headers_h.uint16_t;

import com.nvidia.cuvs.*;
import com.nvidia.cuvs.internal.*;
import com.nvidia.cuvs.internal.common.Util;
import java.io.IOException;
import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
Expand All @@ -28,11 +34,77 @@
import java.nio.file.Path;
import java.util.Locale;
import java.util.Objects;
import java.util.Properties;

final class JDKProvider implements CuVSProvider {

private static final MethodHandle createNativeDataset$mh = createNativeDatasetBuilder();

static CuVSProvider create() throws Throwable {
var mavenVersion = readMavenVersionFromPropertiesOrNull();

try (var localArena = Arena.ofConfined()) {
var majorPtr = localArena.allocate(uint16_t);
var minorPtr = localArena.allocate(uint16_t);
var patchPtr = localArena.allocate(uint16_t);
checkCuVSError(cuvsVersionGet(majorPtr, minorPtr, patchPtr), "cuvsVersionGet");
var major = majorPtr.get(uint16_t, 0);
var minor = minorPtr.get(uint16_t, 0);
var patch = patchPtr.get(uint16_t, 0);

var cuvsVersionString = String.format(Locale.ROOT, "%02d.%02d.%d", major, minor, patch);
if (mavenVersion != null && !cuvsVersionString.equals(mavenVersion)) {
throw new ProviderInitializationException(
String.format(
Locale.ROOT,
"libcuvs_c version mismatch: expected [%s], found [%s]",
mavenVersion,
cuvsVersionString));
}
} catch (ExceptionInInitializerError e) {
if (e.getCause() instanceof IllegalArgumentException) {
// Try to find if we failed to load libcuvs and why
// jextract loads the dynamic library with SymbolLookup.libraryLookup; this uses
// RawNativeLibraries::load
// https://github.com/openjdk/jdk/blob/master/src/java.base/share/native/libjava/RawNativeLibraries.c#L58
// RawNativeLibraries::load in turn calls JVM_LoadLibrary. Unfortunately, it calls it with a
// JNI_FALSE parameter for throwException, which means that the detailed error messages are
// not surfaced.
// Here we try and load it again, with throwException true, so we can see what's broken
try (var localArena = Arena.ofConfined()) {
var name = localArena.allocateFrom(System.mapLibraryName("cuvs_c"));
Object lib = JVM_LoadLibrary$mh.invoke(name, true);
if (lib != null) {
// It wasn't a problem with library loading, so undo what we did
JVM_UnloadLibrary$mh.invoke(lib);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ldematte nice sleuthing. Let's try to use System.localLibrary earlier so that we can get similar improved error messages.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the idea! Less reflection :)
A first test seems to indicate that System.loadLibrary gives a different and less precise message. I'll dig into that.
In any case, trying to load the library before seems best.

}
} catch (Throwable ex) {
if (ex instanceof UnsatisfiedLinkError ulex) {
throw new ProviderInitializationException(ulex.getMessage(), ulex);
}
throw new ProviderInitializationException("error while loading libcuvs", ex);
}
} else {
throw e.getCause() != null ? e.getCause() : e;
}
}
return new JDKProvider();
}

/**
* Read cuvs-java version from Maven generated properties, or null if these are not available
*/
private static String readMavenVersionFromPropertiesOrNull() {
var properties = new Properties();

try (var is = JDKProvider.class.getClassLoader().getResourceAsStream("version.properties")) {
properties.load(is);
return properties.getProperty("version");
} catch (IOException e) {
return null;
}
}

static MethodHandle createNativeDatasetBuilder() {
try {
var lookup = MethodHandles.lookup();
Expand Down