-
Notifications
You must be signed in to change notification settings - Fork 184
Go API - [WIP] #212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Go API - [WIP] #212
Changes from 49 commits
7961798
764dc1e
b0fd0b1
b47cf5b
1ec0c8c
d5c0dc3
cc6319e
01af7c9
ed9bf47
4e79d8e
d278abc
b91721d
6b90861
773fd94
7cbc1f9
46ec2f7
9dcffbd
f82216c
486d40a
6f5c5a6
bfdf3be
de3cea0
57e8dc0
ab173dc
2d5fb95
505d8dc
c3360ee
e261c8a
a4890ed
f2bac2d
a684b01
7d45e24
a7084c2
3af1b73
44d9e58
9447f63
04b6532
a082f67
a84e764
853d538
4021229
6dd2044
3e33692
b1a0476
dddb165
e05782a
a315632
3c97864
bd2dd76
bd76cf8
1e5a756
927f2be
f7fac35
d814e37
2708bc0
e634821
8c9105a
4a80cf7
f5b8e72
7b93510
f738da4
51b5747
ff36031
3a45833
78e9b1a
3287f7c
e2b6104
a656870
b70d6e2
ef2187c
23e67f7
aa38afc
35bc1f1
dba8cfc
5d97c8e
bf50e3f
a93d286
20f41b9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
| @@ -0,0 +1,39 @@ | ||||
| #!/bin/bash | ||||
| # Copyright (c) 2024, NVIDIA CORPORATION. | ||||
|
|
||||
| set -euo pipefail | ||||
|
|
||||
| rapids-logger "Create test conda environment" | ||||
| . /opt/conda/etc/profile.d/conda.sh | ||||
|
|
||||
| RAPIDS_VERSION="$(rapids-version)" | ||||
|
|
||||
| rapids-dependency-file-generator \ | ||||
| --output conda \ | ||||
| --file-key go \ | ||||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do you know how I could add this file key to
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a key under Line 70 in 710e9f5
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. More explanation in the README here: https://github.com/rapidsai/dependency-file-generator |
||||
| --matrix "cuda=${RAPIDS_CUDA_VERSION%.*};arch=$(arch);py=${RAPIDS_PY_VERSION}" | tee env.yaml | ||||
|
|
||||
| rapids-mamba-retry env create --yes -f env.yaml -n go | ||||
|
|
||||
| # seeing failures on activating the environment here on unbound locals | ||||
| # apply workaround from https://github.com/conda/conda/issues/8186#issuecomment-532874667 | ||||
| set +eu | ||||
| conda activate go | ||||
| set -eu | ||||
|
|
||||
| rapids-print-env | ||||
|
|
||||
| export CGO_CFLAGS="-I/usr/local/cuda/include -I/home/ajit/miniforge3/envs/cuvs/include" | ||||
| export CGO_LDFLAGS="-L/usr/local/cuda/lib64 -L/home/ajit/miniforge3/envs/cuvs/lib -lcudart -lcuvs -lcuvs_c" | ||||
|
|
||||
| rapids-logger "Downloading artifacts from previous jobs" | ||||
| CPP_CHANNEL=$(rapids-download-conda-from-s3 cpp) | ||||
|
|
||||
| # installing libcuvs/libraft will speed up the rust build substantially | ||||
| rapids-mamba-retry install \ | ||||
| --channel "${CPP_CHANNEL}" \ | ||||
| libcuvs \ | ||||
| libraft \ | ||||
| cuvs | ||||
|
|
||||
| bash ./build.sh go | ||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,64 @@ | ||
| package brute_force | ||
|
|
||
| // #include <cuvs/neighbors/brute_force.h> | ||
| import "C" | ||
|
|
||
| import ( | ||
| "errors" | ||
| "unsafe" | ||
|
|
||
| cuvs "github.com/rapidsai/cuvs/go" | ||
| ) | ||
|
|
||
| type bruteForceIndex struct { | ||
| index C.cuvsBruteForceIndex_t | ||
| trained bool | ||
| } | ||
|
|
||
| func CreateIndex() (*bruteForceIndex, error) { | ||
| var index C.cuvsBruteForceIndex_t | ||
|
|
||
| err := cuvs.CheckCuvs(cuvs.CuvsError(C.cuvsBruteForceIndexCreate(&index))) | ||
| if err != nil { | ||
| return nil, err | ||
| } | ||
|
|
||
| return &bruteForceIndex{index: index, trained: false}, nil | ||
| } | ||
|
|
||
| func (index *bruteForceIndex) Close() error { | ||
| err := cuvs.CheckCuvs(cuvs.CuvsError(C.cuvsBruteForceIndexDestroy(index.index))) | ||
| if err != nil { | ||
| return err | ||
| } | ||
| return nil | ||
| } | ||
|
|
||
| func BuildIndex[T any](Resources cuvs.Resource, Dataset *cuvs.Tensor[T], metric cuvs.Distance, metric_arg float32, index *bruteForceIndex) error { | ||
| CMetric, exists := cuvs.CDistances[metric] | ||
|
|
||
| if !exists { | ||
| return errors.New("cuvs: invalid distance metric") | ||
| } | ||
|
|
||
| err := cuvs.CheckCuvs(cuvs.CuvsError(C.cuvsBruteForceBuild(C.cuvsResources_t(Resources.Resource), (*C.DLManagedTensor)(unsafe.Pointer(Dataset.C_tensor)), C.cuvsDistanceType(CMetric), C.float(metric_arg), index.index))) | ||
| if err != nil { | ||
| return err | ||
| } | ||
| index.trained = true | ||
|
|
||
| return nil | ||
| } | ||
|
|
||
| func SearchIndex[T any](resources cuvs.Resource, index bruteForceIndex, queries *cuvs.Tensor[T], neighbors *cuvs.Tensor[int64], distances *cuvs.Tensor[T]) error { | ||
| if !index.trained { | ||
| return errors.New("index needs to be built before calling search") | ||
| } | ||
|
|
||
| prefilter := C.cuvsFilter{ | ||
| addr: 0, | ||
| _type: C.NO_FILTER, | ||
| } | ||
|
|
||
| return cuvs.CheckCuvs(cuvs.CuvsError(C.cuvsBruteForceSearch(C.ulong(resources.Resource), index.index, (*C.DLManagedTensor)(unsafe.Pointer(queries.C_tensor)), (*C.DLManagedTensor)(unsafe.Pointer(neighbors.C_tensor)), (*C.DLManagedTensor)(unsafe.Pointer(distances.C_tensor)), prefilter))) | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,110 @@ | ||
| package brute_force | ||
|
|
||
| import ( | ||
| "math/rand/v2" | ||
| "testing" | ||
|
|
||
| cuvs "github.com/rapidsai/cuvs/go" | ||
| ) | ||
|
|
||
| func TestCagra(t *testing.T) { | ||
| const ( | ||
| nDataPoints = 1024 | ||
| nFeatures = 16 | ||
| nQueries = 4 | ||
| k = 4 | ||
| epsilon = 0.001 | ||
| ) | ||
|
|
||
| resource, _ := cuvs.NewResource(nil) | ||
| defer resource.Close() | ||
|
|
||
| testDataset := make([][]float32, nDataPoints) | ||
| for i := range testDataset { | ||
| testDataset[i] = make([]float32, nFeatures) | ||
| for j := range testDataset[i] { | ||
| testDataset[i][j] = rand.Float32() | ||
| } | ||
| } | ||
|
|
||
| dataset, err := cuvs.NewTensor(testDataset) | ||
| if err != nil { | ||
| t.Fatalf("error creating dataset tensor: %v", err) | ||
| } | ||
| defer dataset.Close() | ||
|
|
||
| index, _ := CreateIndex() | ||
| defer index.Close() | ||
|
|
||
| // use the first 4 points from the dataset as queries : will test that we get them back | ||
| // as their own nearest neighbor | ||
| queries, _ := cuvs.NewTensor(testDataset[:nQueries]) | ||
| defer queries.Close() | ||
|
|
||
| neighbors, err := cuvs.NewTensorOnDevice[int64](&resource, []int64{int64(nQueries), int64(k)}) | ||
| if err != nil { | ||
| t.Fatalf("error creating neighbors tensor: %v", err) | ||
| } | ||
| defer neighbors.Close() | ||
|
|
||
| distances, err := cuvs.NewTensorOnDevice[float32](&resource, []int64{int64(nQueries), int64(k)}) | ||
| if err != nil { | ||
| t.Fatalf("error creating distances tensor: %v", err) | ||
| } | ||
| defer distances.Close() | ||
|
|
||
| if _, err := dataset.ToDevice(&resource); err != nil { | ||
| t.Fatalf("error moving dataset to device: %v", err) | ||
| } | ||
|
|
||
| if err := BuildIndex(resource, &dataset, cuvs.DistanceL2, 2.0, index); err != nil { | ||
| t.Fatalf("error building index: %v", err) | ||
| } | ||
|
|
||
| if err := resource.Sync(); err != nil { | ||
| t.Fatalf("error syncing resource: %v", err) | ||
| } | ||
|
|
||
| if _, err := queries.ToDevice(&resource); err != nil { | ||
| t.Fatalf("error moving queries to device: %v", err) | ||
| } | ||
|
|
||
| err = SearchIndex(resource, *index, &queries, &neighbors, &distances) | ||
| if err != nil { | ||
| t.Fatalf("error searching index: %v", err) | ||
| } | ||
|
|
||
| if _, err := neighbors.ToHost(&resource); err != nil { | ||
| t.Fatalf("error moving neighbors to host: %v", err) | ||
| } | ||
|
|
||
| if _, err := distances.ToHost(&resource); err != nil { | ||
| t.Fatalf("error moving distances to host: %v", err) | ||
| } | ||
|
|
||
| if err := resource.Sync(); err != nil { | ||
| t.Fatalf("error syncing resource: %v", err) | ||
| } | ||
|
|
||
| neighborsSlice, err := neighbors.Slice() | ||
| if err != nil { | ||
| t.Fatalf("error getting neighbors slice: %v", err) | ||
| } | ||
|
|
||
| for i := range neighborsSlice { | ||
| if neighborsSlice[i][0] != int64(i) { | ||
| t.Error("wrong neighbor, expected", i, "got", neighborsSlice[i][0]) | ||
| } | ||
| } | ||
|
|
||
| distancesSlice, err := distances.Slice() | ||
| if err != nil { | ||
| t.Fatalf("error getting distances slice: %v", err) | ||
| } | ||
|
|
||
| for i := range distancesSlice { | ||
| if distancesSlice[i][0] >= epsilon || distancesSlice[i][0] <= -epsilon { | ||
| t.Error("distance should be close to 0, got", distancesSlice[i][0]) | ||
| } | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,62 @@ | ||
| package cagra | ||
|
|
||
| // #include <cuvs/neighbors/cagra.h> | ||
| import "C" | ||
|
|
||
| import ( | ||
| "errors" | ||
| "unsafe" | ||
|
|
||
| cuvs "github.com/rapidsai/cuvs/go" | ||
| ) | ||
|
|
||
| type CagraIndex struct { | ||
| index C.cuvsCagraIndex_t | ||
| trained bool | ||
| } | ||
|
|
||
| func CreateIndex() (*CagraIndex, error) { | ||
| var index C.cuvsCagraIndex_t | ||
| err := cuvs.CheckCuvs(cuvs.CuvsError(C.cuvsCagraIndexCreate(&index))) | ||
| if err != nil { | ||
| return nil, err | ||
| } | ||
|
|
||
| return &CagraIndex{index: index}, nil | ||
| } | ||
|
|
||
| func BuildIndex[T any](Resources cuvs.Resource, params *IndexParams, dataset *cuvs.Tensor[T], index *CagraIndex) error { | ||
| err := cuvs.CheckCuvs(cuvs.CuvsError(C.cuvsCagraBuild(C.ulong(Resources.Resource), params.params, (*C.DLManagedTensor)(unsafe.Pointer(dataset.C_tensor)), index.index))) | ||
| if err != nil { | ||
| return err | ||
| } | ||
| index.trained = true | ||
| return nil | ||
| } | ||
|
|
||
| func ExtendIndex[T any](Resources cuvs.Resource, params *ExtendParams, additional_dataset *cuvs.Tensor[T], return_dataset *cuvs.Tensor[T], index *CagraIndex) error { | ||
| if !index.trained { | ||
| return errors.New("index needs to be built before calling extend") | ||
| } | ||
| err := cuvs.CheckCuvs(cuvs.CuvsError(C.cuvsCagraExtend(C.ulong(Resources.Resource), params.params, (*C.DLManagedTensor)(unsafe.Pointer(additional_dataset.C_tensor)), (*C.DLManagedTensor)(unsafe.Pointer(return_dataset.C_tensor)), index.index))) | ||
| if err != nil { | ||
| return err | ||
| } | ||
| return nil | ||
| } | ||
|
|
||
| func (index *CagraIndex) Close() error { | ||
| err := cuvs.CheckCuvs(cuvs.CuvsError(C.cuvsCagraIndexDestroy(index.index))) | ||
| if err != nil { | ||
| return err | ||
| } | ||
| return nil | ||
| } | ||
|
|
||
| func SearchIndex[T any](Resources cuvs.Resource, params *SearchParams, index *CagraIndex, queries *cuvs.Tensor[T], neighbors *cuvs.Tensor[uint32], distances *cuvs.Tensor[T]) error { | ||
| if !index.trained { | ||
| return errors.New("index needs to be built before calling search") | ||
| } | ||
|
|
||
| return cuvs.CheckCuvs(cuvs.CuvsError(C.cuvsCagraSearch(C.cuvsResources_t(Resources.Resource), params.params, index.index, (*C.DLManagedTensor)(unsafe.Pointer(queries.C_tensor)), (*C.DLManagedTensor)(unsafe.Pointer(neighbors.C_tensor)), (*C.DLManagedTensor)(unsafe.Pointer(distances.C_tensor))))) | ||
| } |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@benfred do you know how to install/find dlpack?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The rust and java bindings both install through cmake - but it might be easier to add to the dependencies.yaml to add to the conda environment. I tried that out in the last commit - hopefully this will resolve
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like we got past the dlpack error - but now I'm seeing
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it seems like it installs the previous release from the conda channels, where some features aren't available yet which it relies on. Is there any way to do this through conda or would cmake be better? @benfred
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm - it should be grabbing the built cpp artifacts from S3 from the conda cpp build, but it looks like its installing 24.12 libcuvs instead:
(from the build-go GHA log ).
fwiw I just checked the rust build, which is set up very similar - and seems to be pulling the right libcuvs package.
I'm not quite sure why this is going wrong here - @bdice or @vyasr do you have any ideas?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looking at the difference in build scripts between build_go.sh and build_rust.sh - and I think it might just be that the rust build is specifying the rapids version.
going to try updating the go build to match